fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import enum
|
|
11
|
+
import itertools
|
|
12
|
+
from typing import Any, Dict # noqa: F401
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
17
|
+
EmbeddingLocation,
|
|
18
|
+
SplitState,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def pad4(value: int) -> int:
|
|
23
|
+
"""
|
|
24
|
+
Compute the smallest multiple of 4 that is greater than or equal to the given value.
|
|
25
|
+
|
|
26
|
+
Parameters:
|
|
27
|
+
value (int): The integer to align (must be non-negative).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
int: The aligned value.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If the input is negative.
|
|
34
|
+
TypeError: If the input is not an integer.
|
|
35
|
+
"""
|
|
36
|
+
return (int(value) + 3) & ~3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def pad16(value: int) -> int:
|
|
40
|
+
"""
|
|
41
|
+
Compute the smallest multiple of 16 that is greater than or equal to the given value.
|
|
42
|
+
|
|
43
|
+
Parameters:
|
|
44
|
+
value (int): The integer to align (must be non-negative).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
int: The aligned value.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If the input is negative.
|
|
51
|
+
TypeError: If the input is not an integer.
|
|
52
|
+
"""
|
|
53
|
+
return (int(value) + 15) & ~15
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@enum.unique
|
|
57
|
+
class EmbOptimType(enum.Enum):
|
|
58
|
+
SGD = "sgd" # uses non-deterministic updates (atomicAdd(..)) with duplicate ids
|
|
59
|
+
EXACT_SGD = (
|
|
60
|
+
"exact_sgd" # uses deterministic updates (via sorting + segment reduction)
|
|
61
|
+
)
|
|
62
|
+
LAMB = "lamb"
|
|
63
|
+
ADAM = "adam"
|
|
64
|
+
# exact/dedup: gradients to the same row are applied with coalesce then apply
|
|
65
|
+
# together, instead of applied in sequence (approx).
|
|
66
|
+
EXACT_ADAGRAD = "exact_adagrad"
|
|
67
|
+
EXACT_ROWWISE_ADAGRAD = "exact_row_wise_adagrad"
|
|
68
|
+
LARS_SGD = "lars_sgd"
|
|
69
|
+
PARTIAL_ROWWISE_ADAM = "partial_row_wise_adam"
|
|
70
|
+
PARTIAL_ROWWISE_LAMB = "partial_row_wise_lamb"
|
|
71
|
+
ROWWISE_ADAGRAD = "row_wise_adagrad"
|
|
72
|
+
SHAMPOO = "shampoo" # not currently supported for sparse embedding tables
|
|
73
|
+
SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
|
|
74
|
+
MADGRAD = "madgrad"
|
|
75
|
+
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
|
|
76
|
+
ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
|
|
77
|
+
EMAINPLACE_ROWWISE_ADAGRAD = "ema_in_place_row_wise_adagrad"
|
|
78
|
+
NONE = "none"
|
|
79
|
+
|
|
80
|
+
def __str__(self) -> str:
|
|
81
|
+
return self.value
|
|
82
|
+
|
|
83
|
+
def _extract_dtype(
|
|
84
|
+
self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
|
|
85
|
+
) -> torch.dtype:
|
|
86
|
+
if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
|
|
87
|
+
return torch.float32
|
|
88
|
+
return optimizer_state_dtypes[name].as_dtype()
|
|
89
|
+
|
|
90
|
+
def state_names(self) -> list[str]:
|
|
91
|
+
"""
|
|
92
|
+
Returns the names of the optimizer states. The order of the states will
|
|
93
|
+
be the order in which they are processed and returned in
|
|
94
|
+
SSDTableBatchedEmbeddingBags.split_optimizer_states(), but this is not
|
|
95
|
+
necessarily the same as the order they are stored in the memory layout.
|
|
96
|
+
"""
|
|
97
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
98
|
+
return ["momentum1"]
|
|
99
|
+
elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
|
|
100
|
+
return ["momentum1", "momentum2"]
|
|
101
|
+
else:
|
|
102
|
+
return []
|
|
103
|
+
|
|
104
|
+
def state_size_table(self, D: int) -> dict[str, int]:
|
|
105
|
+
"""
|
|
106
|
+
Returns the table of state names to state sizes in terms of number of
|
|
107
|
+
elements (per table row)
|
|
108
|
+
"""
|
|
109
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
110
|
+
return {"momentum1": 1}
|
|
111
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
112
|
+
return {"momentum1": D, "momentum2": 1}
|
|
113
|
+
elif self == EmbOptimType.ADAM:
|
|
114
|
+
return {"momentum1": D, "momentum2": D}
|
|
115
|
+
else:
|
|
116
|
+
return {}
|
|
117
|
+
|
|
118
|
+
def state_size_nbytes(
|
|
119
|
+
self,
|
|
120
|
+
D: int,
|
|
121
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
122
|
+
) -> int:
|
|
123
|
+
"""
|
|
124
|
+
Returns the size of the data (in bytes) required to hold the optimizer
|
|
125
|
+
state (per table row). This size includes byte-padding.
|
|
126
|
+
"""
|
|
127
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
128
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
129
|
+
|
|
130
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
131
|
+
return momentum1_dtype.itemsize
|
|
132
|
+
|
|
133
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
134
|
+
return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
|
|
135
|
+
|
|
136
|
+
elif self == EmbOptimType.ADAM:
|
|
137
|
+
return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
return 0
|
|
141
|
+
|
|
142
|
+
def byte_offsets_along_row(
|
|
143
|
+
self,
|
|
144
|
+
D: int,
|
|
145
|
+
weights_precision: "SparseType",
|
|
146
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
147
|
+
) -> dict[str, tuple[int, int]]:
|
|
148
|
+
"""
|
|
149
|
+
Returns the start and end byte offsets of each optimizer state along a
|
|
150
|
+
cache row with optimizer state offloading enabled.
|
|
151
|
+
"""
|
|
152
|
+
# Extract the optimizer state dtypes
|
|
153
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
154
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
155
|
+
|
|
156
|
+
# This is the pointer to where the optimizer state begins in the memory
|
|
157
|
+
p0 = pad4(D) * weights_precision.as_dtype().itemsize
|
|
158
|
+
|
|
159
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
160
|
+
return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
|
|
161
|
+
|
|
162
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
163
|
+
# momentum1 lies after momentum2
|
|
164
|
+
p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
|
|
165
|
+
return {
|
|
166
|
+
"momentum2": (p0, p0 + momentum2_dtype.itemsize),
|
|
167
|
+
"momentum1": (
|
|
168
|
+
p1,
|
|
169
|
+
p1 + D * momentum1_dtype.itemsize,
|
|
170
|
+
),
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
elif self == EmbOptimType.ADAM:
|
|
174
|
+
# momentum2 lies after momentum1
|
|
175
|
+
p1 = p0 + (D * momentum1_dtype.itemsize)
|
|
176
|
+
|
|
177
|
+
return {
|
|
178
|
+
"momentum1": (p0, p1),
|
|
179
|
+
"momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
else:
|
|
183
|
+
return {}
|
|
184
|
+
|
|
185
|
+
def empty_states(
|
|
186
|
+
self,
|
|
187
|
+
rows: list[int],
|
|
188
|
+
dims: list[int],
|
|
189
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
190
|
+
) -> list[list[torch.Tensor]]:
|
|
191
|
+
"""
|
|
192
|
+
Creates sets of empty tensors per table to hold optimizer states based
|
|
193
|
+
on the specified optimizer type, state dtypes, embedding specs, and
|
|
194
|
+
(optionally) local row counts.
|
|
195
|
+
"""
|
|
196
|
+
# Else, check that the local row count for each table is set
|
|
197
|
+
assert len(rows) == len(dims)
|
|
198
|
+
|
|
199
|
+
opt_states_set: list[list[torch.Tensor]] = []
|
|
200
|
+
|
|
201
|
+
for r, D in zip(rows, dims):
|
|
202
|
+
# Set up the table of state names to state sizes, ordered by their
|
|
203
|
+
# memory layout
|
|
204
|
+
state_size_table = self.state_size_table(D)
|
|
205
|
+
ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]
|
|
206
|
+
|
|
207
|
+
# Create the optimizer states for this table
|
|
208
|
+
opt_states_set.append(
|
|
209
|
+
[
|
|
210
|
+
torch.empty(
|
|
211
|
+
# If the state size is 1, then fix tensor to 1D to be
|
|
212
|
+
# consistent with training.py code
|
|
213
|
+
# pyre-ignore [6]
|
|
214
|
+
(r, d) if d > 1 else r,
|
|
215
|
+
dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
|
|
216
|
+
device="cpu",
|
|
217
|
+
)
|
|
218
|
+
for state_name, d in ordered_state_sizes
|
|
219
|
+
]
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return opt_states_set
|
|
223
|
+
|
|
224
|
+
def ssd_state_splits(
|
|
225
|
+
self,
|
|
226
|
+
embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
|
|
227
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
228
|
+
enable_optimizer_offloading: bool = False,
|
|
229
|
+
) -> list[tuple[SplitState, str, torch.dtype]]:
|
|
230
|
+
"""
|
|
231
|
+
Returns the split planning for the optimizer states
|
|
232
|
+
"""
|
|
233
|
+
(rows, _) = zip(*embedding_specs)
|
|
234
|
+
T_ = len(embedding_specs)
|
|
235
|
+
|
|
236
|
+
# This is the cumulative row counts for rowwise states
|
|
237
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
238
|
+
# This is the cumulative element counts for elementwise states
|
|
239
|
+
table_size_cumsum: list[int] = [0] + list(
|
|
240
|
+
itertools.accumulate([r * d for r, d in embedding_specs])
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
244
|
+
params = {"momentum1": row_count_cumsum}
|
|
245
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
246
|
+
params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
|
|
247
|
+
elif self == EmbOptimType.ADAM:
|
|
248
|
+
params = {
|
|
249
|
+
"momentum1": table_size_cumsum,
|
|
250
|
+
"momentum2": table_size_cumsum,
|
|
251
|
+
"row_counter": row_count_cumsum,
|
|
252
|
+
}
|
|
253
|
+
else:
|
|
254
|
+
params = {}
|
|
255
|
+
|
|
256
|
+
return [
|
|
257
|
+
(
|
|
258
|
+
SplitState(
|
|
259
|
+
dev_size=(
|
|
260
|
+
cumsum_table[-1] if not enable_optimizer_offloading else 0
|
|
261
|
+
),
|
|
262
|
+
host_size=0,
|
|
263
|
+
uvm_size=0,
|
|
264
|
+
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
|
|
265
|
+
offsets=cumsum_table[:-1],
|
|
266
|
+
),
|
|
267
|
+
name,
|
|
268
|
+
self._extract_dtype(optimizer_state_dtypes, name),
|
|
269
|
+
)
|
|
270
|
+
for (name, cumsum_table) in params.items()
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Base class for quantization configuration (in case other numeric types have
|
|
275
|
+
# configs)
|
|
276
|
+
class QuantizationConfig:
|
|
277
|
+
def __init__(self) -> None:
|
|
278
|
+
self.config = {} # type: Dict[str, Any]
|
|
279
|
+
|
|
280
|
+
def get(self, name: str) -> int:
|
|
281
|
+
return -1
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# FP8 quantization configuration
|
|
285
|
+
# Compute necessary parameters in the constructor
|
|
286
|
+
class FP8QuantizationConfig(QuantizationConfig):
|
|
287
|
+
def __init__(self, exponent_bits: int, exponent_bias: int) -> None:
|
|
288
|
+
super(FP8QuantizationConfig, self).__init__()
|
|
289
|
+
self.config = {
|
|
290
|
+
"exponent_bits": exponent_bits,
|
|
291
|
+
"exponent_bias": exponent_bias,
|
|
292
|
+
"max_position": (1 << ((1 << exponent_bits) - 2 - exponent_bias))
|
|
293
|
+
* (2 - 2 ** (exponent_bits - 7)),
|
|
294
|
+
} # type: Dict[str, Any]
|
|
295
|
+
|
|
296
|
+
def get(self, name: str) -> int:
|
|
297
|
+
if name not in self.config:
|
|
298
|
+
raise RuntimeError("{} must be set in config".format(name))
|
|
299
|
+
return self.config[name]
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def sparse_type_to_int(sparse_type: "SparseType") -> int:
|
|
303
|
+
return {
|
|
304
|
+
SparseType.FP32.value: 0,
|
|
305
|
+
SparseType.FP16.value: 1,
|
|
306
|
+
SparseType.INT8.value: 2,
|
|
307
|
+
SparseType.INT4.value: 3,
|
|
308
|
+
SparseType.INT2.value: 4,
|
|
309
|
+
SparseType.BF16.value: 5,
|
|
310
|
+
SparseType.FP8.value: 6,
|
|
311
|
+
SparseType.MX4.value: 7,
|
|
312
|
+
SparseType.NFP8.value: 8,
|
|
313
|
+
}[sparse_type.value]
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@enum.unique
|
|
317
|
+
class SparseType(enum.Enum):
|
|
318
|
+
FP32 = "fp32"
|
|
319
|
+
FP16 = "fp16"
|
|
320
|
+
FP8 = "fp8"
|
|
321
|
+
# NFP8 refers to "native" FP8 in that it uses the GPU implementations
|
|
322
|
+
# of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
|
|
323
|
+
# NFP8 allows us to use hardware casting intrinsics which can be much faster.
|
|
324
|
+
# Eventually, we should merge these two types.
|
|
325
|
+
NFP8 = "nfp8"
|
|
326
|
+
INT8 = "int8"
|
|
327
|
+
INT4 = "int4"
|
|
328
|
+
INT2 = "int2"
|
|
329
|
+
BF16 = "bf16"
|
|
330
|
+
MX4 = "mx4"
|
|
331
|
+
|
|
332
|
+
def __str__(self) -> str:
|
|
333
|
+
return self.value
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def from_int(ty: int) -> "SparseType":
|
|
337
|
+
if ty == 0:
|
|
338
|
+
return SparseType("fp32")
|
|
339
|
+
elif ty == 1:
|
|
340
|
+
return SparseType("fp16")
|
|
341
|
+
elif ty == 2:
|
|
342
|
+
return SparseType("int8")
|
|
343
|
+
elif ty == 3:
|
|
344
|
+
return SparseType("int4")
|
|
345
|
+
elif ty == 4:
|
|
346
|
+
return SparseType("int2")
|
|
347
|
+
elif ty == 5:
|
|
348
|
+
return SparseType("bf16")
|
|
349
|
+
elif ty == 6:
|
|
350
|
+
return SparseType("fp8")
|
|
351
|
+
elif ty == 8:
|
|
352
|
+
return SparseType("mx4")
|
|
353
|
+
elif ty == 9:
|
|
354
|
+
return SparseType("nfp8")
|
|
355
|
+
else: # Invalid is 7 or non enumerated.
|
|
356
|
+
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
357
|
+
|
|
358
|
+
def as_int(self) -> int:
|
|
359
|
+
return sparse_type_to_int(self)
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType":
|
|
363
|
+
if dtype == torch.float32:
|
|
364
|
+
return SparseType("fp32")
|
|
365
|
+
elif dtype == torch.float16:
|
|
366
|
+
return SparseType("fp16")
|
|
367
|
+
elif (dtype == torch.int8 or dtype == torch.uint8) and not is_mx:
|
|
368
|
+
return SparseType("int8")
|
|
369
|
+
elif dtype == torch.quint4x2:
|
|
370
|
+
return SparseType("int4")
|
|
371
|
+
elif dtype == torch.quint2x4:
|
|
372
|
+
return SparseType("int2")
|
|
373
|
+
elif dtype == torch.bfloat16:
|
|
374
|
+
return SparseType("bf16")
|
|
375
|
+
elif dtype == torch.uint8:
|
|
376
|
+
return SparseType("mx4")
|
|
377
|
+
elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
|
|
378
|
+
return SparseType("nfp8")
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Unsupported sparse dtype: {dtype}")
|
|
381
|
+
|
|
382
|
+
def as_dtype(self) -> torch.dtype:
|
|
383
|
+
return {
|
|
384
|
+
SparseType.FP32.value: torch.float32,
|
|
385
|
+
SparseType.FP16.value: torch.float16,
|
|
386
|
+
SparseType.FP8.value: torch.uint8,
|
|
387
|
+
SparseType.INT8.value: torch.uint8,
|
|
388
|
+
SparseType.INT4.value: torch.quint4x2,
|
|
389
|
+
SparseType.INT2.value: torch.quint2x4,
|
|
390
|
+
SparseType.BF16.value: torch.bfloat16,
|
|
391
|
+
SparseType.MX4.value: torch.uint8,
|
|
392
|
+
SparseType.NFP8.value: (
|
|
393
|
+
torch.float8_e4m3fnuz
|
|
394
|
+
if torch.version.hip is not None
|
|
395
|
+
else torch.float8_e4m3fn
|
|
396
|
+
),
|
|
397
|
+
}[self.value]
|
|
398
|
+
|
|
399
|
+
def bit_rate(self) -> int:
|
|
400
|
+
return {
|
|
401
|
+
SparseType.FP32.value: 32,
|
|
402
|
+
SparseType.FP16.value: 16,
|
|
403
|
+
SparseType.FP8.value: 8,
|
|
404
|
+
SparseType.INT8.value: 8,
|
|
405
|
+
SparseType.INT4.value: 4,
|
|
406
|
+
SparseType.INT2.value: 2,
|
|
407
|
+
SparseType.BF16.value: 16,
|
|
408
|
+
SparseType.MX4.value: 4,
|
|
409
|
+
SparseType.NFP8.value: 8,
|
|
410
|
+
}[self.value]
|
|
411
|
+
|
|
412
|
+
def align_size(self) -> int:
|
|
413
|
+
return {
|
|
414
|
+
SparseType.FP32.value: 1,
|
|
415
|
+
SparseType.FP16.value: 2,
|
|
416
|
+
SparseType.FP8.value: 4,
|
|
417
|
+
SparseType.INT8.value: 4,
|
|
418
|
+
SparseType.INT4.value: 8,
|
|
419
|
+
SparseType.INT2.value: 16,
|
|
420
|
+
SparseType.BF16.value: 2,
|
|
421
|
+
SparseType.MX4.value: 8,
|
|
422
|
+
SparseType.NFP8.value: 4,
|
|
423
|
+
}[self.value]
|
|
424
|
+
|
|
425
|
+
def is_float(self) -> bool:
|
|
426
|
+
if (
|
|
427
|
+
self.value == SparseType.FP32.value
|
|
428
|
+
or self.value == SparseType.FP16.value
|
|
429
|
+
or self.value == SparseType.FP8.value
|
|
430
|
+
or self.value == SparseType.BF16.value
|
|
431
|
+
or self.value == SparseType.NFP8.value
|
|
432
|
+
):
|
|
433
|
+
return True
|
|
434
|
+
else:
|
|
435
|
+
return False
|
|
436
|
+
|
|
437
|
+
def default_config(self) -> QuantizationConfig:
|
|
438
|
+
if self.value == SparseType.FP8.value:
|
|
439
|
+
return FP8QuantizationConfig(4, 7)
|
|
440
|
+
else:
|
|
441
|
+
return QuantizationConfig()
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
ELEMENT_SIZE: dict[SparseType, int] = {
|
|
445
|
+
SparseType.FP32: 4,
|
|
446
|
+
SparseType.FP16: 2,
|
|
447
|
+
SparseType.FP8: 1,
|
|
448
|
+
SparseType.INT8: 1,
|
|
449
|
+
SparseType.BF16: 2,
|
|
450
|
+
SparseType.NFP8: 1,
|
|
451
|
+
# SparseType.INT4: 0.5,
|
|
452
|
+
}
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
from typing import cast, Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from fbgemm_gpu.split_embedding_configs import (
|
|
18
|
+
FP8QuantizationConfig,
|
|
19
|
+
QuantizationConfig,
|
|
20
|
+
SparseType,
|
|
21
|
+
)
|
|
22
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
|
|
23
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
|
|
24
|
+
IntNBitTableBatchedEmbeddingBagsCodegen,
|
|
25
|
+
)
|
|
26
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
27
|
+
ComputeDevice,
|
|
28
|
+
SplitTableBatchedEmbeddingBagsCodegen,
|
|
29
|
+
)
|
|
30
|
+
from fbgemm_gpu.tbe.utils import quantize_embs
|
|
31
|
+
from torch import Tensor # usort:skip
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# TODO: add per-feature based converter option (based on embedding_specs during inference)
|
|
35
|
+
# TODO: optimize embedding pruning and quantization latency.
|
|
36
|
+
class SplitEmbInferenceConverter:
|
|
37
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
quantize_type: SparseType,
|
|
41
|
+
pruning_ratio: Optional[float],
|
|
42
|
+
use_array_for_index_remapping: bool = True,
|
|
43
|
+
quantization_config: Optional[QuantizationConfig] = None,
|
|
44
|
+
):
|
|
45
|
+
self.quantize_type = quantize_type
|
|
46
|
+
# TODO(yingz): Change the pruning ratio to per-table settings.
|
|
47
|
+
self.pruning_ratio = pruning_ratio
|
|
48
|
+
self.use_array_for_index_remapping = use_array_for_index_remapping
|
|
49
|
+
self.quantization_config = quantization_config
|
|
50
|
+
|
|
51
|
+
def convert_model(self, model: torch.nn.Module) -> torch.nn.Module:
|
|
52
|
+
self._process_split_embs(model)
|
|
53
|
+
return model
|
|
54
|
+
|
|
55
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
56
|
+
def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
|
|
57
|
+
assert new_num_rows > 0
|
|
58
|
+
from numpy.linalg import norm
|
|
59
|
+
|
|
60
|
+
indicators = []
|
|
61
|
+
for row in weights:
|
|
62
|
+
indicators.append(norm(row.cpu().numpy(), ord=2))
|
|
63
|
+
sorted_indicators = sorted(indicators, reverse=True)
|
|
64
|
+
threshold = None
|
|
65
|
+
for i in range(new_num_rows, len(sorted_indicators)):
|
|
66
|
+
if sorted_indicators[i] < sorted_indicators[new_num_rows - 1]:
|
|
67
|
+
threshold = sorted_indicators[i]
|
|
68
|
+
break
|
|
69
|
+
if threshold is None:
|
|
70
|
+
threshold = sorted_indicators[-1] - 1
|
|
71
|
+
return (torch.tensor(indicators), threshold)
|
|
72
|
+
|
|
73
|
+
def _prune_embs(
|
|
74
|
+
self,
|
|
75
|
+
idx: int,
|
|
76
|
+
num_rows: int,
|
|
77
|
+
module: SplitTableBatchedEmbeddingBagsCodegen,
|
|
78
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
79
|
+
# TODO(yingz): Avoid DtoH / HtoD overhead.
|
|
80
|
+
weights = module.split_embedding_weights()[idx].cpu()
|
|
81
|
+
if self.pruning_ratio is None:
|
|
82
|
+
return (weights, None)
|
|
83
|
+
new_num_rows = int(math.ceil(num_rows * (1.0 - self.pruning_ratio))) # type: ignore
|
|
84
|
+
if new_num_rows == num_rows:
|
|
85
|
+
return (weights, None)
|
|
86
|
+
|
|
87
|
+
(indicators, threshold) = self._prune_by_weights_l2_norm(new_num_rows, weights)
|
|
88
|
+
|
|
89
|
+
return torch.ops.fbgemm.embedding_bag_rowwise_prune(
|
|
90
|
+
weights, indicators, threshold, torch.int32
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
94
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
95
|
+
def _get_quantization_config(self, name):
|
|
96
|
+
quantization_config = self.quantization_config
|
|
97
|
+
if quantization_config is None:
|
|
98
|
+
raise RuntimeError("quantization_config must be set for FP8 weight")
|
|
99
|
+
return quantization_config.get(name)
|
|
100
|
+
|
|
101
|
+
def _quantize_embs(
|
|
102
|
+
self, weight: Tensor, weight_ty: SparseType
|
|
103
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
104
|
+
fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
|
|
105
|
+
return quantize_embs(weight, weight_ty, fp8_quant_config)
|
|
106
|
+
|
|
107
|
+
def _process_split_embs(self, model: torch.nn.Module) -> None:
|
|
108
|
+
for name, child in model.named_children():
|
|
109
|
+
if isinstance(
|
|
110
|
+
child,
|
|
111
|
+
SplitTableBatchedEmbeddingBagsCodegen,
|
|
112
|
+
):
|
|
113
|
+
embedding_specs = []
|
|
114
|
+
use_cpu = child.embedding_specs[0][3] == ComputeDevice.CPU
|
|
115
|
+
for E, D, _, _ in child.embedding_specs:
|
|
116
|
+
weights_ty = self.quantize_type
|
|
117
|
+
if D % weights_ty.align_size() != 0:
|
|
118
|
+
logging.warning(
|
|
119
|
+
f"Embedding dim {D} couldn't be divided by align size {weights_ty.align_size()}!"
|
|
120
|
+
)
|
|
121
|
+
assert D % 4 == 0
|
|
122
|
+
weights_ty = (
|
|
123
|
+
SparseType.FP16
|
|
124
|
+
) # fall back to FP16 if dimension couldn't be aligned with the required size
|
|
125
|
+
embedding_specs.append(("", E, D, weights_ty))
|
|
126
|
+
|
|
127
|
+
weight_lists = []
|
|
128
|
+
new_embedding_specs = []
|
|
129
|
+
index_remapping_list = []
|
|
130
|
+
for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
|
|
131
|
+
# Try to prune embeddings.
|
|
132
|
+
(pruned_weight, index_remapping) = self._prune_embs(t, E, child)
|
|
133
|
+
new_embedding_specs.append(
|
|
134
|
+
(
|
|
135
|
+
"",
|
|
136
|
+
pruned_weight.size()[0],
|
|
137
|
+
D,
|
|
138
|
+
weight_ty,
|
|
139
|
+
(
|
|
140
|
+
EmbeddingLocation.HOST
|
|
141
|
+
if use_cpu
|
|
142
|
+
else EmbeddingLocation.DEVICE
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
index_remapping_list.append(index_remapping)
|
|
147
|
+
|
|
148
|
+
# Try to quantize embeddings.
|
|
149
|
+
weight_lists.append(self._quantize_embs(pruned_weight, weight_ty))
|
|
150
|
+
|
|
151
|
+
is_fp8_weight = self.quantize_type == SparseType.FP8
|
|
152
|
+
|
|
153
|
+
q_child = IntNBitTableBatchedEmbeddingBagsCodegen(
|
|
154
|
+
embedding_specs=new_embedding_specs,
|
|
155
|
+
index_remapping=(
|
|
156
|
+
index_remapping_list if self.pruning_ratio is not None else None
|
|
157
|
+
),
|
|
158
|
+
pooling_mode=child.pooling_mode,
|
|
159
|
+
device="cpu" if use_cpu else torch.cuda.current_device(),
|
|
160
|
+
weight_lists=weight_lists,
|
|
161
|
+
use_array_for_index_remapping=self.use_array_for_index_remapping,
|
|
162
|
+
fp8_exponent_bits=(
|
|
163
|
+
self._get_quantization_config("exponent_bits")
|
|
164
|
+
if is_fp8_weight
|
|
165
|
+
else None
|
|
166
|
+
),
|
|
167
|
+
fp8_exponent_bias=(
|
|
168
|
+
self._get_quantization_config("exponent_bias")
|
|
169
|
+
if is_fp8_weight
|
|
170
|
+
else None
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
setattr(model, name, q_child)
|
|
174
|
+
else:
|
|
175
|
+
self._process_split_embs(child)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# flake8: noqa F401
|
|
11
|
+
|
|
12
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen
|
|
13
|
+
from fbgemm_gpu.split_embedding_optimizer_codegen.optimizer_args import (
|
|
14
|
+
SplitEmbeddingArgs,
|
|
15
|
+
SplitEmbeddingOptimizerParams,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen
|
|
19
|
+
from fbgemm_gpu.split_embedding_optimizer_codegen.split_embedding_optimizer_rowwise_adagrad import (
|
|
20
|
+
SplitEmbeddingRowwiseAdagrad,
|
|
21
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from fbgemm_gpu.tbe.utils import ( # noqa: F401
|
|
12
|
+
b_indices, # noqa: F401
|
|
13
|
+
fake_quantize_embs, # noqa: F401
|
|
14
|
+
generate_requests, # noqa: F401
|
|
15
|
+
get_device, # noqa: F401
|
|
16
|
+
get_table_batched_offsets_from_dense, # noqa: F401
|
|
17
|
+
quantize_embs, # noqa: F401
|
|
18
|
+
round_up, # noqa: F401
|
|
19
|
+
TBERequest, # noqa: F401
|
|
20
|
+
to_device, # noqa: F401
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
warnings.warn( # noqa: B028
|
|
24
|
+
f"""\033[93m
|
|
25
|
+
The Python module {__name__} is now DEPRECATED and will be removed in the
|
|
26
|
+
future. Users should import fbgemm_gpu.tbe.utils into their scripts instead.
|
|
27
|
+
\033[0m""",
|
|
28
|
+
DeprecationWarning,
|
|
29
|
+
)
|