fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
# flake8: noqa F401
|
|
11
|
+
|
|
12
|
+
import torch # usort:skip
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
# This module is a compatibility wrapper that re-exports the symbols from:
|
|
16
|
+
# fbgemm_gpu.split_table_batched_embeddings_ops_common
|
|
17
|
+
# fbgemm_gpu.split_table_batched_embeddings_ops_inference
|
|
18
|
+
# fbgemm_gpu.split_table_batched_embeddings_ops_training
|
|
19
|
+
|
|
20
|
+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
|
|
21
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
22
|
+
BoundsCheckMode,
|
|
23
|
+
CacheAlgorithm,
|
|
24
|
+
CacheState,
|
|
25
|
+
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
26
|
+
EmbeddingLocation,
|
|
27
|
+
PoolingMode,
|
|
28
|
+
RecordCacheMetrics,
|
|
29
|
+
round_up,
|
|
30
|
+
SplitState,
|
|
31
|
+
)
|
|
32
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
|
|
33
|
+
align_to_cacheline,
|
|
34
|
+
IntNBitTableBatchedEmbeddingBagsCodegen,
|
|
35
|
+
rounded_row_size_in_bytes,
|
|
36
|
+
unpadded_row_size_in_bytes,
|
|
37
|
+
)
|
|
38
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
39
|
+
ComputeDevice,
|
|
40
|
+
CounterBasedRegularizationDefinition,
|
|
41
|
+
CounterWeightDecayMode,
|
|
42
|
+
DEFAULT_ASSOC,
|
|
43
|
+
DenseTableBatchedEmbeddingBagsCodegen,
|
|
44
|
+
GradSumDecay,
|
|
45
|
+
INT8_EMB_ROW_DIM_OFFSET,
|
|
46
|
+
LearningRateMode,
|
|
47
|
+
SplitTableBatchedEmbeddingBagsCodegen,
|
|
48
|
+
TailIdThreshold,
|
|
49
|
+
WeightDecayMode,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
if torch.version.hip:
|
|
54
|
+
torch.ops.load_library(
|
|
55
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip"
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
|
|
59
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
|
|
60
|
+
except Exception:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
warnings.warn(
|
|
64
|
+
f"""\033[93m
|
|
65
|
+
The Python module {__name__} is now DEPRECATED and will be removed in the
|
|
66
|
+
future. Users should instead declare dependencies on
|
|
67
|
+
//deeplearning/fbgemm/fbgemm_gpu/split_table_batched_embeddings_ops_{{training, inference}}
|
|
68
|
+
in their TARGETS file and import the
|
|
69
|
+
fbgemm_gpu.split_table_batched_embeddings_ops_{{training, inference}}
|
|
70
|
+
modules as needed in their scripts.
|
|
71
|
+
\033[0m""",
|
|
72
|
+
DeprecationWarning,
|
|
73
|
+
)
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# pyre-ignore-all-errors[56]
|
|
11
|
+
|
|
12
|
+
import enum
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import FrozenSet, NamedTuple, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Maximum number of times prefetch() can be called without
|
|
21
|
+
# a corresponding forward() call
|
|
22
|
+
MAX_PREFETCH_DEPTH = 100
|
|
23
|
+
|
|
24
|
+
# GPU and CPU use 16-bit scale and bias for quantized embedding bags in TBE
|
|
25
|
+
# The total size is 2 + 2 = 4 bytes
|
|
26
|
+
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES = 4
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EmbeddingLocation(enum.IntEnum):
|
|
30
|
+
DEVICE = 0
|
|
31
|
+
MANAGED = 1
|
|
32
|
+
MANAGED_CACHING = 2
|
|
33
|
+
HOST = 3
|
|
34
|
+
MTIA = 4
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
# pyre-ignore[3]
|
|
38
|
+
def str_values(cls):
|
|
39
|
+
return [
|
|
40
|
+
"device",
|
|
41
|
+
"managed",
|
|
42
|
+
"managed_caching",
|
|
43
|
+
"host",
|
|
44
|
+
"mtia",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
# pyre-ignore[3]
|
|
49
|
+
def from_str(cls, key: str):
|
|
50
|
+
lookup = {
|
|
51
|
+
"device": EmbeddingLocation.DEVICE,
|
|
52
|
+
"managed": EmbeddingLocation.MANAGED,
|
|
53
|
+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
|
|
54
|
+
"host": EmbeddingLocation.HOST,
|
|
55
|
+
"mtia": EmbeddingLocation.MTIA,
|
|
56
|
+
}
|
|
57
|
+
if key in lookup:
|
|
58
|
+
return lookup[key]
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class EvictionPolicy(NamedTuple):
|
|
64
|
+
eviction_trigger_mode: int = (
|
|
65
|
+
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
|
|
66
|
+
)
|
|
67
|
+
eviction_strategy: int = (
|
|
68
|
+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
|
|
69
|
+
)
|
|
70
|
+
eviction_step_intervals: Optional[int] = (
|
|
71
|
+
None # trigger_step_interval if trigger mode is iteration
|
|
72
|
+
)
|
|
73
|
+
eviction_mem_threshold_gb: Optional[int] = (
|
|
74
|
+
None # eviction trigger condition if trigger mode is mem_util
|
|
75
|
+
)
|
|
76
|
+
counter_thresholds: Optional[list[int]] = (
|
|
77
|
+
None # count_thresholds for each table if eviction strategy is counter
|
|
78
|
+
)
|
|
79
|
+
ttls_in_mins: Optional[list[int]] = (
|
|
80
|
+
None # ttls_in_mins for each table if eviction strategy is timestamp
|
|
81
|
+
)
|
|
82
|
+
counter_decay_rates: Optional[list[float]] = (
|
|
83
|
+
None # count_decay_rates for each table if eviction strategy is counter
|
|
84
|
+
)
|
|
85
|
+
feature_score_counter_decay_rates: Optional[list[float]] = (
|
|
86
|
+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
87
|
+
)
|
|
88
|
+
training_id_eviction_trigger_count: Optional[list[int]] = (
|
|
89
|
+
None # Number of training IDs that, when exceeded, will trigger eviction for each table.
|
|
90
|
+
)
|
|
91
|
+
training_id_keep_count: Optional[list[int]] = (
|
|
92
|
+
None # Target number of training IDs to retain in each table after eviction.
|
|
93
|
+
)
|
|
94
|
+
l2_weight_thresholds: Optional[list[float]] = (
|
|
95
|
+
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
96
|
+
)
|
|
97
|
+
threshold_calculation_bucket_stride: Optional[float] = (
|
|
98
|
+
0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
99
|
+
)
|
|
100
|
+
threshold_calculation_bucket_num: Optional[int] = (
|
|
101
|
+
1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
102
|
+
)
|
|
103
|
+
interval_for_insufficient_eviction_s: int = (
|
|
104
|
+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
|
|
105
|
+
# insufficient means we didn't evict enough rows, so we want to wait longer time to
|
|
106
|
+
# avoid another insufficient eviction
|
|
107
|
+
600
|
|
108
|
+
)
|
|
109
|
+
interval_for_sufficient_eviction_s: int = (
|
|
110
|
+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
|
|
111
|
+
60
|
|
112
|
+
)
|
|
113
|
+
interval_for_feature_statistics_decay_s: int = (
|
|
114
|
+
24 * 3600 # 1 day, interval for feature statistics decay
|
|
115
|
+
)
|
|
116
|
+
meta_header_lens: Optional[list[int]] = None # metaheader length for each table
|
|
117
|
+
eviction_free_mem_threshold_gb: Optional[int] = (
|
|
118
|
+
None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
119
|
+
)
|
|
120
|
+
eviction_free_mem_check_interval_batch: Optional[int] = (
|
|
121
|
+
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
122
|
+
)
|
|
123
|
+
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
|
|
124
|
+
None # enable eviction if eviction policy is feature score, false means no eviction
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def validate(self) -> None:
|
|
128
|
+
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
|
|
129
|
+
"eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
|
|
130
|
+
f"actual {self.eviction_trigger_mode}"
|
|
131
|
+
)
|
|
132
|
+
if self.eviction_trigger_mode == 0:
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
|
|
136
|
+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
|
|
137
|
+
f"actual {self.eviction_strategy}"
|
|
138
|
+
)
|
|
139
|
+
if self.eviction_trigger_mode == 1:
|
|
140
|
+
assert (
|
|
141
|
+
self.eviction_step_intervals is not None
|
|
142
|
+
and self.eviction_step_intervals > 0
|
|
143
|
+
), (
|
|
144
|
+
"eviction_step_intervals must be positive if eviction_trigger_mode is 1, "
|
|
145
|
+
f"actual {self.eviction_step_intervals}"
|
|
146
|
+
)
|
|
147
|
+
elif self.eviction_trigger_mode == 2:
|
|
148
|
+
assert (
|
|
149
|
+
self.eviction_mem_threshold_gb is not None
|
|
150
|
+
), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
|
|
151
|
+
elif self.eviction_trigger_mode == 4:
|
|
152
|
+
assert (
|
|
153
|
+
self.training_id_eviction_trigger_count is not None
|
|
154
|
+
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
|
|
155
|
+
elif self.eviction_trigger_mode == 5:
|
|
156
|
+
assert (
|
|
157
|
+
self.eviction_free_mem_threshold_gb is not None
|
|
158
|
+
), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
|
|
159
|
+
assert (
|
|
160
|
+
self.eviction_free_mem_check_interval_batch is not None
|
|
161
|
+
), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
|
|
162
|
+
|
|
163
|
+
if self.eviction_strategy == 0:
|
|
164
|
+
assert self.ttls_in_mins is not None, (
|
|
165
|
+
"ttls_in_mins must be set if eviction_strategy is 0, "
|
|
166
|
+
f"actual {self.ttls_in_mins}"
|
|
167
|
+
)
|
|
168
|
+
elif self.eviction_strategy == 1:
|
|
169
|
+
assert self.counter_thresholds is not None, (
|
|
170
|
+
"counter_thresholds must be set if eviction_strategy is 1, "
|
|
171
|
+
f"actual {self.counter_thresholds}"
|
|
172
|
+
)
|
|
173
|
+
assert self.counter_decay_rates is not None, (
|
|
174
|
+
"counter_decay_rates must be set if eviction_strategy is 1, "
|
|
175
|
+
f"actual {self.counter_decay_rates}"
|
|
176
|
+
)
|
|
177
|
+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
|
|
178
|
+
"counter_thresholds and counter_decay_rates must have the same length, "
|
|
179
|
+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
|
|
180
|
+
)
|
|
181
|
+
elif self.eviction_strategy == 2:
|
|
182
|
+
assert self.counter_thresholds is not None, (
|
|
183
|
+
"counter_thresholds must be set if eviction_strategy is 2, "
|
|
184
|
+
f"actual {self.counter_thresholds}"
|
|
185
|
+
)
|
|
186
|
+
assert self.counter_decay_rates is not None, (
|
|
187
|
+
"counter_decay_rates must be set if eviction_strategy is 2, "
|
|
188
|
+
f"actual {self.counter_decay_rates}"
|
|
189
|
+
)
|
|
190
|
+
assert self.ttls_in_mins is not None, (
|
|
191
|
+
"ttls_in_mins must be set if eviction_strategy is 2, "
|
|
192
|
+
f"actual {self.ttls_in_mins}"
|
|
193
|
+
)
|
|
194
|
+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
|
|
195
|
+
"counter_thresholds and counter_decay_rates must have the same length, "
|
|
196
|
+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
|
|
197
|
+
)
|
|
198
|
+
assert len(self.counter_thresholds) == len(self.ttls_in_mins), (
|
|
199
|
+
"counter_thresholds and ttls_in_mins must have the same length, "
|
|
200
|
+
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
|
|
201
|
+
)
|
|
202
|
+
elif self.eviction_strategy == 5:
|
|
203
|
+
assert self.feature_score_counter_decay_rates is not None, (
|
|
204
|
+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
|
|
205
|
+
f"actual {self.feature_score_counter_decay_rates}"
|
|
206
|
+
)
|
|
207
|
+
assert self.training_id_eviction_trigger_count is not None, (
|
|
208
|
+
"training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
|
|
209
|
+
f"actual {self.training_id_eviction_trigger_count}"
|
|
210
|
+
)
|
|
211
|
+
assert self.training_id_keep_count is not None, (
|
|
212
|
+
"training_id_keep_count must be set if eviction_strategy is 5,"
|
|
213
|
+
f"actual {self.training_id_keep_count}"
|
|
214
|
+
)
|
|
215
|
+
assert self.threshold_calculation_bucket_stride is not None, (
|
|
216
|
+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
|
|
217
|
+
f"actual {self.threshold_calculation_bucket_stride}"
|
|
218
|
+
)
|
|
219
|
+
assert self.threshold_calculation_bucket_num is not None, (
|
|
220
|
+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
|
|
221
|
+
f"actual {self.threshold_calculation_bucket_num}"
|
|
222
|
+
)
|
|
223
|
+
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
|
|
224
|
+
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
|
|
225
|
+
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
226
|
+
)
|
|
227
|
+
assert (
|
|
228
|
+
len(self.enable_eviction_for_feature_score_eviction_policy)
|
|
229
|
+
== len(self.training_id_keep_count)
|
|
230
|
+
== len(self.feature_score_counter_decay_rates)
|
|
231
|
+
), (
|
|
232
|
+
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
|
|
233
|
+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class KVZCHParams(NamedTuple):
|
|
238
|
+
# global bucket id start and global bucket id end offsets for each logical table,
|
|
239
|
+
# where start offset is inclusive and end offset is exclusive
|
|
240
|
+
bucket_offsets: list[tuple[int, int]] = []
|
|
241
|
+
# bucket size for each logical table
|
|
242
|
+
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
|
|
243
|
+
bucket_sizes: list[int] = []
|
|
244
|
+
# enable optimizer offloading or not
|
|
245
|
+
enable_optimizer_offloading: bool = False
|
|
246
|
+
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
|
|
247
|
+
# can only be enabled when enable_optimizer_offloading is enabled
|
|
248
|
+
backend_return_whole_row: bool = False
|
|
249
|
+
eviction_policy: EvictionPolicy = EvictionPolicy()
|
|
250
|
+
embedding_cache_mode: bool = False
|
|
251
|
+
load_ckpt_without_opt: bool = False
|
|
252
|
+
optimizer_type_for_st: Optional[str] = None
|
|
253
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
254
|
+
|
|
255
|
+
def validate(self) -> None:
|
|
256
|
+
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
|
|
257
|
+
"bucket_offsets and bucket_sizes must have the same length, "
|
|
258
|
+
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
|
|
259
|
+
)
|
|
260
|
+
self.eviction_policy.validate()
|
|
261
|
+
assert (
|
|
262
|
+
not self.backend_return_whole_row or self.enable_optimizer_offloading
|
|
263
|
+
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class KVZCHTBEConfig(NamedTuple):
|
|
267
|
+
# Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
|
|
268
|
+
kvzch_eviction_trigger_mode: int = 2 # mem_util
|
|
269
|
+
# Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
270
|
+
eviction_free_mem_threshold_gb: int = 200 # 200GB
|
|
271
|
+
# Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
272
|
+
eviction_free_mem_check_interval_batch: int = 1000
|
|
273
|
+
# The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
274
|
+
threshold_calculation_bucket_stride: float = 0.2
|
|
275
|
+
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
276
|
+
threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
|
|
277
|
+
# When true, we only save weight to kvzch backend and not optimizer state.
|
|
278
|
+
load_ckpt_without_opt: bool = False
|
|
279
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
280
|
+
optimizer_type_for_st: Optional[str] = None
|
|
281
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
282
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class BackendType(enum.IntEnum):
|
|
286
|
+
SSD = 0
|
|
287
|
+
DRAM = 1
|
|
288
|
+
PS = 2
|
|
289
|
+
|
|
290
|
+
@classmethod
|
|
291
|
+
# pyre-ignore[3]
|
|
292
|
+
def from_str(cls, key: str):
|
|
293
|
+
lookup = {
|
|
294
|
+
"ssd": BackendType.SSD,
|
|
295
|
+
"dram": BackendType.DRAM,
|
|
296
|
+
}
|
|
297
|
+
if key in lookup:
|
|
298
|
+
return lookup[key]
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError(f"Cannot parse value into BackendType: {key}")
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class CacheAlgorithm(enum.Enum):
|
|
304
|
+
LRU = 0
|
|
305
|
+
LFU = 1
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class MultiPassPrefetchConfig(NamedTuple):
|
|
309
|
+
# Number of passes to split indices tensor into. Actual number of passes may
|
|
310
|
+
# be less if indices tensor is too small to split.
|
|
311
|
+
num_passes: int = 12
|
|
312
|
+
|
|
313
|
+
# The minimal number of element in indices tensor to be able to split into
|
|
314
|
+
# two passes. This is useful to prevent too many prefetch kernels spamming
|
|
315
|
+
# the CUDA launch queue.
|
|
316
|
+
# The default 6M indices means 6M * 8 * 6 = approx. 300MB of memory overhead
|
|
317
|
+
# per pass.
|
|
318
|
+
min_splitable_pass_size: int = 6 * 1024 * 1024
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class PoolingMode(enum.IntEnum):
|
|
322
|
+
SUM = 0
|
|
323
|
+
MEAN = 1
|
|
324
|
+
NONE = 2
|
|
325
|
+
|
|
326
|
+
def do_pooling(self) -> bool:
|
|
327
|
+
return self is not PoolingMode.NONE
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
# pyre-ignore[3]
|
|
331
|
+
def from_str(cls, key: str):
|
|
332
|
+
lookup = {
|
|
333
|
+
"sum": PoolingMode.SUM,
|
|
334
|
+
"mean": PoolingMode.MEAN,
|
|
335
|
+
"none": PoolingMode.NONE,
|
|
336
|
+
}
|
|
337
|
+
if key in lookup:
|
|
338
|
+
return lookup[key]
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(f"Cannot parse value into PoolingMode: {key}")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class BoundsCheckMode(enum.IntEnum):
|
|
344
|
+
# Raise an exception (CPU) or device-side assert (CUDA)
|
|
345
|
+
FATAL = 0
|
|
346
|
+
# Log the first out-of-bounds instance per kernel, and set to zero.
|
|
347
|
+
WARNING = 1
|
|
348
|
+
# Set to zero.
|
|
349
|
+
IGNORE = 2
|
|
350
|
+
# No bounds checks.
|
|
351
|
+
NONE = 3
|
|
352
|
+
# IGNORE with V2 enabled
|
|
353
|
+
V2_IGNORE = 4
|
|
354
|
+
# WARNING with V2 enabled
|
|
355
|
+
V2_WARNING = 5
|
|
356
|
+
# FATAL with V2 enabled
|
|
357
|
+
V2_FATAL = 6
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class ComputeDevice(enum.IntEnum):
|
|
361
|
+
CPU = 0
|
|
362
|
+
CUDA = 1
|
|
363
|
+
MTIA = 2
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class EmbeddingSpecInfo(enum.IntEnum):
|
|
367
|
+
feature_names = 0
|
|
368
|
+
rows = 1
|
|
369
|
+
dims = 2
|
|
370
|
+
sparse_type = 3
|
|
371
|
+
embedding_location = 4
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
RecordCacheMetrics: NamedTuple = NamedTuple(
|
|
375
|
+
"RecordCacheMetrics",
|
|
376
|
+
[("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)],
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
SplitState: NamedTuple = NamedTuple(
|
|
380
|
+
"SplitState",
|
|
381
|
+
[
|
|
382
|
+
("dev_size", int),
|
|
383
|
+
("host_size", int),
|
|
384
|
+
("uvm_size", int),
|
|
385
|
+
("placements", list[EmbeddingLocation]),
|
|
386
|
+
("offsets", list[int]),
|
|
387
|
+
],
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@dataclass
|
|
392
|
+
class CacheState:
|
|
393
|
+
# T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
|
|
394
|
+
cache_hash_size_cumsum: list[int]
|
|
395
|
+
cache_index_table_map: list[int]
|
|
396
|
+
total_cache_hash_size: int
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def construct_cache_state(
|
|
400
|
+
row_list: list[int],
|
|
401
|
+
location_list: list[EmbeddingLocation],
|
|
402
|
+
feature_table_map: list[int],
|
|
403
|
+
) -> CacheState:
|
|
404
|
+
_cache_hash_size_cumsum = [0]
|
|
405
|
+
total_cache_hash_size = 0
|
|
406
|
+
for num_embeddings, location in zip(row_list, location_list):
|
|
407
|
+
if location == EmbeddingLocation.MANAGED_CACHING:
|
|
408
|
+
total_cache_hash_size += num_embeddings
|
|
409
|
+
_cache_hash_size_cumsum.append(total_cache_hash_size)
|
|
410
|
+
# [T], -1: non-cached table
|
|
411
|
+
cache_hash_size_cumsum = []
|
|
412
|
+
# [total_cache_hash_size], linear cache index -> table index
|
|
413
|
+
cache_index_table_map = [-1] * total_cache_hash_size
|
|
414
|
+
unique_feature_table_map = {}
|
|
415
|
+
for t, t_ in enumerate(feature_table_map):
|
|
416
|
+
unique_feature_table_map[t_] = t
|
|
417
|
+
for t_, t in unique_feature_table_map.items():
|
|
418
|
+
start, end = _cache_hash_size_cumsum[t_], _cache_hash_size_cumsum[t_ + 1]
|
|
419
|
+
cache_index_table_map[start:end] = [t] * (end - start)
|
|
420
|
+
cache_hash_size_cumsum = [
|
|
421
|
+
(
|
|
422
|
+
_cache_hash_size_cumsum[t_]
|
|
423
|
+
if location_list[t_] == EmbeddingLocation.MANAGED_CACHING
|
|
424
|
+
else -1
|
|
425
|
+
)
|
|
426
|
+
for t_ in feature_table_map
|
|
427
|
+
]
|
|
428
|
+
cache_hash_size_cumsum.append(total_cache_hash_size)
|
|
429
|
+
s = CacheState(
|
|
430
|
+
cache_hash_size_cumsum=cache_hash_size_cumsum,
|
|
431
|
+
cache_index_table_map=cache_index_table_map,
|
|
432
|
+
total_cache_hash_size=total_cache_hash_size,
|
|
433
|
+
)
|
|
434
|
+
return s
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
# NOTE: This is also defined in fbgemm_gpu.tbe.utils, but declaring
|
|
438
|
+
# target dependency on :split_embedding_utils will result in compatibility
|
|
439
|
+
# breakage with Caffe2 module_factory because it will pull in numpy
|
|
440
|
+
def round_up(a: int, b: int) -> int:
|
|
441
|
+
return int((a + b - 1) // b) * b
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def tensor_to_device(tensor: torch.Tensor, device: torch.device) -> Tensor:
|
|
445
|
+
if tensor.device == torch.device("meta"):
|
|
446
|
+
return torch.empty_like(tensor, device=device)
|
|
447
|
+
return tensor.to(device)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def get_new_embedding_location(
|
|
451
|
+
device: torch.device, cache_load_factor: float
|
|
452
|
+
) -> EmbeddingLocation:
|
|
453
|
+
"""
|
|
454
|
+
Based on the cache_load_factor and device, return the embedding location intended
|
|
455
|
+
for the TBE weights.
|
|
456
|
+
"""
|
|
457
|
+
# Only support CPU and GPU device
|
|
458
|
+
assert device.type == "cpu" or device.type == "cuda"
|
|
459
|
+
if cache_load_factor < 0 or cache_load_factor > 1:
|
|
460
|
+
raise ValueError(
|
|
461
|
+
f"cache_load_factor must be between 0.0 and 1.0, got {cache_load_factor}"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if device.type == "cpu":
|
|
465
|
+
return EmbeddingLocation.HOST
|
|
466
|
+
# UVM only
|
|
467
|
+
elif cache_load_factor == 0:
|
|
468
|
+
return EmbeddingLocation.MANAGED
|
|
469
|
+
# HBM only
|
|
470
|
+
elif cache_load_factor == 1.0:
|
|
471
|
+
return EmbeddingLocation.DEVICE
|
|
472
|
+
# UVM caching
|
|
473
|
+
else:
|
|
474
|
+
return EmbeddingLocation.MANAGED_CACHING
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def get_bounds_check_version_for_platform() -> int:
|
|
478
|
+
# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
|
|
479
|
+
# constraint that the gridDim * blockDim has to be smaller than
|
|
480
|
+
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
|
|
481
|
+
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
|
|
482
|
+
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
|
|
483
|
+
# than 2^32
|
|
484
|
+
return 2 if (torch.cuda.is_available() and torch.version.hip) else 1
|