fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -9,10 +9,11 @@
|
|
|
9
9
|
|
|
10
10
|
import enum
|
|
11
11
|
import itertools
|
|
12
|
-
from typing import Any, Dict
|
|
12
|
+
from typing import Any, Dict # noqa: F401
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
|
|
16
|
+
# fmt:skip
|
|
16
17
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
17
18
|
EmbeddingLocation,
|
|
18
19
|
SplitState,
|
|
@@ -36,6 +37,23 @@ def pad4(value: int) -> int:
|
|
|
36
37
|
return (int(value) + 3) & ~3
|
|
37
38
|
|
|
38
39
|
|
|
40
|
+
def pad16(value: int) -> int:
|
|
41
|
+
"""
|
|
42
|
+
Compute the smallest multiple of 16 that is greater than or equal to the given value.
|
|
43
|
+
|
|
44
|
+
Parameters:
|
|
45
|
+
value (int): The integer to align (must be non-negative).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
int: The aligned value.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If the input is negative.
|
|
52
|
+
TypeError: If the input is not an integer.
|
|
53
|
+
"""
|
|
54
|
+
return (int(value) + 15) & ~15
|
|
55
|
+
|
|
56
|
+
|
|
39
57
|
@enum.unique
|
|
40
58
|
class EmbOptimType(enum.Enum):
|
|
41
59
|
SGD = "sgd" # uses non-deterministic updates (atomicAdd(..)) with duplicate ids
|
|
@@ -64,13 +82,13 @@ class EmbOptimType(enum.Enum):
|
|
|
64
82
|
return self.value
|
|
65
83
|
|
|
66
84
|
def _extract_dtype(
|
|
67
|
-
self, optimizer_state_dtypes:
|
|
85
|
+
self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
|
|
68
86
|
) -> torch.dtype:
|
|
69
87
|
if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
|
|
70
88
|
return torch.float32
|
|
71
89
|
return optimizer_state_dtypes[name].as_dtype()
|
|
72
90
|
|
|
73
|
-
def state_names(self) ->
|
|
91
|
+
def state_names(self) -> list[str]:
|
|
74
92
|
"""
|
|
75
93
|
Returns the names of the optimizer states. The order of the states will
|
|
76
94
|
be the order in which they are processed and returned in
|
|
@@ -79,12 +97,12 @@ class EmbOptimType(enum.Enum):
|
|
|
79
97
|
"""
|
|
80
98
|
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
81
99
|
return ["momentum1"]
|
|
82
|
-
elif self
|
|
100
|
+
elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
|
|
83
101
|
return ["momentum1", "momentum2"]
|
|
84
102
|
else:
|
|
85
103
|
return []
|
|
86
104
|
|
|
87
|
-
def state_size_table(self, D: int) ->
|
|
105
|
+
def state_size_table(self, D: int) -> dict[str, int]:
|
|
88
106
|
"""
|
|
89
107
|
Returns the table of state names to state sizes in terms of number of
|
|
90
108
|
elements (per table row)
|
|
@@ -93,64 +111,84 @@ class EmbOptimType(enum.Enum):
|
|
|
93
111
|
return {"momentum1": 1}
|
|
94
112
|
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
95
113
|
return {"momentum1": D, "momentum2": 1}
|
|
114
|
+
elif self == EmbOptimType.ADAM:
|
|
115
|
+
return {"momentum1": D, "momentum2": D}
|
|
96
116
|
else:
|
|
97
117
|
return {}
|
|
98
118
|
|
|
99
119
|
def state_size_nbytes(
|
|
100
|
-
self,
|
|
120
|
+
self,
|
|
121
|
+
D: int,
|
|
122
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
101
123
|
) -> int:
|
|
102
124
|
"""
|
|
103
125
|
Returns the size of the data (in bytes) required to hold the optimizer
|
|
104
|
-
state (per table row)
|
|
126
|
+
state (per table row). This size includes byte-padding.
|
|
105
127
|
"""
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
128
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
129
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
130
|
+
|
|
131
|
+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
132
|
+
return momentum1_dtype.itemsize
|
|
133
|
+
|
|
134
|
+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
135
|
+
return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
|
|
136
|
+
|
|
137
|
+
elif self == EmbOptimType.ADAM:
|
|
138
|
+
return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
return 0
|
|
114
142
|
|
|
115
143
|
def byte_offsets_along_row(
|
|
116
144
|
self,
|
|
117
145
|
D: int,
|
|
118
146
|
weights_precision: "SparseType",
|
|
119
|
-
optimizer_state_dtypes:
|
|
120
|
-
) ->
|
|
147
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
148
|
+
) -> dict[str, tuple[int, int]]:
|
|
121
149
|
"""
|
|
122
150
|
Returns the start and end byte offsets of each optimizer state along a
|
|
123
151
|
cache row with optimizer state offloading enabled.
|
|
124
152
|
"""
|
|
153
|
+
# Extract the optimizer state dtypes
|
|
154
|
+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
155
|
+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
|
|
125
156
|
|
|
126
157
|
# This is the pointer to where the optimizer state begins in the memory
|
|
127
158
|
p0 = pad4(D) * weights_precision.as_dtype().itemsize
|
|
128
159
|
|
|
129
160
|
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
|
|
130
|
-
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
|
|
131
|
-
# Store one value for momentum per row
|
|
132
161
|
return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
|
|
133
162
|
|
|
134
163
|
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
135
|
-
|
|
136
|
-
|
|
164
|
+
# momentum1 lies after momentum2
|
|
165
|
+
p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
|
|
137
166
|
return {
|
|
138
167
|
"momentum2": (p0, p0 + momentum2_dtype.itemsize),
|
|
139
168
|
"momentum1": (
|
|
140
|
-
|
|
141
|
-
|
|
169
|
+
p1,
|
|
170
|
+
p1 + D * momentum1_dtype.itemsize,
|
|
142
171
|
),
|
|
143
172
|
}
|
|
144
173
|
|
|
174
|
+
elif self == EmbOptimType.ADAM:
|
|
175
|
+
# momentum2 lies after momentum1
|
|
176
|
+
p1 = p0 + (D * momentum1_dtype.itemsize)
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
"momentum1": (p0, p1),
|
|
180
|
+
"momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
|
|
181
|
+
}
|
|
182
|
+
|
|
145
183
|
else:
|
|
146
184
|
return {}
|
|
147
185
|
|
|
148
186
|
def empty_states(
|
|
149
187
|
self,
|
|
150
|
-
rows:
|
|
151
|
-
dims:
|
|
152
|
-
optimizer_state_dtypes:
|
|
153
|
-
) ->
|
|
188
|
+
rows: list[int],
|
|
189
|
+
dims: list[int],
|
|
190
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
191
|
+
) -> list[list[torch.Tensor]]:
|
|
154
192
|
"""
|
|
155
193
|
Creates sets of empty tensors per table to hold optimizer states based
|
|
156
194
|
on the specified optimizer type, state dtypes, embedding specs, and
|
|
@@ -159,7 +197,7 @@ class EmbOptimType(enum.Enum):
|
|
|
159
197
|
# Else, check that the local row count for each table is set
|
|
160
198
|
assert len(rows) == len(dims)
|
|
161
199
|
|
|
162
|
-
opt_states_set:
|
|
200
|
+
opt_states_set: list[list[torch.Tensor]] = []
|
|
163
201
|
|
|
164
202
|
for r, D in zip(rows, dims):
|
|
165
203
|
# Set up the table of state names to state sizes, ordered by their
|
|
@@ -186,20 +224,20 @@ class EmbOptimType(enum.Enum):
|
|
|
186
224
|
|
|
187
225
|
def ssd_state_splits(
|
|
188
226
|
self,
|
|
189
|
-
embedding_specs:
|
|
190
|
-
optimizer_state_dtypes:
|
|
227
|
+
embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
|
|
228
|
+
optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
|
|
191
229
|
enable_optimizer_offloading: bool = False,
|
|
192
|
-
) ->
|
|
230
|
+
) -> list[tuple[SplitState, str, torch.dtype]]:
|
|
193
231
|
"""
|
|
194
232
|
Returns the split planning for the optimizer states
|
|
195
233
|
"""
|
|
196
|
-
|
|
234
|
+
rows, _ = zip(*embedding_specs)
|
|
197
235
|
T_ = len(embedding_specs)
|
|
198
236
|
|
|
199
237
|
# This is the cumulative row counts for rowwise states
|
|
200
|
-
row_count_cumsum:
|
|
238
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
201
239
|
# This is the cumulative element counts for elementwise states
|
|
202
|
-
table_size_cumsum:
|
|
240
|
+
table_size_cumsum: list[int] = [0] + list(
|
|
203
241
|
itertools.accumulate([r * d for r, d in embedding_specs])
|
|
204
242
|
)
|
|
205
243
|
|
|
@@ -207,6 +245,12 @@ class EmbOptimType(enum.Enum):
|
|
|
207
245
|
params = {"momentum1": row_count_cumsum}
|
|
208
246
|
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
|
|
209
247
|
params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
|
|
248
|
+
elif self == EmbOptimType.ADAM:
|
|
249
|
+
params = {
|
|
250
|
+
"momentum1": table_size_cumsum,
|
|
251
|
+
"momentum2": table_size_cumsum,
|
|
252
|
+
"row_counter": row_count_cumsum,
|
|
253
|
+
}
|
|
210
254
|
else:
|
|
211
255
|
params = {}
|
|
212
256
|
|
|
@@ -266,14 +310,54 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
|
|
|
266
310
|
SparseType.BF16.value: 5,
|
|
267
311
|
SparseType.FP8.value: 6,
|
|
268
312
|
SparseType.MX4.value: 7,
|
|
313
|
+
SparseType.NFP8.value: 8,
|
|
269
314
|
}[sparse_type.value]
|
|
270
315
|
|
|
271
316
|
|
|
317
|
+
def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
|
|
318
|
+
"""
|
|
319
|
+
TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
|
|
320
|
+
|
|
321
|
+
This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
|
|
322
|
+
with TorchScript. TorchScript does not support @staticmethod on Enum classes,
|
|
323
|
+
so this function provides a workaround.
|
|
324
|
+
"""
|
|
325
|
+
if ty == 0: # fp32
|
|
326
|
+
return torch.float32
|
|
327
|
+
elif ty == 1: # fp16
|
|
328
|
+
return torch.float16
|
|
329
|
+
elif ty == 2: # int8
|
|
330
|
+
return torch.uint8
|
|
331
|
+
elif ty == 3: # int4
|
|
332
|
+
return torch.quint4x2
|
|
333
|
+
elif ty == 4: # int2
|
|
334
|
+
return torch.quint2x4
|
|
335
|
+
elif ty == 5: # bf16
|
|
336
|
+
return torch.bfloat16
|
|
337
|
+
elif ty == 6: # fp8
|
|
338
|
+
return torch.uint8
|
|
339
|
+
elif ty == 7: # mx4
|
|
340
|
+
return torch.uint8
|
|
341
|
+
elif ty == 9:
|
|
342
|
+
return (
|
|
343
|
+
torch.float8_e4m3fnuz
|
|
344
|
+
if torch.version.hip is not None
|
|
345
|
+
else torch.float8_e4m3fn
|
|
346
|
+
)
|
|
347
|
+
else: # Invalid is 7 or non enumerated.
|
|
348
|
+
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
349
|
+
|
|
350
|
+
|
|
272
351
|
@enum.unique
|
|
273
352
|
class SparseType(enum.Enum):
|
|
274
353
|
FP32 = "fp32"
|
|
275
354
|
FP16 = "fp16"
|
|
276
355
|
FP8 = "fp8"
|
|
356
|
+
# NFP8 refers to "native" FP8 in that it uses the GPU implementations
|
|
357
|
+
# of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
|
|
358
|
+
# NFP8 allows us to use hardware casting intrinsics which can be much faster.
|
|
359
|
+
# Eventually, we should merge these two types.
|
|
360
|
+
NFP8 = "nfp8"
|
|
277
361
|
INT8 = "int8"
|
|
278
362
|
INT4 = "int4"
|
|
279
363
|
INT2 = "int2"
|
|
@@ -299,9 +383,11 @@ class SparseType(enum.Enum):
|
|
|
299
383
|
return SparseType("bf16")
|
|
300
384
|
elif ty == 6:
|
|
301
385
|
return SparseType("fp8")
|
|
302
|
-
elif ty ==
|
|
386
|
+
elif ty == 8:
|
|
303
387
|
return SparseType("mx4")
|
|
304
|
-
|
|
388
|
+
elif ty == 9:
|
|
389
|
+
return SparseType("nfp8")
|
|
390
|
+
else: # Invalid is 7 or non enumerated.
|
|
305
391
|
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
306
392
|
|
|
307
393
|
def as_int(self) -> int:
|
|
@@ -323,6 +409,8 @@ class SparseType(enum.Enum):
|
|
|
323
409
|
return SparseType("bf16")
|
|
324
410
|
elif dtype == torch.uint8:
|
|
325
411
|
return SparseType("mx4")
|
|
412
|
+
elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
|
|
413
|
+
return SparseType("nfp8")
|
|
326
414
|
else:
|
|
327
415
|
raise ValueError(f"Unsupported sparse dtype: {dtype}")
|
|
328
416
|
|
|
@@ -336,6 +424,11 @@ class SparseType(enum.Enum):
|
|
|
336
424
|
SparseType.INT2.value: torch.quint2x4,
|
|
337
425
|
SparseType.BF16.value: torch.bfloat16,
|
|
338
426
|
SparseType.MX4.value: torch.uint8,
|
|
427
|
+
SparseType.NFP8.value: (
|
|
428
|
+
torch.float8_e4m3fnuz
|
|
429
|
+
if torch.version.hip is not None
|
|
430
|
+
else torch.float8_e4m3fn
|
|
431
|
+
),
|
|
339
432
|
}[self.value]
|
|
340
433
|
|
|
341
434
|
def bit_rate(self) -> int:
|
|
@@ -348,6 +441,7 @@ class SparseType(enum.Enum):
|
|
|
348
441
|
SparseType.INT2.value: 2,
|
|
349
442
|
SparseType.BF16.value: 16,
|
|
350
443
|
SparseType.MX4.value: 4,
|
|
444
|
+
SparseType.NFP8.value: 8,
|
|
351
445
|
}[self.value]
|
|
352
446
|
|
|
353
447
|
def align_size(self) -> int:
|
|
@@ -360,6 +454,7 @@ class SparseType(enum.Enum):
|
|
|
360
454
|
SparseType.INT2.value: 16,
|
|
361
455
|
SparseType.BF16.value: 2,
|
|
362
456
|
SparseType.MX4.value: 8,
|
|
457
|
+
SparseType.NFP8.value: 4,
|
|
363
458
|
}[self.value]
|
|
364
459
|
|
|
365
460
|
def is_float(self) -> bool:
|
|
@@ -368,6 +463,7 @@ class SparseType(enum.Enum):
|
|
|
368
463
|
or self.value == SparseType.FP16.value
|
|
369
464
|
or self.value == SparseType.FP8.value
|
|
370
465
|
or self.value == SparseType.BF16.value
|
|
466
|
+
or self.value == SparseType.NFP8.value
|
|
371
467
|
):
|
|
372
468
|
return True
|
|
373
469
|
else:
|
|
@@ -380,11 +476,12 @@ class SparseType(enum.Enum):
|
|
|
380
476
|
return QuantizationConfig()
|
|
381
477
|
|
|
382
478
|
|
|
383
|
-
ELEMENT_SIZE:
|
|
479
|
+
ELEMENT_SIZE: dict[SparseType, int] = {
|
|
384
480
|
SparseType.FP32: 4,
|
|
385
481
|
SparseType.FP16: 2,
|
|
386
482
|
SparseType.FP8: 1,
|
|
387
483
|
SparseType.INT8: 1,
|
|
388
484
|
SparseType.BF16: 2,
|
|
485
|
+
SparseType.NFP8: 1,
|
|
389
486
|
# SparseType.INT4: 0.5,
|
|
390
487
|
}
|
|
@@ -10,10 +10,11 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
import math
|
|
13
|
-
from typing import cast, Optional
|
|
13
|
+
from typing import cast, Optional
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
+
# fmt:skip
|
|
17
18
|
from fbgemm_gpu.split_embedding_configs import (
|
|
18
19
|
FP8QuantizationConfig,
|
|
19
20
|
QuantizationConfig,
|
|
@@ -53,7 +54,7 @@ class SplitEmbInferenceConverter:
|
|
|
53
54
|
return model
|
|
54
55
|
|
|
55
56
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
56
|
-
def _prune_by_weights_l2_norm(self, new_num_rows, weights) ->
|
|
57
|
+
def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
|
|
57
58
|
assert new_num_rows > 0
|
|
58
59
|
from numpy.linalg import norm
|
|
59
60
|
|
|
@@ -75,7 +76,7 @@ class SplitEmbInferenceConverter:
|
|
|
75
76
|
idx: int,
|
|
76
77
|
num_rows: int,
|
|
77
78
|
module: SplitTableBatchedEmbeddingBagsCodegen,
|
|
78
|
-
) ->
|
|
79
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
79
80
|
# TODO(yingz): Avoid DtoH / HtoD overhead.
|
|
80
81
|
weights = module.split_embedding_weights()[idx].cpu()
|
|
81
82
|
if self.pruning_ratio is None:
|
|
@@ -84,7 +85,7 @@ class SplitEmbInferenceConverter:
|
|
|
84
85
|
if new_num_rows == num_rows:
|
|
85
86
|
return (weights, None)
|
|
86
87
|
|
|
87
|
-
|
|
88
|
+
indicators, threshold = self._prune_by_weights_l2_norm(new_num_rows, weights)
|
|
88
89
|
|
|
89
90
|
return torch.ops.fbgemm.embedding_bag_rowwise_prune(
|
|
90
91
|
weights, indicators, threshold, torch.int32
|
|
@@ -100,7 +101,7 @@ class SplitEmbInferenceConverter:
|
|
|
100
101
|
|
|
101
102
|
def _quantize_embs(
|
|
102
103
|
self, weight: Tensor, weight_ty: SparseType
|
|
103
|
-
) ->
|
|
104
|
+
) -> tuple[Tensor, Optional[Tensor]]:
|
|
104
105
|
fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
|
|
105
106
|
return quantize_embs(weight, weight_ty, fp8_quant_config)
|
|
106
107
|
|
|
@@ -129,7 +130,7 @@ class SplitEmbInferenceConverter:
|
|
|
129
130
|
index_remapping_list = []
|
|
130
131
|
for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
|
|
131
132
|
# Try to prune embeddings.
|
|
132
|
-
|
|
133
|
+
pruned_weight, index_remapping = self._prune_embs(t, E, child)
|
|
133
134
|
new_embedding_specs.append(
|
|
134
135
|
(
|
|
135
136
|
"",
|
|
@@ -11,12 +11,11 @@
|
|
|
11
11
|
|
|
12
12
|
import enum
|
|
13
13
|
from dataclasses import dataclass
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import FrozenSet, NamedTuple, Optional, Tuple
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
|
|
19
|
-
|
|
20
19
|
# Maximum number of times prefetch() can be called without
|
|
21
20
|
# a corresponding forward() call
|
|
22
21
|
MAX_PREFETCH_DEPTH = 100
|
|
@@ -62,10 +61,10 @@ class EmbeddingLocation(enum.IntEnum):
|
|
|
62
61
|
|
|
63
62
|
class EvictionPolicy(NamedTuple):
|
|
64
63
|
eviction_trigger_mode: int = (
|
|
65
|
-
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
|
|
64
|
+
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
|
|
66
65
|
)
|
|
67
66
|
eviction_strategy: int = (
|
|
68
|
-
0 # 0: timestamp, 1: counter
|
|
67
|
+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
|
|
69
68
|
)
|
|
70
69
|
eviction_step_intervals: Optional[int] = (
|
|
71
70
|
None # trigger_step_interval if trigger mode is iteration
|
|
@@ -73,18 +72,33 @@ class EvictionPolicy(NamedTuple):
|
|
|
73
72
|
eviction_mem_threshold_gb: Optional[int] = (
|
|
74
73
|
None # eviction trigger condition if trigger mode is mem_util
|
|
75
74
|
)
|
|
76
|
-
counter_thresholds: Optional[
|
|
77
|
-
None # count_thresholds for each table if eviction strategy is
|
|
75
|
+
counter_thresholds: Optional[list[int]] = (
|
|
76
|
+
None # count_thresholds for each table if eviction strategy is counter
|
|
78
77
|
)
|
|
79
|
-
ttls_in_mins: Optional[
|
|
78
|
+
ttls_in_mins: Optional[list[int]] = (
|
|
80
79
|
None # ttls_in_mins for each table if eviction strategy is timestamp
|
|
81
80
|
)
|
|
82
|
-
counter_decay_rates: Optional[
|
|
83
|
-
None # count_decay_rates for each table if eviction strategy is
|
|
81
|
+
counter_decay_rates: Optional[list[float]] = (
|
|
82
|
+
None # count_decay_rates for each table if eviction strategy is counter
|
|
83
|
+
)
|
|
84
|
+
feature_score_counter_decay_rates: Optional[list[float]] = (
|
|
85
|
+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
86
|
+
)
|
|
87
|
+
training_id_eviction_trigger_count: Optional[list[int]] = (
|
|
88
|
+
None # Number of training IDs that, when exceeded, will trigger eviction for each table.
|
|
84
89
|
)
|
|
85
|
-
|
|
90
|
+
training_id_keep_count: Optional[list[int]] = (
|
|
91
|
+
None # Target number of training IDs to retain in each table after eviction.
|
|
92
|
+
)
|
|
93
|
+
l2_weight_thresholds: Optional[list[float]] = (
|
|
86
94
|
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
87
95
|
)
|
|
96
|
+
threshold_calculation_bucket_stride: Optional[float] = (
|
|
97
|
+
0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
98
|
+
)
|
|
99
|
+
threshold_calculation_bucket_num: Optional[int] = (
|
|
100
|
+
1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
101
|
+
)
|
|
88
102
|
interval_for_insufficient_eviction_s: int = (
|
|
89
103
|
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
|
|
90
104
|
# insufficient means we didn't evict enough rows, so we want to wait longer time to
|
|
@@ -95,18 +109,30 @@ class EvictionPolicy(NamedTuple):
|
|
|
95
109
|
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
|
|
96
110
|
60
|
|
97
111
|
)
|
|
98
|
-
|
|
112
|
+
interval_for_feature_statistics_decay_s: int = (
|
|
113
|
+
24 * 3600 # 1 day, interval for feature statistics decay
|
|
114
|
+
)
|
|
115
|
+
meta_header_lens: Optional[list[int]] = None # metaheader length for each table
|
|
116
|
+
eviction_free_mem_threshold_gb: Optional[int] = (
|
|
117
|
+
None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
118
|
+
)
|
|
119
|
+
eviction_free_mem_check_interval_batch: Optional[int] = (
|
|
120
|
+
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
121
|
+
)
|
|
122
|
+
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
|
|
123
|
+
None # enable eviction if eviction policy is feature score, false means no eviction
|
|
124
|
+
)
|
|
99
125
|
|
|
100
126
|
def validate(self) -> None:
|
|
101
|
-
assert self.eviction_trigger_mode in [0, 1, 2, 3], (
|
|
102
|
-
"eviction_trigger_mode must be 0, 1, 2,
|
|
127
|
+
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
|
|
128
|
+
"eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
|
|
103
129
|
f"actual {self.eviction_trigger_mode}"
|
|
104
130
|
)
|
|
105
131
|
if self.eviction_trigger_mode == 0:
|
|
106
132
|
return
|
|
107
133
|
|
|
108
|
-
assert self.eviction_strategy in [0, 1, 2, 3], (
|
|
109
|
-
"eviction_strategy must be 0, 1, 2, or
|
|
134
|
+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
|
|
135
|
+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
|
|
110
136
|
f"actual {self.eviction_strategy}"
|
|
111
137
|
)
|
|
112
138
|
if self.eviction_trigger_mode == 1:
|
|
@@ -121,6 +147,17 @@ class EvictionPolicy(NamedTuple):
|
|
|
121
147
|
assert (
|
|
122
148
|
self.eviction_mem_threshold_gb is not None
|
|
123
149
|
), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
|
|
150
|
+
elif self.eviction_trigger_mode == 4:
|
|
151
|
+
assert (
|
|
152
|
+
self.training_id_eviction_trigger_count is not None
|
|
153
|
+
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
|
|
154
|
+
elif self.eviction_trigger_mode == 5:
|
|
155
|
+
assert (
|
|
156
|
+
self.eviction_free_mem_threshold_gb is not None
|
|
157
|
+
), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
|
|
158
|
+
assert (
|
|
159
|
+
self.eviction_free_mem_check_interval_batch is not None
|
|
160
|
+
), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
|
|
124
161
|
|
|
125
162
|
if self.eviction_strategy == 0:
|
|
126
163
|
assert self.ttls_in_mins is not None, (
|
|
@@ -161,21 +198,58 @@ class EvictionPolicy(NamedTuple):
|
|
|
161
198
|
"counter_thresholds and ttls_in_mins must have the same length, "
|
|
162
199
|
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
|
|
163
200
|
)
|
|
201
|
+
elif self.eviction_strategy == 5:
|
|
202
|
+
assert self.feature_score_counter_decay_rates is not None, (
|
|
203
|
+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
|
|
204
|
+
f"actual {self.feature_score_counter_decay_rates}"
|
|
205
|
+
)
|
|
206
|
+
assert self.training_id_eviction_trigger_count is not None, (
|
|
207
|
+
"training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
|
|
208
|
+
f"actual {self.training_id_eviction_trigger_count}"
|
|
209
|
+
)
|
|
210
|
+
assert self.training_id_keep_count is not None, (
|
|
211
|
+
"training_id_keep_count must be set if eviction_strategy is 5,"
|
|
212
|
+
f"actual {self.training_id_keep_count}"
|
|
213
|
+
)
|
|
214
|
+
assert self.threshold_calculation_bucket_stride is not None, (
|
|
215
|
+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
|
|
216
|
+
f"actual {self.threshold_calculation_bucket_stride}"
|
|
217
|
+
)
|
|
218
|
+
assert self.threshold_calculation_bucket_num is not None, (
|
|
219
|
+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
|
|
220
|
+
f"actual {self.threshold_calculation_bucket_num}"
|
|
221
|
+
)
|
|
222
|
+
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
|
|
223
|
+
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
|
|
224
|
+
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
225
|
+
)
|
|
226
|
+
assert (
|
|
227
|
+
len(self.enable_eviction_for_feature_score_eviction_policy)
|
|
228
|
+
== len(self.training_id_keep_count)
|
|
229
|
+
== len(self.feature_score_counter_decay_rates)
|
|
230
|
+
), (
|
|
231
|
+
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
|
|
232
|
+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
233
|
+
)
|
|
164
234
|
|
|
165
235
|
|
|
166
236
|
class KVZCHParams(NamedTuple):
|
|
167
237
|
# global bucket id start and global bucket id end offsets for each logical table,
|
|
168
238
|
# where start offset is inclusive and end offset is exclusive
|
|
169
|
-
bucket_offsets:
|
|
239
|
+
bucket_offsets: list[tuple[int, int]] = []
|
|
170
240
|
# bucket size for each logical table
|
|
171
241
|
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
|
|
172
|
-
bucket_sizes:
|
|
242
|
+
bucket_sizes: list[int] = []
|
|
173
243
|
# enable optimizer offloading or not
|
|
174
244
|
enable_optimizer_offloading: bool = False
|
|
175
245
|
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
|
|
176
246
|
# can only be enabled when enable_optimizer_offloading is enabled
|
|
177
247
|
backend_return_whole_row: bool = False
|
|
178
248
|
eviction_policy: EvictionPolicy = EvictionPolicy()
|
|
249
|
+
embedding_cache_mode: bool = False
|
|
250
|
+
load_ckpt_without_opt: bool = False
|
|
251
|
+
optimizer_type_for_st: Optional[str] = None
|
|
252
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
179
253
|
|
|
180
254
|
def validate(self) -> None:
|
|
181
255
|
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
|
|
@@ -188,6 +262,25 @@ class KVZCHParams(NamedTuple):
|
|
|
188
262
|
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
|
|
189
263
|
|
|
190
264
|
|
|
265
|
+
class KVZCHTBEConfig(NamedTuple):
|
|
266
|
+
# Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
|
|
267
|
+
kvzch_eviction_trigger_mode: int = 2 # mem_util
|
|
268
|
+
# Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
269
|
+
eviction_free_mem_threshold_gb: int = 200 # 200GB
|
|
270
|
+
# Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
271
|
+
eviction_free_mem_check_interval_batch: int = 1000
|
|
272
|
+
# The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
273
|
+
threshold_calculation_bucket_stride: float = 0.2
|
|
274
|
+
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
275
|
+
threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
|
|
276
|
+
# When true, we only save weight to kvzch backend and not optimizer state.
|
|
277
|
+
load_ckpt_without_opt: bool = False
|
|
278
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
279
|
+
optimizer_type_for_st: Optional[str] = None
|
|
280
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
281
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
282
|
+
|
|
283
|
+
|
|
191
284
|
class BackendType(enum.IntEnum):
|
|
192
285
|
SSD = 0
|
|
193
286
|
DRAM = 1
|
|
@@ -288,8 +381,8 @@ SplitState: NamedTuple = NamedTuple(
|
|
|
288
381
|
("dev_size", int),
|
|
289
382
|
("host_size", int),
|
|
290
383
|
("uvm_size", int),
|
|
291
|
-
("placements",
|
|
292
|
-
("offsets",
|
|
384
|
+
("placements", list[EmbeddingLocation]),
|
|
385
|
+
("offsets", list[int]),
|
|
293
386
|
],
|
|
294
387
|
)
|
|
295
388
|
|
|
@@ -297,15 +390,15 @@ SplitState: NamedTuple = NamedTuple(
|
|
|
297
390
|
@dataclass
|
|
298
391
|
class CacheState:
|
|
299
392
|
# T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
|
|
300
|
-
cache_hash_size_cumsum:
|
|
301
|
-
cache_index_table_map:
|
|
393
|
+
cache_hash_size_cumsum: list[int]
|
|
394
|
+
cache_index_table_map: list[int]
|
|
302
395
|
total_cache_hash_size: int
|
|
303
396
|
|
|
304
397
|
|
|
305
398
|
def construct_cache_state(
|
|
306
|
-
row_list:
|
|
307
|
-
location_list:
|
|
308
|
-
feature_table_map:
|
|
399
|
+
row_list: list[int],
|
|
400
|
+
location_list: list[EmbeddingLocation],
|
|
401
|
+
feature_table_map: list[int],
|
|
309
402
|
) -> CacheState:
|
|
310
403
|
_cache_hash_size_cumsum = [0]
|
|
311
404
|
total_cache_hash_size = 0
|