fbgemm-gpu-genai-nightly 2025.12.23__cp311-cp311-manylinux_2_28_x86_64.whl → 2026.1.14__cp311-cp311-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. fbgemm_gpu/batched_unary_embeddings_ops.py +0 -1
  2. fbgemm_gpu/config/feature_list.py +3 -0
  3. fbgemm_gpu/docs/target.genai.json.py +1 -1
  4. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  5. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +0 -4
  6. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +0 -3
  7. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +0 -1
  8. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +0 -1
  9. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +0 -2
  10. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +0 -3
  11. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +0 -3
  12. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  13. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +0 -1
  14. fbgemm_gpu/experimental/gen_ai/moe/activation.py +0 -1
  15. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +0 -1
  16. fbgemm_gpu/experimental/gen_ai/moe/layers.py +0 -2
  17. fbgemm_gpu/experimental/gen_ai/quantize.py +0 -1
  18. fbgemm_gpu/fbgemm.so +0 -0
  19. fbgemm_gpu/permute_pooled_embedding_modules.py +0 -1
  20. fbgemm_gpu/quantize/quantize_ops.py +0 -1
  21. fbgemm_gpu/quantize_comm.py +8 -13
  22. fbgemm_gpu/quantize_utils.py +61 -7
  23. fbgemm_gpu/sll/__init__.py +0 -1
  24. fbgemm_gpu/sll/triton/__init__.py +0 -10
  25. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +0 -1
  26. fbgemm_gpu/sparse_ops.py +0 -1
  27. fbgemm_gpu/split_embedding_configs.py +0 -1
  28. fbgemm_gpu/split_embedding_inference_converter.py +0 -1
  29. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +10 -11
  30. fbgemm_gpu/tbe/bench/bench_runs.py +0 -1
  31. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +2 -4
  32. fbgemm_gpu/tbe/bench/eeg_cli.py +0 -1
  33. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +0 -1
  34. fbgemm_gpu/tbe/bench/tbe_data_config.py +5 -3
  35. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +0 -2
  36. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +0 -1
  37. fbgemm_gpu/tbe/bench/utils.py +0 -1
  38. fbgemm_gpu/tbe/ssd/common.py +0 -1
  39. fbgemm_gpu/tbe_input_multiplexer.py +0 -1
  40. fbgemm_gpu/triton/quantize.py +13 -8
  41. fbgemm_gpu/uvm.py +0 -1
  42. {fbgemm_gpu_genai_nightly-2025.12.23.dist-info → fbgemm_gpu_genai_nightly-2026.1.14.dist-info}/METADATA +1 -1
  43. {fbgemm_gpu_genai_nightly-2025.12.23.dist-info → fbgemm_gpu_genai_nightly-2026.1.14.dist-info}/RECORD +46 -46
  44. list_versions/cli_run.py +0 -2
  45. {fbgemm_gpu_genai_nightly-2025.12.23.dist-info → fbgemm_gpu_genai_nightly-2026.1.14.dist-info}/WHEEL +0 -0
  46. {fbgemm_gpu_genai_nightly-2025.12.23.dist-info → fbgemm_gpu_genai_nightly-2026.1.14.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@
11
11
  from math import sqrt
12
12
 
13
13
  import torch
14
-
15
14
  from fbgemm_gpu.utils.loader import load_torch_module
16
15
 
17
16
  try:
@@ -63,6 +63,9 @@ class FeatureGateName(Enum):
63
63
  # Enable TBE input parameters extraction
64
64
  TBE_REPORT_INPUT_PARAMS = auto()
65
65
 
66
+ # Enable tuned max segment length per CTA for B200
67
+ TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
68
+
66
69
  def is_enabled(self) -> bool:
67
70
  return FeatureGate.is_enabled(self)
68
71
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  {
3
- "version": "2025.12.23",
3
+ "version": "2026.1.14",
4
4
  "target": "genai",
5
5
  "variant": "cuda"
6
6
  }
@@ -12,9 +12,7 @@ from typing import Optional, Union
12
12
 
13
13
  import torch
14
14
  import triton # @manual
15
-
16
15
  import triton.language as tl # @manual
17
-
18
16
  from fbgemm_gpu.experimental.gemm.triton_gemm.matmul_perf_model import (
19
17
  early_config_prune,
20
18
  estimate_matmul_time,
@@ -23,10 +21,8 @@ from fbgemm_gpu.experimental.gemm.triton_gemm.utils import (
23
21
  map_dtype_to_triton,
24
22
  TmaAutoTuneHelper,
25
23
  )
26
-
27
24
  from packaging import version
28
25
  from torch._tensor import Tensor
29
-
30
26
  from triton import Config # @manual
31
27
  from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual
32
28
 
@@ -9,14 +9,11 @@
9
9
  import functools
10
10
  import inspect
11
11
  import warnings
12
-
13
12
  from typing import Optional
14
13
 
15
14
  import torch
16
-
17
15
  import triton
18
16
  import triton.language as tl
19
-
20
17
  from fbgemm_gpu.experimental.gemm.triton_gemm import utils
21
18
  from triton.runtime import driver # @manual
22
19
 
@@ -12,7 +12,6 @@ import functools
12
12
  import heapq
13
13
 
14
14
  import torch
15
-
16
15
  from triton import cdiv # @manual
17
16
  from triton.runtime import driver # @manual
18
17
  from triton.testing import ( # @manual
@@ -9,7 +9,6 @@ import sys
9
9
 
10
10
  import torch
11
11
  import triton # @manual
12
-
13
12
  import triton.language as tl # @manual
14
13
 
15
14
 
@@ -6,7 +6,6 @@
6
6
 
7
7
 
8
8
  import argparse
9
-
10
9
  import os
11
10
  import tempfile
12
11
  import uuid
@@ -15,7 +14,6 @@ from pprint import pprint
15
14
 
16
15
  import fbgemm_gpu.experimental.gen_ai # noqa: F401
17
16
  import pandas as pd
18
-
19
17
  import torch
20
18
  import torch.distributed as dist
21
19
  import torch.distributed._symmetric_memory as symm_mem
@@ -7,16 +7,13 @@
7
7
  import itertools
8
8
  import os
9
9
  import sys
10
-
11
10
  from dataclasses import dataclass
12
11
  from datetime import datetime
13
12
  from typing import Any, Optional
14
13
 
15
14
  import click
16
-
17
15
  import matplotlib.pyplot as plt
18
16
  import numpy as np
19
-
20
17
  import pandas as pd
21
18
  import seaborn as sns
22
19
  import torch
@@ -9,10 +9,8 @@ import abc
9
9
 
10
10
  import fbgemm_gpu.experimental.gen_ai # noqa: F401
11
11
  import numpy as np
12
-
13
12
  import torch
14
13
  import triton # @manual=//triton:triton
15
-
16
14
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
17
15
  _to_blocked,
18
16
  calculate_group_max,
@@ -26,7 +24,6 @@ from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
26
24
  triton_scale_nvfp4_quant_rms,
27
25
  triton_scale_nvfp4_quant_silu,
28
26
  )
29
-
30
27
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
31
28
  get_fp8_constants,
32
29
  matmul_fp8_block,
@@ -56,7 +56,6 @@ from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( # noqa F401
56
56
  )
57
57
 
58
58
  from .activation import silu_mul, silu_mul_quant # noqa F401
59
-
60
59
  from .gather_scatter import ( # noqa F401
61
60
  gather_scale_dense_tokens,
62
61
  gather_scale_quant_dense_tokens,
@@ -11,7 +11,6 @@ from typing import Optional
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
-
15
14
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import get_fp8_constants
16
15
 
17
16
 
@@ -11,7 +11,6 @@ from typing import Optional
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
-
15
14
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import get_fp8_constants
16
15
 
17
16
 
@@ -12,9 +12,7 @@ from functools import cached_property
12
12
  from typing import Callable, Optional, Union
13
13
 
14
14
  import torch
15
-
16
15
  from fairscale.nn.model_parallel.initialize import get_model_parallel_world_size
17
-
18
16
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_row
19
17
  from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
20
18
  grouped_gemm,
@@ -10,7 +10,6 @@
10
10
 
11
11
 
12
12
  import torch
13
-
14
13
  from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
15
14
 
16
15
 
fbgemm_gpu/fbgemm.so CHANGED
Binary file
@@ -11,7 +11,6 @@ from itertools import accumulate
11
11
  from typing import Optional
12
12
 
13
13
  import torch
14
-
15
14
  from fbgemm_gpu.utils.loader import load_torch_module
16
15
 
17
16
  try:
@@ -8,7 +8,6 @@
8
8
  from typing import Union
9
9
 
10
10
  import torch
11
-
12
11
  from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
13
12
 
14
13
 
@@ -16,7 +16,6 @@ import logging
16
16
  from typing import Optional, TypeVar
17
17
 
18
18
  import torch
19
-
20
19
  from fbgemm_gpu.quantize_utils import (
21
20
  bf16_to_fp32,
22
21
  fp16_to_fp32,
@@ -25,12 +24,10 @@ from fbgemm_gpu.quantize_utils import (
25
24
  fp32_to_hfp8_with_clamp,
26
25
  fp32_to_mx4,
27
26
  hfp8_to_fp32,
28
- mx4_to_fp32,
27
+ mx4_to_float,
29
28
  RoundingMode,
30
29
  )
31
-
32
30
  from fbgemm_gpu.split_embedding_configs import SparseType
33
-
34
31
  from torch.autograd.profiler import record_function # usort:skip
35
32
  from dataclasses import dataclass
36
33
 
@@ -123,7 +120,7 @@ def _dequantize_tensor(
123
120
  comm_precision: SparseType,
124
121
  ctx: Optional[QuantizationContext] = None,
125
122
  is_fwd: bool = True,
126
- fp8_output_dtype: Optional[SparseType] = None,
123
+ output_dtype: Optional[SparseType] = None,
127
124
  ) -> torch.Tensor:
128
125
  if comm_precision == SparseType.FP32:
129
126
  assert quantized_tensor.dtype == torch.float
@@ -138,10 +135,8 @@ def _dequantize_tensor(
138
135
  if ctx is not None and ctx.row_dim > 0:
139
136
  row_dim_quant = ctx.row_dim_quant
140
137
  quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
141
- # use provided fp8_output_dtype or default to FP32 (0)
142
- output_dtype_int = (
143
- fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
144
- )
138
+ # use provided output_dtype or default to FP32 (0)
139
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
145
140
  dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
146
141
  quantized_tensor_2d,
147
142
  is_fwd,
@@ -161,7 +156,7 @@ def _dequantize_tensor(
161
156
  return dequant_tensor.view(-1)
162
157
  elif comm_precision == SparseType.MX4:
163
158
  mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
164
- return mx4_to_fp32(quantized_tensor, mx_group_size)
159
+ return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
165
160
  else:
166
161
  raise ValueError(f"comm_precision={comm_precision} is not supported")
167
162
 
@@ -175,7 +170,7 @@ class QuantizedCommCodec:
175
170
  row_dim: Optional[int] = None,
176
171
  is_fwd: bool = True,
177
172
  rounding_mode: Optional[RoundingMode] = None,
178
- fp8_output_dtype: Optional[SparseType] = None,
173
+ output_dtype: Optional[SparseType] = None,
179
174
  ) -> None:
180
175
  if loss_scale is not None:
181
176
  if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -193,7 +188,7 @@ class QuantizedCommCodec:
193
188
  self._is_fwd = is_fwd
194
189
  self._row_dim: int = -1 if row_dim is None else row_dim
195
190
  self._rounding_mode: Optional[RoundingMode] = rounding_mode
196
- self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
191
+ self._output_dtype: Optional[SparseType] = output_dtype
197
192
  if self._comm_precision == SparseType.MX4:
198
193
  self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
199
194
  self._rounding_mode = (
@@ -229,7 +224,7 @@ class QuantizedCommCodec:
229
224
  self._comm_precision,
230
225
  ctx,
231
226
  self._is_fwd,
232
- fp8_output_dtype=self._fp8_output_dtype,
227
+ output_dtype=self._output_dtype,
233
228
  )
234
229
  return dequantized_tensor
235
230
 
@@ -13,10 +13,18 @@ from typing import Optional, Union
13
13
  import torch # isort:skip
14
14
 
15
15
  import fbgemm_gpu
16
-
17
- from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
16
+ from fbgemm_gpu.split_embedding_configs import SparseType
17
+ from fbgemm_gpu.triton.common import RoundingMode
18
18
  from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
19
19
 
20
+ try:
21
+ if torch.cuda.is_available():
22
+ from fbgemm_gpu.triton import quantize_mx4
23
+ from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
24
+ except Exception:
25
+ pass
26
+
27
+
20
28
  try:
21
29
  # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
22
30
  open_source = bool(getattr(fbgemm_gpu, "open_source", False))
@@ -126,25 +134,71 @@ def mx4_to_fp32(
126
134
  ) -> torch.Tensor:
127
135
  """Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
128
136
 
137
+ This function is kept for backward compatibility and always returns FP32.
138
+ For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
139
+ """
140
+ return mx4_to_float(
141
+ tensor,
142
+ group_size,
143
+ use_triton,
144
+ ebits,
145
+ mbits,
146
+ output_dtype=None, # None = FP32 default for backward compatibility
147
+ )
148
+
149
+
150
+ def mx4_to_float(
151
+ tensor: torch.Tensor,
152
+ group_size: int = 32,
153
+ use_triton: bool = True,
154
+ ebits: int = 2,
155
+ mbits: int = 1,
156
+ output_dtype: Optional[SparseType] = None,
157
+ ) -> torch.Tensor:
158
+ """Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
159
+
129
160
  Args:
130
161
  tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
131
162
  group_size (int): Compute scale in chunks of group_size.
132
163
  use_triton (bool): If set, use triton quantization, otherwise cuda.
133
164
  ebits (int): Number of exponent bits in target mx4 format.
134
165
  mbits (int): Number of mantissa bits in target mx4 format.
166
+ output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
167
+ Defaults to None (FP32) for backward compatibility.
135
168
 
136
169
  Return:
137
- output: FP32 tensor with total elements (M).
170
+ output: Tensor with dtype matching output_dtype and total elements (M).
138
171
  """
172
+ # Validate output_dtype
173
+ supported_dtypes = {SparseType.FP32, SparseType.BF16}
174
+ if output_dtype is not None and output_dtype not in supported_dtypes:
175
+ raise ValueError(
176
+ f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
177
+ f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
178
+ f"Use BF16 for memory savings with same dynamic range as FP32."
179
+ )
180
+
181
+ target_dtype = (
182
+ output_dtype.as_dtype() if output_dtype is not None else torch.float32
183
+ )
184
+
139
185
  # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
140
186
  if not tensor.is_cuda and not tensor.is_mtia:
141
- return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
187
+ result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
188
+ return result.to(target_dtype) if output_dtype is not None else result
142
189
  if use_triton:
143
190
  if tensor.is_mtia:
144
- return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
145
- return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
191
+ return mtia_dequantize_mx4(
192
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
193
+ )
194
+ return triton_dequantize_mx4(
195
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
196
+ )
146
197
  else:
147
- return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
198
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
199
+ return torch.ops.fbgemm.dequantize_mx_cuda(
200
+ tensor.flatten(), group_size, output_dtype_int
201
+ )
148
202
 
149
203
 
150
204
  def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
@@ -8,7 +8,6 @@
8
8
  # pyre-strict
9
9
 
10
10
  import torch
11
-
12
11
  from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
13
12
  from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
14
13
  from fbgemm_gpu.utils import TorchLibraryFragment
@@ -10,19 +10,16 @@
10
10
  from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
11
11
  dense_jagged_cat_jagged_out,
12
12
  )
13
-
14
13
  from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
15
14
  jagged2_to_padded_dense,
16
15
  Jagged2ToPaddedDense, # noqa F401
17
16
  )
18
-
19
17
  from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
20
18
  jagged_dense_bmm,
21
19
  jagged_jagged_bmm,
22
20
  JaggedDenseBmm, # noqa F401
23
21
  JaggedJaggedBmm, # noqa F401
24
22
  )
25
-
26
23
  from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
27
24
  array_jagged_bmm_jagged_out,
28
25
  ArrayJaggedBmmNopadding, # noqa F401
@@ -31,38 +28,31 @@ from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
31
28
  triton_array_jagged_bmm_jagged_out, # noqa F401
32
29
  triton_jagged_jagged_bmm_jagged_out, # noqa F401
33
30
  )
34
-
35
31
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
36
32
  jagged_dense_elementwise_add,
37
33
  JaggedDenseAdd, # noqa F401
38
34
  )
39
-
40
35
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
41
36
  jagged_dense_elementwise_mul_jagged_out,
42
37
  JaggedDenseElementwiseMul, # noqa F401
43
38
  )
44
-
45
39
  from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
46
40
  jagged_dense_flash_attention,
47
41
  JaggedDenseFlashAttention, # noqa F401
48
42
  )
49
-
50
43
  from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
51
44
  jagged_flash_attention_basic,
52
45
  JaggedFlashAttentionBasic, # noqa F401
53
46
  )
54
-
55
47
  from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
56
48
  triton_jagged_self_substraction_jagged_out,
57
49
  )
58
-
59
50
  from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
60
51
  jagged2_softmax,
61
52
  Jagged2Softmax, # noqa F401
62
53
  jagged_softmax,
63
54
  JaggedSoftmax, # noqa F401
64
55
  )
65
-
66
56
  from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
67
57
  multi_head_jagged_flash_attention,
68
58
  MultiHeadJaggedFlashAttention, # noqa F401
@@ -7,7 +7,6 @@
7
7
  # pyre-unsafe
8
8
 
9
9
  import torch
10
-
11
10
  from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
12
11
  dense_to_jagged,
13
12
  jagged_to_dense,
fbgemm_gpu/sparse_ops.py CHANGED
@@ -11,7 +11,6 @@ from collections.abc import Sequence
11
11
  from typing import Callable, Optional
12
12
 
13
13
  import torch
14
-
15
14
  from fbgemm_gpu.split_embedding_configs import SparseType
16
15
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
17
16
  from fbgemm_gpu.utils.loader import load_torch_module
@@ -12,7 +12,6 @@ import itertools
12
12
  from typing import Any, Dict # noqa: F401
13
13
 
14
14
  import torch
15
-
16
15
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17
16
  EmbeddingLocation,
18
17
  SplitState,
@@ -13,7 +13,6 @@ import math
13
13
  from typing import cast, Optional
14
14
 
15
15
  import torch
16
-
17
16
  from fbgemm_gpu.split_embedding_configs import (
18
17
  FP8QuantizationConfig,
19
18
  QuantizationConfig,
@@ -26,7 +26,6 @@ from torch.autograd.profiler import record_function # usort:skip
26
26
 
27
27
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
28
28
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
29
-
30
29
  from fbgemm_gpu.config import FeatureGate, FeatureGateName
31
30
  from fbgemm_gpu.runtime_monitor import (
32
31
  AsyncSeriesTimer,
@@ -59,7 +58,6 @@ from fbgemm_gpu.tbe_input_multiplexer import (
59
58
  TBEInputMultiplexer,
60
59
  TBEInputMultiplexerConfig,
61
60
  )
62
-
63
61
  from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
64
62
  from fbgemm_gpu.utils.writeback_util import writeback_gradient
65
63
 
@@ -2764,20 +2762,21 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2764
2762
  self.prefetch_stream != forward_stream
2765
2763
  ), "prefetch_stream and forward_stream should not be the same stream"
2766
2764
 
2767
- indices, offsets, _, vbe_metadata = self.prepare_inputs(
2768
- indices,
2769
- offsets,
2770
- per_sample_weights=None,
2771
- batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2772
- force_cast_input_types=False,
2773
- prefetch_pipeline=self.prefetch_pipeline,
2774
- )
2775
-
2776
2765
  with self._recording_to_timer(
2777
2766
  self.prefetch_duration_timer,
2778
2767
  context=self.step,
2779
2768
  stream=torch.cuda.current_stream(),
2780
2769
  ):
2770
+
2771
+ indices, offsets, _, vbe_metadata = self.prepare_inputs(
2772
+ indices,
2773
+ offsets,
2774
+ per_sample_weights=None,
2775
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2776
+ force_cast_input_types=False,
2777
+ prefetch_pipeline=self.prefetch_pipeline,
2778
+ )
2779
+
2781
2780
  self._prefetch(
2782
2781
  indices,
2783
2782
  offsets,
@@ -15,7 +15,6 @@ from subprocess import Popen
15
15
  from typing import Callable, Optional
16
16
 
17
17
  import torch
18
-
19
18
  from fbgemm_gpu.tbe.utils import b_indices, TBERequest
20
19
  from fbgemm_gpu.tbe.utils.common import get_device
21
20
 
@@ -8,12 +8,10 @@
8
8
  # pyre-strict
9
9
 
10
10
  import click
11
-
12
11
  from fbgemm_gpu.split_embedding_configs import SparseType
13
12
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import BoundsCheckMode
14
-
15
- from .bench_config import TBEBenchmarkingHelperText
16
- from .tbe_data_config_loader import TBEDataConfigHelperText
13
+ from .bench_config import TBEBenchmarkingHelperText # usort:skip
14
+ from .tbe_data_config_loader import TBEDataConfigHelperText # usort:skip
17
15
 
18
16
 
19
17
  class TbeBenchClickInterface:
@@ -9,7 +9,6 @@
9
9
 
10
10
  import click
11
11
  import torch
12
-
13
12
  from fbgemm_gpu.tbe.bench import IndicesParams
14
13
 
15
14
 
@@ -12,7 +12,6 @@ from typing import Any, Optional
12
12
 
13
13
  import click
14
14
  import torch
15
-
16
15
  from fbgemm_gpu.split_embedding_configs import SparseType
17
16
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
18
17
  BoundsCheckMode,
@@ -13,10 +13,12 @@ import logging
13
13
  from typing import Any, Optional
14
14
 
15
15
  import torch
16
-
17
16
  from fbgemm_gpu.tbe.utils.common import get_device
18
-
19
- from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
17
+ from .tbe_data_config_param_models import (
18
+ BatchParams,
19
+ IndicesParams,
20
+ PoolingParams,
21
+ ) # usort:skip
20
22
 
21
23
  try:
22
24
  torch.ops.load_library(
@@ -11,10 +11,8 @@ from typing import Optional
11
11
 
12
12
  import numpy as np
13
13
  import torch
14
-
15
14
  from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig
16
15
  from fbgemm_gpu.tbe.utils.common import get_device, round_up
17
-
18
16
  from fbgemm_gpu.tbe.utils.requests import (
19
17
  generate_batch_sizes_from_stats,
20
18
  generate_pooling_factors_from_stats,
@@ -13,7 +13,6 @@ from enum import Enum
13
13
  import click
14
14
  import torch
15
15
  import yaml
16
-
17
16
  from fbgemm_gpu.tbe.bench.tbe_data_config import (
18
17
  BatchParams,
19
18
  IndicesParams,
@@ -10,7 +10,6 @@ import logging
10
10
 
11
11
  import numpy as np
12
12
  import torch
13
-
14
13
  from fbgemm_gpu.split_embedding_configs import SparseType
15
14
 
16
15
  logging.basicConfig(level=logging.DEBUG)
@@ -9,7 +9,6 @@
9
9
  # pyre-ignore-all-errors[56]
10
10
 
11
11
  import torch
12
-
13
12
  from fbgemm_gpu.utils.loader import load_torch_module
14
13
 
15
14
  try:
@@ -8,7 +8,6 @@
8
8
  # pyre-unsafe
9
9
 
10
10
  import abc
11
-
12
11
  from dataclasses import dataclass
13
12
  from typing import Optional
14
13
 
@@ -11,7 +11,6 @@ from typing import Union
11
11
 
12
12
  import torch
13
13
  import triton # @manual
14
-
15
14
  import triton.language as tl # @manual
16
15
 
17
16
  from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
@@ -575,7 +574,7 @@ def _kernel_dequantize_mx4(
575
574
  # Write final outputs.
576
575
  tl.store(
577
576
  out + output_offset,
578
- scaled_fp32,
577
+ scaled_fp32.to(out.dtype.element_ty),
579
578
  # Mask values that are out of this chunk or the main array.
580
579
  mask=(output_offset < OUTPUT_SIZE)
581
580
  & (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
@@ -588,10 +587,14 @@ def _kernel_dequantize_mx4(
588
587
 
589
588
 
590
589
  def triton_dequantize_mx4(
591
- a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
590
+ a: torch.Tensor,
591
+ group_size: int = 32,
592
+ ebits: int = 2,
593
+ mbits: int = 1,
594
+ output_dtype: torch.dtype = torch.float32,
592
595
  ) -> torch.Tensor:
593
596
  """
594
- Dequantize a tensor from mx4 format to fp32.
597
+ Dequantize a tensor from mx4 format to fp32 or bf16.
595
598
 
596
599
  Args:
597
600
  a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
@@ -599,13 +602,15 @@ def triton_dequantize_mx4(
599
602
  group_size (int): Size of chunks that use the same shared exponent.
600
603
  ebits (int): Number of bits to use for exponent in target mx4 format.
601
604
  mbits (int): Number of bits to use for mantissa in target mx4 format.
605
+ output_dtype (torch.dtype): Output dtype (FP32 or BF16).
606
+ Defaults to torch.float32 for backward compatibility.
602
607
 
603
608
  Returns:
604
- torch.Tensor: [M, K] dequantized fp32 tensor.
609
+ torch.Tensor: [M, K] dequantized tensor in the specified dtype.
605
610
  """
606
611
  # If given an empty shape, return an empty tensor.
607
612
  if a.numel() == 0:
608
- return torch.empty(a.shape, device=a.device, dtype=torch.float32)
613
+ return torch.empty(a.shape, device=a.device, dtype=output_dtype)
609
614
  # View a as 2D for simplicity.
610
615
  orig_shape = a.shape
611
616
  a = a.flatten()
@@ -622,9 +627,9 @@ def triton_dequantize_mx4(
622
627
  # Use a lookup table to convert
623
628
  mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
624
629
 
625
- # Create output tensor.
630
+ # Create output tensor in target dtype.
626
631
  output_elems = num_groups * group_size
627
- out = torch.empty([output_elems], device=a.device, dtype=torch.float)
632
+ out = torch.empty([output_elems], device=a.device, dtype=output_dtype)
628
633
  # Check if we need to use int64 for indexing.
629
634
  use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
630
635
  # Invoke triton dequantization kernel over rows.
fbgemm_gpu/uvm.py CHANGED
@@ -11,7 +11,6 @@ from enum import Enum
11
11
  from typing import Optional
12
12
 
13
13
  import torch
14
-
15
14
  from fbgemm_gpu.enums import create_enums
16
15
 
17
16
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fbgemm_gpu_genai_nightly
3
- Version: 2025.12.23
3
+ Version: 2026.1.14
4
4
  Home-page: https://github.com/pytorch/fbgemm
5
5
  Author: FBGEMM Team
6
6
  Author-email: packages@pytorch.org
@@ -1,29 +1,29 @@
1
1
  fbgemm_gpu/__init__.py,sha256=bL2dL7uYeXb1GvdjIDUTcLXLRGNfmnI4MQoE3-Gg5m8,6361
2
2
  fbgemm_gpu/asmjit.so,sha256=RxTYI8zY4PpIBRpSKT_-U7bRIVeTRohdtRFUmLNU1tQ,501728
3
- fbgemm_gpu/batched_unary_embeddings_ops.py,sha256=GYeJ9pg-Wc9FokXVci_npDsL6UV18-pJXID2xzrJ9O8,2904
3
+ fbgemm_gpu/batched_unary_embeddings_ops.py,sha256=Zst_OhYCBgbNMWfUADp1W1pGL1pT5t_8XX2q-QT50TI,2903
4
4
  fbgemm_gpu/enums.py,sha256=37ewGSfO1x7sO31ZkRiqV1yKuklfHXT5qZIxzeeGogo,755
5
- fbgemm_gpu/fbgemm.so,sha256=U864UANx-CVyFYk5ADawCd0uWRfntHaVcyl6AVty_3Q,5642616
5
+ fbgemm_gpu/fbgemm.so,sha256=pW04240G6WyXyaVJszQM0R7p8Jr1ZoyfblH5OJsmCyo,5675384
6
6
  fbgemm_gpu/metrics.py,sha256=TsurFLJf0nJvPDN7urWb4LMQlf5RgdWPTTTDO7S4wtI,5663
7
- fbgemm_gpu/permute_pooled_embedding_modules.py,sha256=vOXMYclaGnwSt0St_SOAlAe18kz6WjMyTeHnC9jLhcE,5130
7
+ fbgemm_gpu/permute_pooled_embedding_modules.py,sha256=dGQ8o3wN0yaLj8adx4oR6ncmkOH3PT7_zGZ8yYTnnk0,5129
8
8
  fbgemm_gpu/permute_pooled_embedding_modules_split.py,sha256=f3VJvH_kw9Ltd_DXtaf_PJPHmlmEWrQgzQ7MDkhh5Nw,2746
9
- fbgemm_gpu/quantize_comm.py,sha256=ZfXtRHfqpVpV6k2PDL6oTUkKYzopqAV2M6vavp_RLSM,12022
10
- fbgemm_gpu/quantize_utils.py,sha256=q8Aokk6nlHbXF6HcDBbhBCAGSZV4klM8uPF-MUFFtAw,8324
9
+ fbgemm_gpu/quantize_comm.py,sha256=yKKDJF_aMIYJG_22KG4BX1-AF_88ulgOXLvRO2a4RNI,11980
10
+ fbgemm_gpu/quantize_utils.py,sha256=sROgIdOrAjQT5_CmFafg40GMo0-pe4d56bAZTI57548,10243
11
11
  fbgemm_gpu/runtime_monitor.py,sha256=YXRUv6nXCsoTgh5_RzailTGvCYzwoYDb-eR4rlGwtaw,7619
12
- fbgemm_gpu/sparse_ops.py,sha256=_EJC1pAbNnAnVQQ5JBg4DAV2TboIj-4XQkiKMmg1vXI,50417
13
- fbgemm_gpu/split_embedding_configs.py,sha256=EuVFKIDrgRQpRC5mmB4Du6WftK5GXJvDue9_ezt_eBI,16575
14
- fbgemm_gpu/split_embedding_inference_converter.py,sha256=AghGW22MgMsdHzdwdPMPYDjgas5AE_estckY8rMgXVU,7056
12
+ fbgemm_gpu/sparse_ops.py,sha256=uCmtitnCJnDAIq1TCYvk24COyUnbvjIHVob37JgSDkg,50416
13
+ fbgemm_gpu/split_embedding_configs.py,sha256=awc9gAhCsRulXmQM089gxJwW0G3PeIw48gUesf13AKc,16574
14
+ fbgemm_gpu/split_embedding_inference_converter.py,sha256=rKILaM_C5Y-4Ypl1uHG4pZfiMZ-XlzjMwgik4X-wWeU,7055
15
15
  fbgemm_gpu/split_embedding_optimizer_ops.py,sha256=wXuGazClBMk62yL_r9udUIKaPgQP7SlkSb5ugB75wrQ,711
16
16
  fbgemm_gpu/split_embedding_utils.py,sha256=Gb40ZKeATxIKEKI3aVQMgDDBanNpKMc53Z43mnzdR_I,851
17
17
  fbgemm_gpu/split_table_batched_embeddings_ops.py,sha256=_MIp6uHYHLn4GxGdrGsfddfSsZ2Z9mjsYIrih3ncI1I,2339
18
18
  fbgemm_gpu/split_table_batched_embeddings_ops_common.py,sha256=eFxb_bDfBV8G76pmd-SxDXXXnqgbuGYOS4pSU8JS5dg,19295
19
19
  fbgemm_gpu/split_table_batched_embeddings_ops_inference.py,sha256=dGC85xjQiRUrequBibSf9oMAVHT5Q49zsVo2zW4n_88,81679
20
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=rNGMELM_xFIsdS_340PB7bsn9h_VjONq_JJG1SjHyvQ,188992
20
+ fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=kzTVo_o7ouCdPuGdziPSz3LZbEi3jI0aTLp4u7fuWRs,189023
21
21
  fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py,sha256=jofAN2UB_iSk53Id6MBvn9Bi3Qxw67IL0_VE_EHlw_Q,7593
22
22
  fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py,sha256=7qGkO8FARku38mFYl4Bc4qL8dS1wrfyorS9l1m5ZAVA,718
23
- fbgemm_gpu/tbe_input_multiplexer.py,sha256=TQjwkJ2JkOaQsMYuRdk9RbNa9759EPEtx8bYclChtZY,3063
24
- fbgemm_gpu/uvm.py,sha256=guNK8ZzR80jmv-CyRgEhxhVYhjz3R9d6tB8Hu1uWDUo,1047
23
+ fbgemm_gpu/tbe_input_multiplexer.py,sha256=MbZF8aZdm_kV-JRMaooeZrqlh6Pn5IuNkSXBXODp-LE,3062
24
+ fbgemm_gpu/uvm.py,sha256=V6LvMN7_Oc0YifB6AgwD37ymZzyZO9ydDWany1FoDf0,1046
25
25
  fbgemm_gpu/config/__init__.py,sha256=yN0KAneCICgF2BTfOYGsd0qU1PvZX_6msC6YHHZKLMg,292
26
- fbgemm_gpu/config/feature_list.py,sha256=iDOGr9nwTqUhWsqOefRIqIo1jwLSeSII4jGnLeU01kg,2359
26
+ fbgemm_gpu/config/feature_list.py,sha256=hhDNkkafd-Oetvuqv9ylBVTNM-lKPi029mpRqq-JZCA,2467
27
27
  fbgemm_gpu/docs/__init__.py,sha256=DR6hMSQrsZALfH2AnuJQ4Zq2CfBUUhMN8YjD6APjiAE,523
28
28
  fbgemm_gpu/docs/common.py,sha256=8ipXTwVb222X-aZ71O6n8fhxHCHPNhJEHMFiO7epcIs,273
29
29
  fbgemm_gpu/docs/examples.py,sha256=ZMN_6sL74LH_hrp2bF_hmg8gi29GhcgvwV3kCMjxkoE,2377
@@ -32,47 +32,47 @@ fbgemm_gpu/docs/merge_pooled_embedding_ops.py,sha256=oJLgSgZQmhsyGLbTmZTxNgQrk65
32
32
  fbgemm_gpu/docs/permute_pooled_embedding_ops.py,sha256=tZUqLVXlk5O6VAKKDA-OEMx2fCu5QPOOeoAPZA9_nLY,4454
33
33
  fbgemm_gpu/docs/quantize_ops.py,sha256=xTtOaVK1P02ymreE_i21YiyYDZCqhoZY9eWp_mEIRlo,1297
34
34
  fbgemm_gpu/docs/sparse_ops.py,sha256=gSLUFdnu8lle_6gLewFkM20wL3ek2jKLvDGMKR6POaY,27292
35
- fbgemm_gpu/docs/target.genai.json.py,sha256=_wCZZFTZPnoCRTnmtXfpjGrdRZuWl7T171wr-JhtC-Y,79
35
+ fbgemm_gpu/docs/target.genai.json.py,sha256=ruseG1ciUe-WwzabUt1S-x9bEycf4pNFzFtE7-nSnuk,78
36
36
  fbgemm_gpu/experimental/example/__init__.py,sha256=OvJHZgWnycL1gWKyCXFJCTKuys3KAqx4iadjx3R-tBQ,723
37
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=ebm5zEzVjAj-j6DP1W41ZD2_UB4DrV-3xEq9iIAkCqg,190656
37
+ fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=nZblYdz9XBtJD3YAP1GZkCLIryI1DDh5ri9rb0pR90Y,358592
38
38
  fbgemm_gpu/experimental/example/utils.py,sha256=Je__VkMlBMLOhh7NXOocOdvaa2gz9kl9Dkqeu25tpFA,562
39
39
  fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py,sha256=1CqUfzlYyXTvU-BNaUq4RZpLV-2lKAVCAHeJzSIZFWw,419
40
40
  fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py,sha256=R4VNZdPSgmRmwDfTt2CShED2SGUF6dCXSUW2C4LISgE,215713
41
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py,sha256=KrI-wZeIf4AqcjXo5XoxAUWzOeM5MHTvhKBKzbQ-Hc0,153178
42
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py,sha256=5ClZ-GDrx6q0uaqWOOmKGVANBQfAd1KFBt0LneFeZDY,42364
43
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py,sha256=SltbY_dsit5e7B8lDIB_VYPrEq0t9kckthj9mQaVNfA,7571
44
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py,sha256=rULXIpVaaRS3GKUZ1RHcWUrUyy0xMVREwS1SFShGgcw,4302
41
+ fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py,sha256=8I3qGh9lzio3Wt67X0Vt0aZvkqcecyO5mpktHRrl8jc,153174
42
+ fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py,sha256=OXvsVGtULWPYIyWXqdvRf_v-ZgeG5qiDCdmjbvmR2nE,42361
43
+ fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py,sha256=_7qbOVZAPvaTBAEFA5lvQIFtcgd-iCXAZ4KWlwEkcAE,7570
44
+ fbgemm_gpu/experimental/gemm/triton_gemm/utils.py,sha256=HR4sVGYswh_h3aSGUoZrN76WX01mTYCGDVMdCXt9Ruc,4301
45
45
  fbgemm_gpu/experimental/gen_ai/__init__.py,sha256=r3NlNCXuIh0pfKwKU5v14y6AZkpoIkKWbtzxSprgeKA,1713
46
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=eCZ54iUjb6Z4A1IJcGwiZVm2uwjF6yDSHl2ZEWlokC8,65238760
47
- fbgemm_gpu/experimental/gen_ai/quantize.py,sha256=KAljWSdN-1_c5DWfT-3MDxWLMULK49Yu36t6TmQI9Tw,12599
46
+ fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=NBnThgRqEGzgO0DMhmF2fQcOGqQyEp8sAScSL-rFafs,229525416
47
+ fbgemm_gpu/experimental/gen_ai/quantize.py,sha256=EOfTJI2efb37hivgJd__xe8-YdWRzCBbGpXd4rSu-ck,12598
48
48
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py,sha256=-R_LxyHpdXMILU9TNuYoRisBCkfK0_VLyixefaeZf4g,1463
49
49
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py,sha256=gbhNU3mDTKJb3yt3inIDbiUjX_SG1oZfzgDygtHvMpk,10101
50
50
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py,sha256=fD39_WH7TfNCiP5Vl46ToX6PsLMLUFLhizT26Qe7TWg,17282
51
51
  fbgemm_gpu/experimental/gen_ai/bench/__init__.py,sha256=XpAK_eyqDSKeFC5J9KpnKtbZG07mrDh9d2j1LFKzr-8,404
52
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py,sha256=ApEyJOf_rdIo8V_EgvhZXBGNov8ITC_dnB95v8szulI,8515
52
+ fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py,sha256=aQbX9JzNeC_7Ka2EjJhShBWCgOmDg3bDYXWHhipYjps,8513
53
53
  fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py,sha256=K9Nib6D7xJbw1QwEVuCJrVyI1qs988moo3cieVKYuFY,12057
54
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py,sha256=BWl6t-4acbuRSEX2aVNDlFrSWZkqMWK2sI3VONaMd3Q,24047
55
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py,sha256=Kq4zSfxrzmSL75RWWdhPSTWq3AxClu_RO3onn5vzx8s,104983
54
+ fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py,sha256=7OiaaOvVIQJLNgxEeqW6t8ZkFtXRd7js-6ZAJ29zuRs,24044
55
+ fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py,sha256=bl23zzD4LanvTinNHMrlcCOxBvsRiTCnxBpB-Ed4yO0,104980
56
56
  fbgemm_gpu/experimental/gen_ai/moe/README.md,sha256=z9ybHmv4KFJ1drj5OByuFaOY0tRQwwiIW3Q22TB_2-k,904
57
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py,sha256=lwSvff07yEav024B1XyfgW8r8hwNe--aEDywcO7rnbM,1905
58
- fbgemm_gpu/experimental/gen_ai/moe/activation.py,sha256=NiXhWyCNagI3P9N3N89iSX7xKuShdkq9DxEUAzoV6y0,7892
59
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py,sha256=8inrE4dkpfO9NFkrmXyXOCM262LMcTA3SQldxPoosT8,21044
60
- fbgemm_gpu/experimental/gen_ai/moe/layers.py,sha256=QLwoKjyYUHT5vXAvp_maRSxyruwGXaNURgtW8ataVyg,42693
57
+ fbgemm_gpu/experimental/gen_ai/moe/__init__.py,sha256=SeASfWgbuYq4p6_YIax-8KhRFaqyL5933dadUKRJNgo,1904
58
+ fbgemm_gpu/experimental/gen_ai/moe/activation.py,sha256=GeIcBKXpfvJWSn1P0nlbMqzuLYvlyyaZ8pQsSf1GHT0,7891
59
+ fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py,sha256=I_pNn2-VUD9_tiEvLlJMJJFP3orOepnfQ--75ytnfIo,21043
60
+ fbgemm_gpu/experimental/gen_ai/moe/layers.py,sha256=fuCPsRzM62zrfW5lHkjtpbyRi9YP3ytp0GgYmwweHa8,42691
61
61
  fbgemm_gpu/experimental/gen_ai/moe/shuffling.py,sha256=VDGEUdLZyj6mblJkAIReLICxU5BGnvmUjgZDP0VVqt8,11077
62
62
  fbgemm_gpu/quantize/__init__.py,sha256=pftciXHE7csekDFkl7Ui1AWglVMMnSrOO04mREnUdb0,921
63
- fbgemm_gpu/quantize/quantize_ops.py,sha256=25AIOv9n2UoxamMUaI6EK1Ur4gSHxbZIReHBtgOjjCs,2228
64
- fbgemm_gpu/sll/__init__.py,sha256=rgXh35-OFUE54E9gGBq3NGxouGvgMv2ccY2bWUTxONY,4191
63
+ fbgemm_gpu/quantize/quantize_ops.py,sha256=BhOS3PPKJ6-UFyKFYBB3qtRESSDmHo0UKl2zlXKeKhQ,2227
64
+ fbgemm_gpu/sll/__init__.py,sha256=dvFBTqA7Rw8bvZclAAH-l1eMxD9-haQ9lKYUnZXCmIM,4190
65
65
  fbgemm_gpu/sll/cpu/__init__.py,sha256=glsukNpXtf47VRIdBktILD-4CmVcf4621SGB55lT_ho,2692
66
66
  fbgemm_gpu/sll/cpu/cpu_sll.py,sha256=3zRsDZKCFPly1EZWl4LNB3ABJVy4JM4RVwmDuUeJZzc,27870
67
67
  fbgemm_gpu/sll/meta/__init__.py,sha256=2sMcD67XGsweBZ-UV2AEJmM4ELPsHeRAYED6kqfgAd4,1077
68
68
  fbgemm_gpu/sll/meta/meta_sll.py,sha256=Jk14EOW9VPFwawD7Bwky0R0A5rmbcLWMo52oH8J6Koc,8305
69
- fbgemm_gpu/sll/triton/__init__.py,sha256=dW_cEW0R8635sKLozsL88SP0Cch5QnBGvfnAmoqWMic,4109
69
+ fbgemm_gpu/sll/triton/__init__.py,sha256=ndvZ5OO81KP65HopJql91R9y_5fC88WnNIGYxCAVKwM,4099
70
70
  fbgemm_gpu/sll/triton/common.py,sha256=hISlX4Y-7FtGof-Xx4_B8-2vlF27F9t4p2qyLMUnJ8A,798
71
71
  fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py,sha256=J9qOqjNJ72LUBqs-pGI9wrFzzzBpsZ5fzYjgfKc2YhY,1885
72
72
  fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py,sha256=M_AMJfW9D67xa4ezhmBViKsrt_n9EiX-Ki_drI5K3Bo,5925
73
73
  fbgemm_gpu/sll/triton/triton_jagged_bmm.py,sha256=QFhaIQc8g-TRHr7wjm-Wd-atNJS1fDDkImHXXB3v-gU,11789
74
74
  fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py,sha256=hccLxsKoSZKiWid5P_yl-IVdBSXw1Rt0WeiRsjLD2Iw,13864
75
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py,sha256=_0hke_aaAdKQJpGUYX20NLss1_cXDIKxqblX4QQb7Io,1592
75
+ fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py,sha256=HpSn4BPFHAODTmXAsZUibAppL1x7qI50vpQhA_p98OE,1591
76
76
  fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py,sha256=9R7BOOe8SJiko1PgbiuHlFyPKtGaaCFSlZ1RaEQyICE,4198
77
77
  fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py,sha256=nebxJ7-1muDn-1oEuE46NbYbr6BcsPcuTOsQ49nCchI,22783
78
78
  fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py,sha256=po9Nx4uAGVu_YIZ9CWvrmzSwxDsnDuNAtnk9VR7-Ems,17750
@@ -82,22 +82,22 @@ fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py,sha256=nEo5I-b
82
82
  fbgemm_gpu/tbe/__init__.py,sha256=fE0IHi1JJpxsNVBNzWNee2thrNXFFRhY94c80RxNSIE,231
83
83
  fbgemm_gpu/tbe/bench/__init__.py,sha256=wgPBmxtQMmbA39cbQ2nO4PGAk5lXjFGjG8-9FoAXg34,1589
84
84
  fbgemm_gpu/tbe/bench/bench_config.py,sha256=xgtlGLCeZVW6jBYwkKsiQeCslCrWDgJbV2NLLwCRSn4,5452
85
- fbgemm_gpu/tbe/bench/bench_runs.py,sha256=vCblxjwvpzZ5oBxd6Z9fYy2KYmI--ySYlqRw_PLPX3k,23507
86
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py,sha256=Ey-3Rx4jfzam4QnYs-pNIe-UJvgmoeeM0zZ4C5j5ZuU,6891
87
- fbgemm_gpu/tbe/bench/eeg_cli.py,sha256=DuF0pjy1wjrGaqsf1Bo9IP_q5nNx237cv9j80pG5aCk,3569
88
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py,sha256=CXwupJIhtDQiOedqSYhJyXbiMOikML5torrXb5hqt2Y,4967
85
+ fbgemm_gpu/tbe/bench/bench_runs.py,sha256=CBxO3Jad091cuD3ARr3UxRdGWrpWJkop44Tc17_OaeM,23506
86
+ fbgemm_gpu/tbe/bench/benchmark_click_interface.py,sha256=_e86jTLSWxYSkj8aiHm53kVzPJirDZWRl4S_Zd5FuOo,6917
87
+ fbgemm_gpu/tbe/bench/eeg_cli.py,sha256=n7_9L2dbb2F65BSABH50HRzRQFgujnPESjzuHSVjG_U,3568
88
+ fbgemm_gpu/tbe/bench/embedding_ops_common_config.py,sha256=WvoPvw-pY7gHQuJZlcU5RL87-pDcKKdMPH5wwUUOmAc,4966
89
89
  fbgemm_gpu/tbe/bench/eval_compression.py,sha256=ulFMaNZF2g_vfkXLWZSh02ibotg1zpTz3swVU484mzU,3486
90
90
  fbgemm_gpu/tbe/bench/reporter.py,sha256=ZK5RFolUmZEcsEaife270_iOdXAQD5EjTUkuxctnAbY,804
91
- fbgemm_gpu/tbe/bench/tbe_data_config.py,sha256=M0lK6m3S7Kl34prQcC3z8POr93FgX1oEUZ6MdVXZq5M,4794
92
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py,sha256=tgNB_3qWqWpjR86BhgRSU74bdW_ilRjtG61Cxmy1_Vk,10923
93
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py,sha256=MNddYzoRlu0mNhnsVVG57JN7pBAepfaRL7UCEzS2KoI,10007
91
+ fbgemm_gpu/tbe/bench/tbe_data_config.py,sha256=zV8gzA9wcpDqh8y9JC9mUCEt-_6IxrcJn3SlvpqMBo4,4823
92
+ fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py,sha256=A-iFmpRZkMMH2qgJuAoRplH5CyT1MUFTvgSDf1n6e4A,10921
93
+ fbgemm_gpu/tbe/bench/tbe_data_config_loader.py,sha256=2pz1HBhQ4UP6dHtxECdxWUhEb05wv6ZkG1u33Sy1EJA,10006
94
94
  fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py,sha256=sptdqcNE9JlgyIJ17neZaMxagKG469_ynX0mVx_JKBY,6090
95
- fbgemm_gpu/tbe/bench/utils.py,sha256=cq_6FJHlgZ5femAK6XKpj7nJ9jc03qXI16N1ht1CcLg,1721
95
+ fbgemm_gpu/tbe/bench/utils.py,sha256=kxc3mqsZKq_tjlCN65TPevuKt6JUvwZs9LN8lu8Pfds,1720
96
96
  fbgemm_gpu/tbe/cache/__init__.py,sha256=lrYwhvqX2eWN0vAPe89HYgMW_O1vccoOcoFHJ9cyM-s,398
97
97
  fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py,sha256=VmG9EennGcq2By8Tj8VkFsJG0oOCGw8EhlPo8-t--Fk,14604
98
98
  fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py,sha256=vZHj7KIe1DoJDy5eft29XtGg6I-tRx60tjKOcTHRAYI,1321
99
99
  fbgemm_gpu/tbe/ssd/__init__.py,sha256=wzfMT10cp_dqK2lrebC449hOdexBnizcf_98lA1NyHs,483
100
- fbgemm_gpu/tbe/ssd/common.py,sha256=1J8K7sTQswgCYWaVwF-ZdCJj7mNN6O9GI70AaZWzJGE,1044
100
+ fbgemm_gpu/tbe/ssd/common.py,sha256=QP9Cz2t3dxzSQ2P4x0R2ekQY2Dk1TzijqXBdMJ-uLkQ,1043
101
101
  fbgemm_gpu/tbe/ssd/inference.py,sha256=B_uX66ajGA9YKGlFa5TmGWs7b-b1RFigzwxmENZ9Oio,22816
102
102
  fbgemm_gpu/tbe/ssd/training.py,sha256=C6M3H_f8oWWRkC4R-BJED73au-Gl9SUVllxOoFSiDkI,212234
103
103
  fbgemm_gpu/tbe/ssd/utils/__init__.py,sha256=5DgmR2HA6NtmYh2ddkUgpDsZ6a7hF0DPedA1gMpdh18,250
@@ -111,7 +111,7 @@ fbgemm_gpu/tbe/utils/quantize.py,sha256=icN2MXnl5rNqtKhGKkjpelx5pYBMYUv-6CrghxeV
111
111
  fbgemm_gpu/tbe/utils/requests.py,sha256=rQkEoaUUWEYCQM-1K_Lxg1wPcyIVw8sbdaGFTpsaE5I,18040
112
112
  fbgemm_gpu/triton/__init__.py,sha256=kPn_Ye6J9DAzWtqi76KYGwfKSqw0IhqG3Bir5aUpkWM,658
113
113
  fbgemm_gpu/triton/common.py,sha256=wnkLd2a8fKpefymLL-LjNKEL4hDVSxFiF5g3aF8mzsw,2131
114
- fbgemm_gpu/triton/quantize.py,sha256=z3y74-DCbGcQDsO70b2jK_HQDIYC0UJ7IEG2vvMu0_Y,26816
114
+ fbgemm_gpu/triton/quantize.py,sha256=I0pxyfIx04zyq55x4Pvj-28Cb2ZeF-SGtFhAymFagkg,27073
115
115
  fbgemm_gpu/triton/quantize_ref.py,sha256=q4RBmFaqPVPELU52lbSgB0n26Aun7apeK7bRF2MWS80,11553
116
116
  fbgemm_gpu/triton/jagged/__init__.py,sha256=om0yhjuzKuE1UQakFMWHsXN4WNb8mvNkZtYofQ8hdn4,246
117
117
  fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py,sha256=F2eQWjkWMR5RWQ48oIr-8OU_CRZyLazDpT7DFrDWS6g,29871
@@ -121,8 +121,8 @@ fbgemm_gpu/utils/loader.py,sha256=1hCEhNvkflniH46fGcrguLeP1z-6uyOu2QFwqKU5CIM,99
121
121
  fbgemm_gpu/utils/torch_library.py,sha256=ywsAHjbuwesj50LjEu99WkAH17FlaVgePZ9OmFg6YE4,4193
122
122
  fbgemm_gpu/utils/writeback_util.py,sha256=PyVbHp1EuF-GKrJv_CTP6B50Z0oBblXKucf7Rhd6KKY,4614
123
123
  list_versions/__init__.py,sha256=UmTeqCk-UJWFtlZQWvZao3xvui2w9E3X_JdOXVjRaNw,315
124
- list_versions/cli_run.py,sha256=CChZoXQ-tiKaWboXAYlPVJ5w8K5zAKiKcncA087I1sc,4508
125
- fbgemm_gpu_genai_nightly-2025.12.23.dist-info/METADATA,sha256=f_OA5iQSJM23ogn2epFKT0jnDM5ggh7Xg6c7FtDT0ag,2657
126
- fbgemm_gpu_genai_nightly-2025.12.23.dist-info/WHEEL,sha256=V2Q6mQKbouIadCxoRjt9FQ9oKfi45-uZUcoc77zzs0M,108
127
- fbgemm_gpu_genai_nightly-2025.12.23.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
128
- fbgemm_gpu_genai_nightly-2025.12.23.dist-info/RECORD,,
124
+ list_versions/cli_run.py,sha256=BCRaJvjVFBFmD5WPdjC_yJwlLv1w_TYOe3eYlf_9ZMo,4506
125
+ fbgemm_gpu_genai_nightly-2026.1.14.dist-info/METADATA,sha256=oPv8amMA9l2QQ5sKkeeABPCccKVbYAZkVJX7314f1kY,2656
126
+ fbgemm_gpu_genai_nightly-2026.1.14.dist-info/WHEEL,sha256=V2Q6mQKbouIadCxoRjt9FQ9oKfi45-uZUcoc77zzs0M,108
127
+ fbgemm_gpu_genai_nightly-2026.1.14.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
128
+ fbgemm_gpu_genai_nightly-2026.1.14.dist-info/RECORD,,
list_versions/cli_run.py CHANGED
@@ -13,9 +13,7 @@ from datetime import datetime
13
13
  from typing import Union
14
14
 
15
15
  import click
16
-
17
16
  import pandas as pd
18
-
19
17
  import torch
20
18
 
21
19