fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,43 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401
10
+ from fbgemm_gpu.utils import TorchLibraryFragment
11
+
12
+ lib = TorchLibraryFragment("fbgemm")
13
+
14
+ lib.define(
15
+ """quantize_mx(
16
+ Tensor input,
17
+ int scale_bits,
18
+ int elem_ebits,
19
+ int elem_mbits,
20
+ float elem_max_norm,
21
+ int mx_group_size,
22
+ int? rounding_mode = None
23
+ ) -> Tensor
24
+ """
25
+ )
26
+
27
+ lib.define(
28
+ """dequantize_mx(
29
+ Tensor input,
30
+ int mx_group_size
31
+ ) -> Tensor
32
+ """
33
+ )
34
+
35
+ lib.register(
36
+ "quantize_mx",
37
+ {"CUDA": quantize_mx, "CPU": quantize_mx},
38
+ )
39
+
40
+ lib.register(
41
+ "dequantize_mx",
42
+ {"CUDA": dequantize_mx, "CPU": dequantize_mx},
43
+ )
@@ -0,0 +1,64 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Union
9
+
10
+ import torch
11
+
12
+ from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
13
+
14
+
15
+ def quantize_mx(
16
+ input: torch.Tensor,
17
+ scale_bits: int = 8,
18
+ elem_ebits: int = 2,
19
+ elem_mbits: int = 3,
20
+ elem_max_norm: float = 6.0,
21
+ mx_group_size: int = 32,
22
+ rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
23
+ ) -> torch.Tensor:
24
+ """
25
+ Registered quantize_mx ops for E2E comm.
26
+ (registration is done in __init__.py)
27
+ We use Triton implementation for quantization
28
+ Args:
29
+ input: FP32 tensor of size total_elems to be quantized
30
+ scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1)
31
+ elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1)
32
+ elem_mbits: num bits of the mantissa incl. sign and implicit bits (
33
+ i.e., 3 for MX4 e2m1)
34
+ elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1)
35
+ mx_group_size: num elements that share the max shared_exponent
36
+ rounding_mode: Which type of rounding to use when calculating shared exponent.
37
+
38
+ Return:
39
+ output: MX4 tensor packed into int8 values with size
40
+ (total_elems / 2 + total_elems / groupsize)
41
+ the shared exponent of each group is stored at the last byte
42
+ of output of each group
43
+ """
44
+ return fp32_to_mx4(
45
+ input, mx_group_size, rounding_mode=rounding_mode, use_triton=True
46
+ )
47
+
48
+
49
+ def dequantize_mx(
50
+ input: torch.Tensor,
51
+ mx_group_size: int = 32,
52
+ ) -> torch.Tensor:
53
+ """
54
+ Registered dequantize_mx ops for E2E comm
55
+ (registration is done in __init__.py to prevent multiple loading)
56
+ We use triton implementation for quantization
57
+ Args:
58
+ input: FP8 tensor (MX4 packed in FP8)
59
+ mx_group_size: number of elements that shares the same max shared_exponent
60
+
61
+ Return:
62
+ output: FP32 tensor with total elements (total_elems)
63
+ """
64
+ return mx4_to_fp32(input, mx_group_size, use_triton=True)
@@ -0,0 +1,315 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # The code in this file is refactored from https://fburl.com/code/p2gy2gxb
11
+ # based on "Amy Yang et al., Training Deep Learning Recommendation Model with
12
+ # Quantized Collective Communications", DLP-KDD 2020.
13
+
14
+
15
+ import logging
16
+ from typing import Optional, TypeVar
17
+
18
+ import torch
19
+
20
+ from fbgemm_gpu.quantize_utils import (
21
+ bf16_to_fp32,
22
+ fp16_to_fp32,
23
+ fp32_to_bf16_with_clamp,
24
+ fp32_to_fp16_with_clamp,
25
+ fp32_to_hfp8_with_clamp,
26
+ fp32_to_mx4,
27
+ hfp8_to_fp32,
28
+ mx4_to_fp32,
29
+ RoundingMode,
30
+ )
31
+
32
+ from fbgemm_gpu.split_embedding_configs import SparseType
33
+
34
+ from torch.autograd.profiler import record_function # usort:skip
35
+ from dataclasses import dataclass
36
+
37
+ import fbgemm_gpu.quantize.quantize_ops # noqa F401
38
+
39
+ logger: logging.Logger = logging.getLogger()
40
+
41
+ # FP8 configurations
42
+ ebits, mbits, bias = 4, 3, 15
43
+ max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
44
+
45
+ # INT8 configurations
46
+ ROW_DIM_DEFAULT = 32
47
+
48
+ # MX4 configurations
49
+ MX_GROUP_SIZE_DEFAULT = 32
50
+
51
+
52
+ def none_throws(
53
+ # pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
54
+ # valid type.
55
+ optional: Optional[TypeVar("_T")],
56
+ message: str = "Unexpected `None`",
57
+ # pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
58
+ ) -> TypeVar("_T"):
59
+ if optional is None:
60
+ raise AssertionError(message)
61
+ return optional
62
+
63
+
64
+ @dataclass
65
+ class QuantizationContext:
66
+ row_dim: int = ROW_DIM_DEFAULT
67
+ row_dim_quant: int = -1
68
+ mx_group_size: int = MX_GROUP_SIZE_DEFAULT
69
+ rounding_mode: Optional[RoundingMode] = RoundingMode.even
70
+ padded_dim_sum_per_rank: Optional[list[int]] = None
71
+
72
+
73
+ def _quantize_tensor(
74
+ input_tensor: torch.Tensor,
75
+ comm_precision: SparseType,
76
+ ctx: Optional[QuantizationContext] = None,
77
+ is_fwd: bool = True,
78
+ ) -> torch.Tensor:
79
+ if comm_precision == SparseType.FP32:
80
+ return input_tensor
81
+ elif comm_precision == SparseType.FP16:
82
+ return fp32_to_fp16_with_clamp(input_tensor)
83
+ elif comm_precision == SparseType.BF16:
84
+ return fp32_to_bf16_with_clamp(input_tensor)
85
+ elif comm_precision == SparseType.FP8:
86
+ # return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias)
87
+ if ctx is not None and ctx.row_dim > 0:
88
+ ctx = none_throws(ctx)
89
+ row_dim = ctx.row_dim
90
+ input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor
91
+ input_2d_quant = torch.ops.fbgemm.FloatToFP8RowwiseQuantized(
92
+ input_2d, is_fwd
93
+ )
94
+ row_dim_quant = input_2d_quant.shape[1]
95
+ input_quant_all2all = None
96
+ input_quant_all2all = input_2d_quant.view((-1))
97
+ ctx.row_dim_quant = row_dim_quant
98
+ return input_quant_all2all
99
+ else:
100
+ return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias)
101
+ elif comm_precision == SparseType.INT8:
102
+ ctx = none_throws(ctx)
103
+ row_dim = ctx.row_dim
104
+ input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor
105
+ input_2d_quant = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_2d)
106
+ row_dim_quant = input_2d_quant.shape[1]
107
+ input_quant_all2all = None
108
+ input_quant_all2all = input_2d_quant.view((-1))
109
+ ctx.row_dim_quant = row_dim_quant
110
+ return input_quant_all2all
111
+ elif comm_precision == SparseType.MX4:
112
+ mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
113
+ rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.even
114
+ return fp32_to_mx4(
115
+ input_tensor, mx_group_size, rounding_mode=rounding_mode
116
+ ).view(-1)
117
+ else:
118
+ raise ValueError(f"comm_precision={comm_precision} is not supported")
119
+
120
+
121
+ def _dequantize_tensor(
122
+ quantized_tensor: torch.Tensor,
123
+ comm_precision: SparseType,
124
+ ctx: Optional[QuantizationContext] = None,
125
+ is_fwd: bool = True,
126
+ fp8_output_dtype: Optional[SparseType] = None,
127
+ ) -> torch.Tensor:
128
+ if comm_precision == SparseType.FP32:
129
+ assert quantized_tensor.dtype == torch.float
130
+ return quantized_tensor
131
+ elif comm_precision == SparseType.FP16:
132
+ assert quantized_tensor.dtype == torch.half
133
+ return fp16_to_fp32(quantized_tensor)
134
+ elif comm_precision == SparseType.BF16:
135
+ assert quantized_tensor.dtype == torch.bfloat16
136
+ return bf16_to_fp32(quantized_tensor)
137
+ elif comm_precision == SparseType.FP8:
138
+ if ctx is not None and ctx.row_dim > 0:
139
+ row_dim_quant = ctx.row_dim_quant
140
+ quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
141
+ # use provided fp8_output_dtype or default to FP32 (0)
142
+ output_dtype_int = (
143
+ fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
144
+ )
145
+ dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
146
+ quantized_tensor_2d,
147
+ is_fwd,
148
+ output_dtype_int,
149
+ )
150
+ return dequant_tensor.view(-1)
151
+ else:
152
+ assert quantized_tensor.dtype == torch.uint8
153
+ return hfp8_to_fp32(quantized_tensor, ebits, bias)
154
+ elif comm_precision == SparseType.INT8:
155
+ ctx = none_throws(ctx)
156
+ row_dim_quant = ctx.row_dim_quant
157
+ quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
158
+ dequant_tensor = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(
159
+ quantized_tensor_2d
160
+ )
161
+ return dequant_tensor.view(-1)
162
+ elif comm_precision == SparseType.MX4:
163
+ mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
164
+ return mx4_to_fp32(quantized_tensor, mx_group_size)
165
+ else:
166
+ raise ValueError(f"comm_precision={comm_precision} is not supported")
167
+
168
+
169
+ class QuantizedCommCodec:
170
+ # Concrete implementation of QuantizedCommCodec provided by FBGEMM functions.
171
+ def __init__(
172
+ self,
173
+ comm_precision: SparseType,
174
+ loss_scale: Optional[float] = None,
175
+ row_dim: Optional[int] = None,
176
+ is_fwd: bool = True,
177
+ rounding_mode: Optional[RoundingMode] = None,
178
+ fp8_output_dtype: Optional[SparseType] = None,
179
+ ) -> None:
180
+ if loss_scale is not None:
181
+ if comm_precision not in [SparseType.FP16, SparseType.BF16]:
182
+ logger.warning(
183
+ f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None"
184
+ )
185
+ loss_scale = None
186
+
187
+ logger.info(
188
+ f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}"
189
+ )
190
+
191
+ self._comm_precision = comm_precision
192
+ self._loss_scale = loss_scale
193
+ self._is_fwd = is_fwd
194
+ self._row_dim: int = -1 if row_dim is None else row_dim
195
+ self._rounding_mode: Optional[RoundingMode] = rounding_mode
196
+ self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
197
+ if self._comm_precision == SparseType.MX4:
198
+ self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
199
+ self._rounding_mode = (
200
+ RoundingMode.even if rounding_mode is None else rounding_mode
201
+ )
202
+
203
+ def encode(
204
+ self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
205
+ ) -> torch.Tensor:
206
+ if self._loss_scale is not None:
207
+ input_tensor = self._loss_scale * input_tensor
208
+ with record_function(
209
+ f"## encoder {self._comm_precision} {self._loss_scale} ##"
210
+ ):
211
+ output = _quantize_tensor(
212
+ input_tensor,
213
+ self._comm_precision,
214
+ ctx,
215
+ self._is_fwd,
216
+ )
217
+ return output
218
+
219
+ def decode(
220
+ self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
221
+ ) -> torch.Tensor:
222
+ if self._loss_scale is not None:
223
+ input_tensor = input_tensor / self._loss_scale
224
+ with record_function(
225
+ f"## decoder {self._comm_precision} {self._loss_scale} ##"
226
+ ):
227
+ dequantized_tensor = _dequantize_tensor(
228
+ input_tensor,
229
+ self._comm_precision,
230
+ ctx,
231
+ self._is_fwd,
232
+ fp8_output_dtype=self._fp8_output_dtype,
233
+ )
234
+ return dequantized_tensor
235
+
236
+ def calc_quantized_size(
237
+ self, input_len: int, ctx: Optional[QuantizationContext] = None
238
+ ) -> int:
239
+ # Use the same logic in _float_to_fused8bitrowwise_gpu_t()
240
+ if self._comm_precision == SparseType.INT8 or (
241
+ self._comm_precision == SparseType.FP8 and self._row_dim > 0
242
+ ):
243
+ ctx = none_throws(ctx)
244
+ torch._check(
245
+ input_len % ctx.row_dim == 0,
246
+ lambda: f"input_len {input_len} is not a multiple of row dim {ctx.row_dim}",
247
+ )
248
+ assert input_len % ctx.row_dim == 0, (
249
+ f"input_len {input_len} is not a multiple of row dim {ctx.row_dim} "
250
+ "Please check your batch size (power of 2 batch size is recommended)"
251
+ )
252
+ nrows = input_len // ctx.row_dim
253
+ ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4
254
+ return nrows * ncols
255
+ elif self._comm_precision == SparseType.MX4:
256
+ if ctx:
257
+ group_size = ctx.mx_group_size
258
+ else:
259
+ group_size = MX_GROUP_SIZE_DEFAULT
260
+ assert (
261
+ input_len % group_size == 0
262
+ ), f"input_len {input_len} needs to be multiple of group_size {group_size}"
263
+ # quantized output size = half input size + number of groups (shared exp)
264
+ ctx = none_throws(ctx)
265
+ return (input_len // 2) + (input_len // ctx.mx_group_size)
266
+ else:
267
+ return input_len
268
+
269
+ @property
270
+ def quantized_dtype(self) -> torch.dtype:
271
+ return self._comm_precision.as_dtype()
272
+
273
+ def create_context(self) -> Optional[QuantizationContext]:
274
+ # fp8 rowwise is activated when row_dim > 0
275
+ if self._comm_precision == SparseType.FP8:
276
+ return QuantizationContext(self._row_dim)
277
+ if self._comm_precision == SparseType.MX4:
278
+ return QuantizationContext(
279
+ row_dim=self._row_dim,
280
+ mx_group_size=self._row_dim,
281
+ rounding_mode=self._rounding_mode,
282
+ )
283
+ # int8 rowwise is default
284
+ return QuantizationContext()
285
+
286
+ def padded_size(
287
+ self,
288
+ input_tensor: torch.Tensor,
289
+ dim_per_rank: list[int],
290
+ my_rank: int,
291
+ qcomm_ctx: QuantizationContext,
292
+ ) -> tuple[int, int]:
293
+ if input_tensor.ndim == 1:
294
+ return input_tensor.shape[0], 0
295
+ # return padded size for the feature dimension (dim 1), 0 if no padding needed.
296
+ padded_dim_sum, padding_size = input_tensor.shape[1], 0
297
+ if self._comm_precision == SparseType.MX4:
298
+ group_size = qcomm_ctx.mx_group_size
299
+ padding_size_per_rank = [
300
+ group_size - (t if (t := dim_sum % group_size) > 0 else group_size)
301
+ for dim_sum in dim_per_rank
302
+ ]
303
+ padded_dim_sum_per_rank = [
304
+ a + b for a, b in zip(dim_per_rank, padding_size_per_rank)
305
+ ]
306
+ dim_sum, padding_size = (
307
+ dim_per_rank[my_rank],
308
+ padding_size_per_rank[my_rank],
309
+ )
310
+ assert input_tensor.ndim == 2 and input_tensor.shape[1] == dim_sum
311
+ qcomm_ctx.padded_dim_sum_per_rank = padded_dim_sum_per_rank
312
+ padded_dim_sum = padding_size + dim_sum
313
+ return padded_dim_sum, padding_size
314
+
315
+ return padded_dim_sum, padding_size
@@ -0,0 +1,246 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import logging
11
+ from typing import Optional, Union
12
+
13
+ import torch # isort:skip
14
+
15
+ import fbgemm_gpu
16
+
17
+ from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
18
+ from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
19
+
20
+ try:
21
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
22
+ open_source = bool(getattr(fbgemm_gpu, "open_source", False))
23
+ except NotImplementedError:
24
+ open_source = False
25
+
26
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
27
+ if not open_source:
28
+ from mtia.kernels.triton.mx4.quantize import (
29
+ triton_dequantize_mx4 as mtia_dequantize_mx4,
30
+ triton_quantize_mx4 as mtia_quantize_mx4,
31
+ )
32
+
33
+ logger: logging.Logger = logging.getLogger()
34
+
35
+ try:
36
+ # pyre-ignore[21]
37
+ from fbgemm_gpu import open_source # noqa: F401
38
+ except Exception:
39
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
40
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
41
+
42
+ TORCH_HALF_MIN: float = torch.finfo(torch.float16).min
43
+ TORCH_HALF_MAX: float = torch.finfo(torch.float16).max
44
+
45
+ TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min
46
+ TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max
47
+
48
+
49
+ def fp32_to_mx4(
50
+ tensor: torch.Tensor,
51
+ group_size: int = 32,
52
+ ebits: int = 2,
53
+ mbits: int = 1,
54
+ rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even,
55
+ stochastic_casting: bool = False,
56
+ use_triton: bool = True,
57
+ ) -> torch.Tensor:
58
+ """Quantize an FP32 tensor to MX4 with triton or native cuda impl.
59
+
60
+ Args:
61
+ tensor (torch.Tensor): FP32 tensor to quantize with M total elements.
62
+ group_size (int): Compute scale in chunks of group_size.
63
+ ebits (int): Number of exponent bits in target mx4 format.
64
+ mbits (int): Number of mantissa bits in target mx4 format.
65
+ rounding_mode (RoundingMode or int): Which type of rounding to use when computing exponent.
66
+ Only supported with use_triton=True.
67
+ stochastic_casting (bool): Whether to use stochastic casting when downcasting.
68
+ use_triton (bool): If set, use triton quantization, otherwise cuda.
69
+
70
+ Return:
71
+ output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize)
72
+ """
73
+ # Accelerated MX4 is only available on cuda, if input is on cpu, use python.
74
+ # Operate on flattened input.
75
+ if rounding_mode is None:
76
+ rounding_mode = RoundingMode.even
77
+
78
+ if not tensor.is_cuda and not tensor.is_mtia:
79
+ return py_quantize_mx4(
80
+ tensor,
81
+ group_size,
82
+ ebits=ebits,
83
+ mbits=mbits,
84
+ rounding_mode=rounding_mode,
85
+ stochastic_casting=stochastic_casting,
86
+ )
87
+
88
+ if use_triton:
89
+ if tensor.is_mtia:
90
+ return mtia_quantize_mx4(
91
+ tensor,
92
+ group_size,
93
+ ebits=ebits,
94
+ mbits=mbits,
95
+ rounding_mode=rounding_mode,
96
+ stochastic_casting=stochastic_casting,
97
+ )
98
+ return quantize_mx4(
99
+ tensor,
100
+ group_size,
101
+ ebits=ebits,
102
+ mbits=mbits,
103
+ rounding_mode=rounding_mode,
104
+ stochastic_casting=stochastic_casting,
105
+ )
106
+ else:
107
+ out = torch.ops.fbgemm.quantize_mx_cuda(
108
+ tensor.flatten(),
109
+ scale_bits=8,
110
+ elem_ebits=2,
111
+ elem_mbits=3,
112
+ elem_max_norm=6.0,
113
+ mx_group_size=group_size,
114
+ )
115
+ # Perserve input dimensions.
116
+ output_shape = list(tensor.shape[:-1]) + [-1]
117
+ return out.view(output_shape)
118
+
119
+
120
+ def mx4_to_fp32(
121
+ tensor: torch.Tensor,
122
+ group_size: int = 32,
123
+ use_triton: bool = True,
124
+ ebits: int = 2,
125
+ mbits: int = 1,
126
+ ) -> torch.Tensor:
127
+ """Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
128
+
129
+ Args:
130
+ tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
131
+ group_size (int): Compute scale in chunks of group_size.
132
+ use_triton (bool): If set, use triton quantization, otherwise cuda.
133
+ ebits (int): Number of exponent bits in target mx4 format.
134
+ mbits (int): Number of mantissa bits in target mx4 format.
135
+
136
+ Return:
137
+ output: FP32 tensor with total elements (M).
138
+ """
139
+ # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
140
+ if not tensor.is_cuda and not tensor.is_mtia:
141
+ return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
142
+ if use_triton:
143
+ if tensor.is_mtia:
144
+ return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
145
+ return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
146
+ else:
147
+ return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
148
+
149
+
150
+ def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
151
+ return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
152
+
153
+
154
+ def fp32_to_bf16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
155
+ return torch.clamp(tensor, TORCH_BFLOAT16_MIN, TORCH_BFLOAT16_MAX).bfloat16()
156
+
157
+
158
+ def fp32_to_hfp8_with_clamp(
159
+ tensor: torch.Tensor, ebits: int = 4, mbits: int = 3, bias: int = 15
160
+ ) -> torch.Tensor:
161
+ max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
162
+ return torch.ops.fbgemm.FloatToHFP8Quantized(
163
+ tensor.contiguous(),
164
+ ebits,
165
+ bias,
166
+ max_pos,
167
+ )
168
+
169
+
170
+ def fp16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
171
+ return tensor.float()
172
+
173
+
174
+ def bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
175
+ return tensor.view(torch.bfloat16).float()
176
+
177
+
178
+ def hfp8_to_fp32(tensor: torch.Tensor, ebits: int = 4, bias: int = 15) -> torch.Tensor:
179
+ return torch.ops.fbgemm.HFP8QuantizedToFloat(
180
+ tensor.contiguous().view(torch.uint8),
181
+ ebits,
182
+ bias,
183
+ )
184
+
185
+
186
+ def measure_fp16_quant_error(input_tensor: torch.Tensor) -> None:
187
+ # TODO: log to tensorboard
188
+
189
+ num_nan_fp32_tensor = torch.numel(input_tensor[torch.isnan(input_tensor)])
190
+ logger.info(
191
+ "num NaN in fp32 tensor: {}, ratio: {}.".format(
192
+ num_nan_fp32_tensor, num_nan_fp32_tensor / torch.numel(input_tensor)
193
+ )
194
+ )
195
+
196
+ logger.info(
197
+ "fp32 tensor profile: min: {}, max: {}, min abs:{}, max abs:{}.".format(
198
+ torch.min(input_tensor),
199
+ torch.max(input_tensor),
200
+ torch.min(torch.abs(input_tensor)),
201
+ torch.max(torch.abs(input_tensor)),
202
+ )
203
+ )
204
+
205
+ fp16_tensor = fp32_to_fp16_with_clamp(input_tensor)
206
+ num_nan_fp16_tensor = torch.numel(fp16_tensor[torch.isnan(fp16_tensor)])
207
+
208
+ logger.info(
209
+ "num NaN in fp16 tensor: {}, ratio: {}.".format(
210
+ num_nan_fp16_tensor, num_nan_fp16_tensor / torch.numel(input_tensor)
211
+ )
212
+ )
213
+
214
+ diff = torch.abs(input_tensor - fp16_tensor.float())
215
+ rel_diff = diff / torch.abs(input_tensor)
216
+ logger.info(
217
+ "fp32_to_fp16 abs error: min={}, max={}, avg={}.".format(
218
+ torch.min(diff), torch.max(diff), torch.mean(diff)
219
+ )
220
+ )
221
+
222
+ rel_diff_not_nan = rel_diff[torch.logical_not(torch.isnan(rel_diff))]
223
+ logger.info(
224
+ "fp32_to_fp16 rel error: min={}, max={}, avg={}.".format(
225
+ torch.min(rel_diff_not_nan),
226
+ torch.max(rel_diff_not_nan),
227
+ torch.mean(rel_diff_not_nan),
228
+ )
229
+ )
230
+
231
+ rel_diff_1_idx = torch.where(rel_diff == 1.0)
232
+ fp32_rel_err_1_vals = input_tensor[rel_diff_1_idx]
233
+ if torch.numel(fp32_rel_err_1_vals) > 0:
234
+ fp32_rel_err_1_vals = torch.abs(fp32_rel_err_1_vals)
235
+ logger.info(
236
+ "fp32_to_fp16 rel error == 1: fp32 min:{}, fp32 max:{}, fp32 avg:{}.".format(
237
+ torch.min(fp32_rel_err_1_vals),
238
+ torch.max(fp32_rel_err_1_vals),
239
+ torch.mean(fp32_rel_err_1_vals),
240
+ )
241
+ )
242
+
243
+ subrange_ratio = torch.numel(fp16_tensor[rel_diff_1_idx]) / torch.numel(
244
+ fp16_tensor
245
+ )
246
+ logger.info("sub fp16 range ratio: {}".format(subrange_ratio))