fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,452 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import enum
11
+ import itertools
12
+ from typing import Any, Dict # noqa: F401
13
+
14
+ import torch
15
+
16
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17
+ EmbeddingLocation,
18
+ SplitState,
19
+ )
20
+
21
+
22
+ def pad4(value: int) -> int:
23
+ """
24
+ Compute the smallest multiple of 4 that is greater than or equal to the given value.
25
+
26
+ Parameters:
27
+ value (int): The integer to align (must be non-negative).
28
+
29
+ Returns:
30
+ int: The aligned value.
31
+
32
+ Raises:
33
+ ValueError: If the input is negative.
34
+ TypeError: If the input is not an integer.
35
+ """
36
+ return (int(value) + 3) & ~3
37
+
38
+
39
+ def pad16(value: int) -> int:
40
+ """
41
+ Compute the smallest multiple of 16 that is greater than or equal to the given value.
42
+
43
+ Parameters:
44
+ value (int): The integer to align (must be non-negative).
45
+
46
+ Returns:
47
+ int: The aligned value.
48
+
49
+ Raises:
50
+ ValueError: If the input is negative.
51
+ TypeError: If the input is not an integer.
52
+ """
53
+ return (int(value) + 15) & ~15
54
+
55
+
56
+ @enum.unique
57
+ class EmbOptimType(enum.Enum):
58
+ SGD = "sgd" # uses non-deterministic updates (atomicAdd(..)) with duplicate ids
59
+ EXACT_SGD = (
60
+ "exact_sgd" # uses deterministic updates (via sorting + segment reduction)
61
+ )
62
+ LAMB = "lamb"
63
+ ADAM = "adam"
64
+ # exact/dedup: gradients to the same row are applied with coalesce then apply
65
+ # together, instead of applied in sequence (approx).
66
+ EXACT_ADAGRAD = "exact_adagrad"
67
+ EXACT_ROWWISE_ADAGRAD = "exact_row_wise_adagrad"
68
+ LARS_SGD = "lars_sgd"
69
+ PARTIAL_ROWWISE_ADAM = "partial_row_wise_adam"
70
+ PARTIAL_ROWWISE_LAMB = "partial_row_wise_lamb"
71
+ ROWWISE_ADAGRAD = "row_wise_adagrad"
72
+ SHAMPOO = "shampoo" # not currently supported for sparse embedding tables
73
+ SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
74
+ MADGRAD = "madgrad"
75
+ EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
76
+ ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
77
+ EMAINPLACE_ROWWISE_ADAGRAD = "ema_in_place_row_wise_adagrad"
78
+ NONE = "none"
79
+
80
+ def __str__(self) -> str:
81
+ return self.value
82
+
83
+ def _extract_dtype(
84
+ self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
85
+ ) -> torch.dtype:
86
+ if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
87
+ return torch.float32
88
+ return optimizer_state_dtypes[name].as_dtype()
89
+
90
+ def state_names(self) -> list[str]:
91
+ """
92
+ Returns the names of the optimizer states. The order of the states will
93
+ be the order in which they are processed and returned in
94
+ SSDTableBatchedEmbeddingBags.split_optimizer_states(), but this is not
95
+ necessarily the same as the order they are stored in the memory layout.
96
+ """
97
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
98
+ return ["momentum1"]
99
+ elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
100
+ return ["momentum1", "momentum2"]
101
+ else:
102
+ return []
103
+
104
+ def state_size_table(self, D: int) -> dict[str, int]:
105
+ """
106
+ Returns the table of state names to state sizes in terms of number of
107
+ elements (per table row)
108
+ """
109
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
110
+ return {"momentum1": 1}
111
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
112
+ return {"momentum1": D, "momentum2": 1}
113
+ elif self == EmbOptimType.ADAM:
114
+ return {"momentum1": D, "momentum2": D}
115
+ else:
116
+ return {}
117
+
118
+ def state_size_nbytes(
119
+ self,
120
+ D: int,
121
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
122
+ ) -> int:
123
+ """
124
+ Returns the size of the data (in bytes) required to hold the optimizer
125
+ state (per table row). This size includes byte-padding.
126
+ """
127
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
128
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
129
+
130
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
131
+ return momentum1_dtype.itemsize
132
+
133
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
134
+ return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
135
+
136
+ elif self == EmbOptimType.ADAM:
137
+ return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
138
+
139
+ else:
140
+ return 0
141
+
142
+ def byte_offsets_along_row(
143
+ self,
144
+ D: int,
145
+ weights_precision: "SparseType",
146
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
147
+ ) -> dict[str, tuple[int, int]]:
148
+ """
149
+ Returns the start and end byte offsets of each optimizer state along a
150
+ cache row with optimizer state offloading enabled.
151
+ """
152
+ # Extract the optimizer state dtypes
153
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
154
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
155
+
156
+ # This is the pointer to where the optimizer state begins in the memory
157
+ p0 = pad4(D) * weights_precision.as_dtype().itemsize
158
+
159
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
160
+ return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
161
+
162
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
163
+ # momentum1 lies after momentum2
164
+ p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
165
+ return {
166
+ "momentum2": (p0, p0 + momentum2_dtype.itemsize),
167
+ "momentum1": (
168
+ p1,
169
+ p1 + D * momentum1_dtype.itemsize,
170
+ ),
171
+ }
172
+
173
+ elif self == EmbOptimType.ADAM:
174
+ # momentum2 lies after momentum1
175
+ p1 = p0 + (D * momentum1_dtype.itemsize)
176
+
177
+ return {
178
+ "momentum1": (p0, p1),
179
+ "momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
180
+ }
181
+
182
+ else:
183
+ return {}
184
+
185
+ def empty_states(
186
+ self,
187
+ rows: list[int],
188
+ dims: list[int],
189
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
190
+ ) -> list[list[torch.Tensor]]:
191
+ """
192
+ Creates sets of empty tensors per table to hold optimizer states based
193
+ on the specified optimizer type, state dtypes, embedding specs, and
194
+ (optionally) local row counts.
195
+ """
196
+ # Else, check that the local row count for each table is set
197
+ assert len(rows) == len(dims)
198
+
199
+ opt_states_set: list[list[torch.Tensor]] = []
200
+
201
+ for r, D in zip(rows, dims):
202
+ # Set up the table of state names to state sizes, ordered by their
203
+ # memory layout
204
+ state_size_table = self.state_size_table(D)
205
+ ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]
206
+
207
+ # Create the optimizer states for this table
208
+ opt_states_set.append(
209
+ [
210
+ torch.empty(
211
+ # If the state size is 1, then fix tensor to 1D to be
212
+ # consistent with training.py code
213
+ # pyre-ignore [6]
214
+ (r, d) if d > 1 else r,
215
+ dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
216
+ device="cpu",
217
+ )
218
+ for state_name, d in ordered_state_sizes
219
+ ]
220
+ )
221
+
222
+ return opt_states_set
223
+
224
+ def ssd_state_splits(
225
+ self,
226
+ embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
227
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
228
+ enable_optimizer_offloading: bool = False,
229
+ ) -> list[tuple[SplitState, str, torch.dtype]]:
230
+ """
231
+ Returns the split planning for the optimizer states
232
+ """
233
+ (rows, _) = zip(*embedding_specs)
234
+ T_ = len(embedding_specs)
235
+
236
+ # This is the cumulative row counts for rowwise states
237
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
238
+ # This is the cumulative element counts for elementwise states
239
+ table_size_cumsum: list[int] = [0] + list(
240
+ itertools.accumulate([r * d for r, d in embedding_specs])
241
+ )
242
+
243
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
244
+ params = {"momentum1": row_count_cumsum}
245
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
246
+ params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
247
+ elif self == EmbOptimType.ADAM:
248
+ params = {
249
+ "momentum1": table_size_cumsum,
250
+ "momentum2": table_size_cumsum,
251
+ "row_counter": row_count_cumsum,
252
+ }
253
+ else:
254
+ params = {}
255
+
256
+ return [
257
+ (
258
+ SplitState(
259
+ dev_size=(
260
+ cumsum_table[-1] if not enable_optimizer_offloading else 0
261
+ ),
262
+ host_size=0,
263
+ uvm_size=0,
264
+ placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
265
+ offsets=cumsum_table[:-1],
266
+ ),
267
+ name,
268
+ self._extract_dtype(optimizer_state_dtypes, name),
269
+ )
270
+ for (name, cumsum_table) in params.items()
271
+ ]
272
+
273
+
274
+ # Base class for quantization configuration (in case other numeric types have
275
+ # configs)
276
+ class QuantizationConfig:
277
+ def __init__(self) -> None:
278
+ self.config = {} # type: Dict[str, Any]
279
+
280
+ def get(self, name: str) -> int:
281
+ return -1
282
+
283
+
284
+ # FP8 quantization configuration
285
+ # Compute necessary parameters in the constructor
286
+ class FP8QuantizationConfig(QuantizationConfig):
287
+ def __init__(self, exponent_bits: int, exponent_bias: int) -> None:
288
+ super(FP8QuantizationConfig, self).__init__()
289
+ self.config = {
290
+ "exponent_bits": exponent_bits,
291
+ "exponent_bias": exponent_bias,
292
+ "max_position": (1 << ((1 << exponent_bits) - 2 - exponent_bias))
293
+ * (2 - 2 ** (exponent_bits - 7)),
294
+ } # type: Dict[str, Any]
295
+
296
+ def get(self, name: str) -> int:
297
+ if name not in self.config:
298
+ raise RuntimeError("{} must be set in config".format(name))
299
+ return self.config[name]
300
+
301
+
302
+ def sparse_type_to_int(sparse_type: "SparseType") -> int:
303
+ return {
304
+ SparseType.FP32.value: 0,
305
+ SparseType.FP16.value: 1,
306
+ SparseType.INT8.value: 2,
307
+ SparseType.INT4.value: 3,
308
+ SparseType.INT2.value: 4,
309
+ SparseType.BF16.value: 5,
310
+ SparseType.FP8.value: 6,
311
+ SparseType.MX4.value: 7,
312
+ SparseType.NFP8.value: 8,
313
+ }[sparse_type.value]
314
+
315
+
316
+ @enum.unique
317
+ class SparseType(enum.Enum):
318
+ FP32 = "fp32"
319
+ FP16 = "fp16"
320
+ FP8 = "fp8"
321
+ # NFP8 refers to "native" FP8 in that it uses the GPU implementations
322
+ # of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
323
+ # NFP8 allows us to use hardware casting intrinsics which can be much faster.
324
+ # Eventually, we should merge these two types.
325
+ NFP8 = "nfp8"
326
+ INT8 = "int8"
327
+ INT4 = "int4"
328
+ INT2 = "int2"
329
+ BF16 = "bf16"
330
+ MX4 = "mx4"
331
+
332
+ def __str__(self) -> str:
333
+ return self.value
334
+
335
+ @staticmethod
336
+ def from_int(ty: int) -> "SparseType":
337
+ if ty == 0:
338
+ return SparseType("fp32")
339
+ elif ty == 1:
340
+ return SparseType("fp16")
341
+ elif ty == 2:
342
+ return SparseType("int8")
343
+ elif ty == 3:
344
+ return SparseType("int4")
345
+ elif ty == 4:
346
+ return SparseType("int2")
347
+ elif ty == 5:
348
+ return SparseType("bf16")
349
+ elif ty == 6:
350
+ return SparseType("fp8")
351
+ elif ty == 8:
352
+ return SparseType("mx4")
353
+ elif ty == 9:
354
+ return SparseType("nfp8")
355
+ else: # Invalid is 7 or non enumerated.
356
+ raise ValueError(f"Unsupported sparse type: {ty}")
357
+
358
+ def as_int(self) -> int:
359
+ return sparse_type_to_int(self)
360
+
361
+ @staticmethod
362
+ def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType":
363
+ if dtype == torch.float32:
364
+ return SparseType("fp32")
365
+ elif dtype == torch.float16:
366
+ return SparseType("fp16")
367
+ elif (dtype == torch.int8 or dtype == torch.uint8) and not is_mx:
368
+ return SparseType("int8")
369
+ elif dtype == torch.quint4x2:
370
+ return SparseType("int4")
371
+ elif dtype == torch.quint2x4:
372
+ return SparseType("int2")
373
+ elif dtype == torch.bfloat16:
374
+ return SparseType("bf16")
375
+ elif dtype == torch.uint8:
376
+ return SparseType("mx4")
377
+ elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
378
+ return SparseType("nfp8")
379
+ else:
380
+ raise ValueError(f"Unsupported sparse dtype: {dtype}")
381
+
382
+ def as_dtype(self) -> torch.dtype:
383
+ return {
384
+ SparseType.FP32.value: torch.float32,
385
+ SparseType.FP16.value: torch.float16,
386
+ SparseType.FP8.value: torch.uint8,
387
+ SparseType.INT8.value: torch.uint8,
388
+ SparseType.INT4.value: torch.quint4x2,
389
+ SparseType.INT2.value: torch.quint2x4,
390
+ SparseType.BF16.value: torch.bfloat16,
391
+ SparseType.MX4.value: torch.uint8,
392
+ SparseType.NFP8.value: (
393
+ torch.float8_e4m3fnuz
394
+ if torch.version.hip is not None
395
+ else torch.float8_e4m3fn
396
+ ),
397
+ }[self.value]
398
+
399
+ def bit_rate(self) -> int:
400
+ return {
401
+ SparseType.FP32.value: 32,
402
+ SparseType.FP16.value: 16,
403
+ SparseType.FP8.value: 8,
404
+ SparseType.INT8.value: 8,
405
+ SparseType.INT4.value: 4,
406
+ SparseType.INT2.value: 2,
407
+ SparseType.BF16.value: 16,
408
+ SparseType.MX4.value: 4,
409
+ SparseType.NFP8.value: 8,
410
+ }[self.value]
411
+
412
+ def align_size(self) -> int:
413
+ return {
414
+ SparseType.FP32.value: 1,
415
+ SparseType.FP16.value: 2,
416
+ SparseType.FP8.value: 4,
417
+ SparseType.INT8.value: 4,
418
+ SparseType.INT4.value: 8,
419
+ SparseType.INT2.value: 16,
420
+ SparseType.BF16.value: 2,
421
+ SparseType.MX4.value: 8,
422
+ SparseType.NFP8.value: 4,
423
+ }[self.value]
424
+
425
+ def is_float(self) -> bool:
426
+ if (
427
+ self.value == SparseType.FP32.value
428
+ or self.value == SparseType.FP16.value
429
+ or self.value == SparseType.FP8.value
430
+ or self.value == SparseType.BF16.value
431
+ or self.value == SparseType.NFP8.value
432
+ ):
433
+ return True
434
+ else:
435
+ return False
436
+
437
+ def default_config(self) -> QuantizationConfig:
438
+ if self.value == SparseType.FP8.value:
439
+ return FP8QuantizationConfig(4, 7)
440
+ else:
441
+ return QuantizationConfig()
442
+
443
+
444
+ ELEMENT_SIZE: dict[SparseType, int] = {
445
+ SparseType.FP32: 4,
446
+ SparseType.FP16: 2,
447
+ SparseType.FP8: 1,
448
+ SparseType.INT8: 1,
449
+ SparseType.BF16: 2,
450
+ SparseType.NFP8: 1,
451
+ # SparseType.INT4: 0.5,
452
+ }
@@ -0,0 +1,175 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+
11
+ import logging
12
+ import math
13
+ from typing import cast, Optional
14
+
15
+ import torch
16
+
17
+ from fbgemm_gpu.split_embedding_configs import (
18
+ FP8QuantizationConfig,
19
+ QuantizationConfig,
20
+ SparseType,
21
+ )
22
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
23
+ from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
24
+ IntNBitTableBatchedEmbeddingBagsCodegen,
25
+ )
26
+ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
27
+ ComputeDevice,
28
+ SplitTableBatchedEmbeddingBagsCodegen,
29
+ )
30
+ from fbgemm_gpu.tbe.utils import quantize_embs
31
+ from torch import Tensor # usort:skip
32
+
33
+
34
+ # TODO: add per-feature based converter option (based on embedding_specs during inference)
35
+ # TODO: optimize embedding pruning and quantization latency.
36
+ class SplitEmbInferenceConverter:
37
+ # pyre-fixme[3]: Return type must be annotated.
38
+ def __init__(
39
+ self,
40
+ quantize_type: SparseType,
41
+ pruning_ratio: Optional[float],
42
+ use_array_for_index_remapping: bool = True,
43
+ quantization_config: Optional[QuantizationConfig] = None,
44
+ ):
45
+ self.quantize_type = quantize_type
46
+ # TODO(yingz): Change the pruning ratio to per-table settings.
47
+ self.pruning_ratio = pruning_ratio
48
+ self.use_array_for_index_remapping = use_array_for_index_remapping
49
+ self.quantization_config = quantization_config
50
+
51
+ def convert_model(self, model: torch.nn.Module) -> torch.nn.Module:
52
+ self._process_split_embs(model)
53
+ return model
54
+
55
+ # pyre-fixme[2]: Parameter must be annotated.
56
+ def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
57
+ assert new_num_rows > 0
58
+ from numpy.linalg import norm
59
+
60
+ indicators = []
61
+ for row in weights:
62
+ indicators.append(norm(row.cpu().numpy(), ord=2))
63
+ sorted_indicators = sorted(indicators, reverse=True)
64
+ threshold = None
65
+ for i in range(new_num_rows, len(sorted_indicators)):
66
+ if sorted_indicators[i] < sorted_indicators[new_num_rows - 1]:
67
+ threshold = sorted_indicators[i]
68
+ break
69
+ if threshold is None:
70
+ threshold = sorted_indicators[-1] - 1
71
+ return (torch.tensor(indicators), threshold)
72
+
73
+ def _prune_embs(
74
+ self,
75
+ idx: int,
76
+ num_rows: int,
77
+ module: SplitTableBatchedEmbeddingBagsCodegen,
78
+ ) -> tuple[Tensor, Optional[Tensor]]:
79
+ # TODO(yingz): Avoid DtoH / HtoD overhead.
80
+ weights = module.split_embedding_weights()[idx].cpu()
81
+ if self.pruning_ratio is None:
82
+ return (weights, None)
83
+ new_num_rows = int(math.ceil(num_rows * (1.0 - self.pruning_ratio))) # type: ignore
84
+ if new_num_rows == num_rows:
85
+ return (weights, None)
86
+
87
+ (indicators, threshold) = self._prune_by_weights_l2_norm(new_num_rows, weights)
88
+
89
+ return torch.ops.fbgemm.embedding_bag_rowwise_prune(
90
+ weights, indicators, threshold, torch.int32
91
+ )
92
+
93
+ # pyre-fixme[3]: Return type must be annotated.
94
+ # pyre-fixme[2]: Parameter must be annotated.
95
+ def _get_quantization_config(self, name):
96
+ quantization_config = self.quantization_config
97
+ if quantization_config is None:
98
+ raise RuntimeError("quantization_config must be set for FP8 weight")
99
+ return quantization_config.get(name)
100
+
101
+ def _quantize_embs(
102
+ self, weight: Tensor, weight_ty: SparseType
103
+ ) -> tuple[Tensor, Optional[Tensor]]:
104
+ fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
105
+ return quantize_embs(weight, weight_ty, fp8_quant_config)
106
+
107
+ def _process_split_embs(self, model: torch.nn.Module) -> None:
108
+ for name, child in model.named_children():
109
+ if isinstance(
110
+ child,
111
+ SplitTableBatchedEmbeddingBagsCodegen,
112
+ ):
113
+ embedding_specs = []
114
+ use_cpu = child.embedding_specs[0][3] == ComputeDevice.CPU
115
+ for E, D, _, _ in child.embedding_specs:
116
+ weights_ty = self.quantize_type
117
+ if D % weights_ty.align_size() != 0:
118
+ logging.warning(
119
+ f"Embedding dim {D} couldn't be divided by align size {weights_ty.align_size()}!"
120
+ )
121
+ assert D % 4 == 0
122
+ weights_ty = (
123
+ SparseType.FP16
124
+ ) # fall back to FP16 if dimension couldn't be aligned with the required size
125
+ embedding_specs.append(("", E, D, weights_ty))
126
+
127
+ weight_lists = []
128
+ new_embedding_specs = []
129
+ index_remapping_list = []
130
+ for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
131
+ # Try to prune embeddings.
132
+ (pruned_weight, index_remapping) = self._prune_embs(t, E, child)
133
+ new_embedding_specs.append(
134
+ (
135
+ "",
136
+ pruned_weight.size()[0],
137
+ D,
138
+ weight_ty,
139
+ (
140
+ EmbeddingLocation.HOST
141
+ if use_cpu
142
+ else EmbeddingLocation.DEVICE
143
+ ),
144
+ )
145
+ )
146
+ index_remapping_list.append(index_remapping)
147
+
148
+ # Try to quantize embeddings.
149
+ weight_lists.append(self._quantize_embs(pruned_weight, weight_ty))
150
+
151
+ is_fp8_weight = self.quantize_type == SparseType.FP8
152
+
153
+ q_child = IntNBitTableBatchedEmbeddingBagsCodegen(
154
+ embedding_specs=new_embedding_specs,
155
+ index_remapping=(
156
+ index_remapping_list if self.pruning_ratio is not None else None
157
+ ),
158
+ pooling_mode=child.pooling_mode,
159
+ device="cpu" if use_cpu else torch.cuda.current_device(),
160
+ weight_lists=weight_lists,
161
+ use_array_for_index_remapping=self.use_array_for_index_remapping,
162
+ fp8_exponent_bits=(
163
+ self._get_quantization_config("exponent_bits")
164
+ if is_fp8_weight
165
+ else None
166
+ ),
167
+ fp8_exponent_bias=(
168
+ self._get_quantization_config("exponent_bias")
169
+ if is_fp8_weight
170
+ else None
171
+ ),
172
+ )
173
+ setattr(model, name, q_child)
174
+ else:
175
+ self._process_split_embs(child)
@@ -0,0 +1,21 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # flake8: noqa F401
11
+
12
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen
13
+ from fbgemm_gpu.split_embedding_optimizer_codegen.optimizer_args import (
14
+ SplitEmbeddingArgs,
15
+ SplitEmbeddingOptimizerParams,
16
+ )
17
+
18
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen
19
+ from fbgemm_gpu.split_embedding_optimizer_codegen.split_embedding_optimizer_rowwise_adagrad import (
20
+ SplitEmbeddingRowwiseAdagrad,
21
+ )
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import warnings
10
+
11
+ from fbgemm_gpu.tbe.utils import ( # noqa: F401
12
+ b_indices, # noqa: F401
13
+ fake_quantize_embs, # noqa: F401
14
+ generate_requests, # noqa: F401
15
+ get_device, # noqa: F401
16
+ get_table_batched_offsets_from_dense, # noqa: F401
17
+ quantize_embs, # noqa: F401
18
+ round_up, # noqa: F401
19
+ TBERequest, # noqa: F401
20
+ to_device, # noqa: F401
21
+ )
22
+
23
+ warnings.warn( # noqa: B028
24
+ f"""\033[93m
25
+ The Python module {__name__} is now DEPRECATED and will be removed in the
26
+ future. Users should import fbgemm_gpu.tbe.utils into their scripts instead.
27
+ \033[0m""",
28
+ DeprecationWarning,
29
+ )