fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -8,10 +8,51 @@
|
|
|
8
8
|
# pyre-strict
|
|
9
9
|
|
|
10
10
|
import enum
|
|
11
|
+
import itertools
|
|
11
12
|
from typing import Any, Dict # noqa: F401
|
|
12
13
|
|
|
13
14
|
import torch
|
|
14
15
|
|
|
16
|
+
# fmt:skip
|
|
17
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
18
|
+
EmbeddingLocation,
|
|
19
|
+
SplitState,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def pad4(value: int) -> int:
|
|
24
|
+
"""
|
|
25
|
+
Compute the smallest multiple of 4 that is greater than or equal to the given value.
|
|
26
|
+
|
|
27
|
+
Parameters:
|
|
28
|
+
value (int): The integer to align (must be non-negative).
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
int: The aligned value.
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If the input is negative.
|
|
35
|
+
TypeError: If the input is not an integer.
|
|
36
|
+
"""
|
|
37
|
+
return (int(value) + 3) & ~3
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def pad16(value: int) -> int:
|
|
41
|
+
"""
|
|
42
|
+
Compute the smallest multiple of 16 that is greater than or equal to the given value.
|
|
43
|
+
|
|
44
|
+
Parameters:
|
|
45
|
+
value (int): The integer to align (must be non-negative).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
int: The aligned value.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If the input is negative.
|
|
52
|
+
TypeError: If the input is not an integer.
|
|
53
|
+
"""
|
|
54
|
+
return (int(value) + 15) & ~15
|
|
55
|
+
|
|
15
56
|
|
|
16
57
|
@enum.unique
|
|
17
58
|
class EmbOptimType(enum.Enum):
|
|
@@ -40,6 +81,196 @@ class EmbOptimType(enum.Enum):
|
|
|
40
81
|
def __str__(self) -> str:
|
|
41
82
|
return self.value
|
|
42
83
|
|
|
84
|
+
def _extract_dtype(
|
|
85
|
+
self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
|
|
86
|
+
) -> torch.dtype:
|
|
87
|
+
if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
|
|
88
|
+
return torch.float32
|
|
89
|
+
return optimizer_state_dtypes[name].as_dtype()
|
|
90
|
+
|
|
91
|
+
def state_names(self) -> list[str]:
|
|
92
|
+
"""
|
|
93
|
+
Returns the names of the optimizer states. The order of the states will
|
|
94
|
+
be the order in which they are processed and returned in
|
|
95
|
+
SSDTableBatchedEmbeddingBags.split_optimizer_states(), but this is not
|
|
96
|
+
necessarily the same as the order they are stored in the memory layout.
|
|
97
|
+
"""
|
|
98
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
99
|
+
return ["momentum1"]
|
|
100
|
+
elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
|
|
101
|
+
return ["momentum1", "momentum2"]
|
|
102
|
+
else:
|
|
103
|
+
return []
|
|
104
|
+
|
|
105
|
+
def state_size_table(self, D: int) -> dict[str, int]:
|
|
106
|
+
"""
|
|
107
|
+
Returns the table of state names to state sizes in terms of number of
|
|
108
|
+
elements (per table row)
|
|
109
|
+
"""
|
|
110
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
111
|
+
return {"momentum1": 1}
|
|
112
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
113
|
+
return {"momentum1": D, "momentum2": 1}
|
|
114
|
+
elif self == EmbOptimType.ADAM:
|
|
115
|
+
return {"momentum1": D, "momentum2": D}
|
|
116
|
+
else:
|
|
117
|
+
return {}
|
|
118
|
+
|
|
119
|
+
def state_size_nbytes(
|
|
120
|
+
self,
|
|
121
|
+
D: int,
|
|
122
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
123
|
+
) -> int:
|
|
124
|
+
"""
|
|
125
|
+
Returns the size of the data (in bytes) required to hold the optimizer
|
|
126
|
+
state (per table row). This size includes byte-padding.
|
|
127
|
+
"""
|
|
128
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
129
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
130
|
+
|
|
131
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
132
|
+
return momentum1_dtype.itemsize
|
|
133
|
+
|
|
134
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
135
|
+
return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
|
|
136
|
+
|
|
137
|
+
elif self == EmbOptimType.ADAM:
|
|
138
|
+
return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
return 0
|
|
142
|
+
|
|
143
|
+
def byte_offsets_along_row(
|
|
144
|
+
self,
|
|
145
|
+
D: int,
|
|
146
|
+
weights_precision: "SparseType",
|
|
147
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
148
|
+
) -> dict[str, tuple[int, int]]:
|
|
149
|
+
"""
|
|
150
|
+
Returns the start and end byte offsets of each optimizer state along a
|
|
151
|
+
cache row with optimizer state offloading enabled.
|
|
152
|
+
"""
|
|
153
|
+
# Extract the optimizer state dtypes
|
|
154
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
155
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
156
|
+
|
|
157
|
+
# This is the pointer to where the optimizer state begins in the memory
|
|
158
|
+
p0 = pad4(D) * weights_precision.as_dtype().itemsize
|
|
159
|
+
|
|
160
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
161
|
+
return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
|
|
162
|
+
|
|
163
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
164
|
+
# momentum1 lies after momentum2
|
|
165
|
+
p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
|
|
166
|
+
return {
|
|
167
|
+
"momentum2": (p0, p0 + momentum2_dtype.itemsize),
|
|
168
|
+
"momentum1": (
|
|
169
|
+
p1,
|
|
170
|
+
p1 + D * momentum1_dtype.itemsize,
|
|
171
|
+
),
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
elif self == EmbOptimType.ADAM:
|
|
175
|
+
# momentum2 lies after momentum1
|
|
176
|
+
p1 = p0 + (D * momentum1_dtype.itemsize)
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
"momentum1": (p0, p1),
|
|
180
|
+
"momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
else:
|
|
184
|
+
return {}
|
|
185
|
+
|
|
186
|
+
def empty_states(
|
|
187
|
+
self,
|
|
188
|
+
rows: list[int],
|
|
189
|
+
dims: list[int],
|
|
190
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
191
|
+
) -> list[list[torch.Tensor]]:
|
|
192
|
+
"""
|
|
193
|
+
Creates sets of empty tensors per table to hold optimizer states based
|
|
194
|
+
on the specified optimizer type, state dtypes, embedding specs, and
|
|
195
|
+
(optionally) local row counts.
|
|
196
|
+
"""
|
|
197
|
+
# Else, check that the local row count for each table is set
|
|
198
|
+
assert len(rows) == len(dims)
|
|
199
|
+
|
|
200
|
+
opt_states_set: list[list[torch.Tensor]] = []
|
|
201
|
+
|
|
202
|
+
for r, D in zip(rows, dims):
|
|
203
|
+
# Set up the table of state names to state sizes, ordered by their
|
|
204
|
+
# memory layout
|
|
205
|
+
state_size_table = self.state_size_table(D)
|
|
206
|
+
ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]
|
|
207
|
+
|
|
208
|
+
# Create the optimizer states for this table
|
|
209
|
+
opt_states_set.append(
|
|
210
|
+
[
|
|
211
|
+
torch.empty(
|
|
212
|
+
# If the state size is 1, then fix tensor to 1D to be
|
|
213
|
+
# consistent with training.py code
|
|
214
|
+
# pyre-ignore [6]
|
|
215
|
+
(r, d) if d > 1 else r,
|
|
216
|
+
dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
|
|
217
|
+
device="cpu",
|
|
218
|
+
)
|
|
219
|
+
for state_name, d in ordered_state_sizes
|
|
220
|
+
]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return opt_states_set
|
|
224
|
+
|
|
225
|
+
def ssd_state_splits(
|
|
226
|
+
self,
|
|
227
|
+
embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
|
|
228
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
229
|
+
enable_optimizer_offloading: bool = False,
|
|
230
|
+
) -> list[tuple[SplitState, str, torch.dtype]]:
|
|
231
|
+
"""
|
|
232
|
+
Returns the split planning for the optimizer states
|
|
233
|
+
"""
|
|
234
|
+
rows, _ = zip(*embedding_specs)
|
|
235
|
+
T_ = len(embedding_specs)
|
|
236
|
+
|
|
237
|
+
# This is the cumulative row counts for rowwise states
|
|
238
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
239
|
+
# This is the cumulative element counts for elementwise states
|
|
240
|
+
table_size_cumsum: list[int] = [0] + list(
|
|
241
|
+
itertools.accumulate([r * d for r, d in embedding_specs])
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
245
|
+
params = {"momentum1": row_count_cumsum}
|
|
246
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
247
|
+
params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
|
|
248
|
+
elif self == EmbOptimType.ADAM:
|
|
249
|
+
params = {
|
|
250
|
+
"momentum1": table_size_cumsum,
|
|
251
|
+
"momentum2": table_size_cumsum,
|
|
252
|
+
"row_counter": row_count_cumsum,
|
|
253
|
+
}
|
|
254
|
+
else:
|
|
255
|
+
params = {}
|
|
256
|
+
|
|
257
|
+
return [
|
|
258
|
+
(
|
|
259
|
+
SplitState(
|
|
260
|
+
dev_size=(
|
|
261
|
+
cumsum_table[-1] if not enable_optimizer_offloading else 0
|
|
262
|
+
),
|
|
263
|
+
host_size=0,
|
|
264
|
+
uvm_size=0,
|
|
265
|
+
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
|
|
266
|
+
offsets=cumsum_table[:-1],
|
|
267
|
+
),
|
|
268
|
+
name,
|
|
269
|
+
self._extract_dtype(optimizer_state_dtypes, name),
|
|
270
|
+
)
|
|
271
|
+
for (name, cumsum_table) in params.items()
|
|
272
|
+
]
|
|
273
|
+
|
|
43
274
|
|
|
44
275
|
# Base class for quantization configuration (in case other numeric types have
|
|
45
276
|
# configs)
|
|
@@ -79,14 +310,54 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
|
|
|
79
310
|
SparseType.BF16.value: 5,
|
|
80
311
|
SparseType.FP8.value: 6,
|
|
81
312
|
SparseType.MX4.value: 7,
|
|
313
|
+
SparseType.NFP8.value: 8,
|
|
82
314
|
}[sparse_type.value]
|
|
83
315
|
|
|
84
316
|
|
|
317
|
+
def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
|
|
318
|
+
"""
|
|
319
|
+
TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
|
|
320
|
+
|
|
321
|
+
This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
|
|
322
|
+
with TorchScript. TorchScript does not support @staticmethod on Enum classes,
|
|
323
|
+
so this function provides a workaround.
|
|
324
|
+
"""
|
|
325
|
+
if ty == 0: # fp32
|
|
326
|
+
return torch.float32
|
|
327
|
+
elif ty == 1: # fp16
|
|
328
|
+
return torch.float16
|
|
329
|
+
elif ty == 2: # int8
|
|
330
|
+
return torch.uint8
|
|
331
|
+
elif ty == 3: # int4
|
|
332
|
+
return torch.quint4x2
|
|
333
|
+
elif ty == 4: # int2
|
|
334
|
+
return torch.quint2x4
|
|
335
|
+
elif ty == 5: # bf16
|
|
336
|
+
return torch.bfloat16
|
|
337
|
+
elif ty == 6: # fp8
|
|
338
|
+
return torch.uint8
|
|
339
|
+
elif ty == 7: # mx4
|
|
340
|
+
return torch.uint8
|
|
341
|
+
elif ty == 9:
|
|
342
|
+
return (
|
|
343
|
+
torch.float8_e4m3fnuz
|
|
344
|
+
if torch.version.hip is not None
|
|
345
|
+
else torch.float8_e4m3fn
|
|
346
|
+
)
|
|
347
|
+
else: # Invalid is 7 or non enumerated.
|
|
348
|
+
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
349
|
+
|
|
350
|
+
|
|
85
351
|
@enum.unique
|
|
86
352
|
class SparseType(enum.Enum):
|
|
87
353
|
FP32 = "fp32"
|
|
88
354
|
FP16 = "fp16"
|
|
89
355
|
FP8 = "fp8"
|
|
356
|
+
# NFP8 refers to "native" FP8 in that it uses the GPU implementations
|
|
357
|
+
# of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
|
|
358
|
+
# NFP8 allows us to use hardware casting intrinsics which can be much faster.
|
|
359
|
+
# Eventually, we should merge these two types.
|
|
360
|
+
NFP8 = "nfp8"
|
|
90
361
|
INT8 = "int8"
|
|
91
362
|
INT4 = "int4"
|
|
92
363
|
INT2 = "int2"
|
|
@@ -112,9 +383,11 @@ class SparseType(enum.Enum):
|
|
|
112
383
|
return SparseType("bf16")
|
|
113
384
|
elif ty == 6:
|
|
114
385
|
return SparseType("fp8")
|
|
115
|
-
elif ty ==
|
|
386
|
+
elif ty == 8:
|
|
116
387
|
return SparseType("mx4")
|
|
117
|
-
|
|
388
|
+
elif ty == 9:
|
|
389
|
+
return SparseType("nfp8")
|
|
390
|
+
else: # Invalid is 7 or non enumerated.
|
|
118
391
|
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
119
392
|
|
|
120
393
|
def as_int(self) -> int:
|
|
@@ -136,6 +409,8 @@ class SparseType(enum.Enum):
|
|
|
136
409
|
return SparseType("bf16")
|
|
137
410
|
elif dtype == torch.uint8:
|
|
138
411
|
return SparseType("mx4")
|
|
412
|
+
elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
|
|
413
|
+
return SparseType("nfp8")
|
|
139
414
|
else:
|
|
140
415
|
raise ValueError(f"Unsupported sparse dtype: {dtype}")
|
|
141
416
|
|
|
@@ -149,6 +424,11 @@ class SparseType(enum.Enum):
|
|
|
149
424
|
SparseType.INT2.value: torch.quint2x4,
|
|
150
425
|
SparseType.BF16.value: torch.bfloat16,
|
|
151
426
|
SparseType.MX4.value: torch.uint8,
|
|
427
|
+
SparseType.NFP8.value: (
|
|
428
|
+
torch.float8_e4m3fnuz
|
|
429
|
+
if torch.version.hip is not None
|
|
430
|
+
else torch.float8_e4m3fn
|
|
431
|
+
),
|
|
152
432
|
}[self.value]
|
|
153
433
|
|
|
154
434
|
def bit_rate(self) -> int:
|
|
@@ -161,6 +441,7 @@ class SparseType(enum.Enum):
|
|
|
161
441
|
SparseType.INT2.value: 2,
|
|
162
442
|
SparseType.BF16.value: 16,
|
|
163
443
|
SparseType.MX4.value: 4,
|
|
444
|
+
SparseType.NFP8.value: 8,
|
|
164
445
|
}[self.value]
|
|
165
446
|
|
|
166
447
|
def align_size(self) -> int:
|
|
@@ -173,6 +454,7 @@ class SparseType(enum.Enum):
|
|
|
173
454
|
SparseType.INT2.value: 16,
|
|
174
455
|
SparseType.BF16.value: 2,
|
|
175
456
|
SparseType.MX4.value: 8,
|
|
457
|
+
SparseType.NFP8.value: 4,
|
|
176
458
|
}[self.value]
|
|
177
459
|
|
|
178
460
|
def is_float(self) -> bool:
|
|
@@ -181,6 +463,7 @@ class SparseType(enum.Enum):
|
|
|
181
463
|
or self.value == SparseType.FP16.value
|
|
182
464
|
or self.value == SparseType.FP8.value
|
|
183
465
|
or self.value == SparseType.BF16.value
|
|
466
|
+
or self.value == SparseType.NFP8.value
|
|
184
467
|
):
|
|
185
468
|
return True
|
|
186
469
|
else:
|
|
@@ -193,11 +476,12 @@ class SparseType(enum.Enum):
|
|
|
193
476
|
return QuantizationConfig()
|
|
194
477
|
|
|
195
478
|
|
|
196
|
-
ELEMENT_SIZE:
|
|
479
|
+
ELEMENT_SIZE: dict[SparseType, int] = {
|
|
197
480
|
SparseType.FP32: 4,
|
|
198
481
|
SparseType.FP16: 2,
|
|
199
482
|
SparseType.FP8: 1,
|
|
200
483
|
SparseType.INT8: 1,
|
|
201
484
|
SparseType.BF16: 2,
|
|
485
|
+
SparseType.NFP8: 1,
|
|
202
486
|
# SparseType.INT4: 0.5,
|
|
203
487
|
}
|
|
@@ -10,10 +10,11 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
import math
|
|
13
|
-
from typing import cast, Optional
|
|
13
|
+
from typing import cast, Optional
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
+
# fmt:skip
|
|
17
18
|
from fbgemm_gpu.split_embedding_configs import (
|
|
18
19
|
FP8QuantizationConfig,
|
|
19
20
|
QuantizationConfig,
|
|
@@ -53,7 +54,7 @@ class SplitEmbInferenceConverter:
|
|
|
53
54
|
return model
|
|
54
55
|
|
|
55
56
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
56
|
-
def _prune_by_weights_l2_norm(self, new_num_rows, weights) ->
|
|
57
|
+
def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
|
|
57
58
|
assert new_num_rows > 0
|
|
58
59
|
from numpy.linalg import norm
|
|
59
60
|
|
|
@@ -75,7 +76,7 @@ class SplitEmbInferenceConverter:
|
|
|
75
76
|
idx: int,
|
|
76
77
|
num_rows: int,
|
|
77
78
|
module: SplitTableBatchedEmbeddingBagsCodegen,
|
|
78
|
-
) ->
|
|
79
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
79
80
|
# TODO(yingz): Avoid DtoH / HtoD overhead.
|
|
80
81
|
weights = module.split_embedding_weights()[idx].cpu()
|
|
81
82
|
if self.pruning_ratio is None:
|
|
@@ -84,7 +85,7 @@ class SplitEmbInferenceConverter:
|
|
|
84
85
|
if new_num_rows == num_rows:
|
|
85
86
|
return (weights, None)
|
|
86
87
|
|
|
87
|
-
|
|
88
|
+
indicators, threshold = self._prune_by_weights_l2_norm(new_num_rows, weights)
|
|
88
89
|
|
|
89
90
|
return torch.ops.fbgemm.embedding_bag_rowwise_prune(
|
|
90
91
|
weights, indicators, threshold, torch.int32
|
|
@@ -100,7 +101,7 @@ class SplitEmbInferenceConverter:
|
|
|
100
101
|
|
|
101
102
|
def _quantize_embs(
|
|
102
103
|
self, weight: Tensor, weight_ty: SparseType
|
|
103
|
-
) ->
|
|
104
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
104
105
|
fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
|
|
105
106
|
return quantize_embs(weight, weight_ty, fp8_quant_config)
|
|
106
107
|
|
|
@@ -129,7 +130,7 @@ class SplitEmbInferenceConverter:
|
|
|
129
130
|
index_remapping_list = []
|
|
130
131
|
for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
|
|
131
132
|
# Try to prune embeddings.
|
|
132
|
-
|
|
133
|
+
pruned_weight, index_remapping = self._prune_embs(t, E, child)
|
|
133
134
|
new_embedding_specs.append(
|
|
134
135
|
(
|
|
135
136
|
"",
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/optimizer_args.py
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/optimizer_args.py"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
9
11
|
# All rights reserved.
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/split_embedding_optimizer_codegen.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/split_embedding_optimizer_codegen.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|