fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
from enum import Enum
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
import torch
|
|
15
|
+
import yaml
|
|
16
|
+
|
|
17
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
|
|
18
|
+
BatchParams,
|
|
19
|
+
IndicesParams,
|
|
20
|
+
PoolingParams,
|
|
21
|
+
TBEDataConfig,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclasses.dataclass(frozen=True)
|
|
26
|
+
class TBEDataConfigHelperText(Enum):
|
|
27
|
+
# Config File
|
|
28
|
+
TBE_CONFIG = "TBE data configuration filepath. If provided, all other `--tbe-*` options are ignored."
|
|
29
|
+
|
|
30
|
+
# Table Parameters
|
|
31
|
+
TBE_NUM_TABLES = "Number of tables (T)"
|
|
32
|
+
TBE_NUM_EMBEDDINGS = "Number of embeddings (E)"
|
|
33
|
+
TBE_EMBEDDING_DIM = "Embedding dimensions (D)"
|
|
34
|
+
TBE_MIXED_DIM = "Use mixed dimensions"
|
|
35
|
+
TBE_WEIGHTED = "Flag to indicate if the table is weighted"
|
|
36
|
+
|
|
37
|
+
# Batch Parameters
|
|
38
|
+
TBE_BATCH_SIZE = "Batch size (B)"
|
|
39
|
+
TBE_BATCH_VBE_SIGMA = "Standard deviation of B for VBE"
|
|
40
|
+
TBE_BATCH_VBE_DIST = "VBE distribution (choices: 'uniform', 'normal')"
|
|
41
|
+
TBE_BATCH_VBE_RANKS = "Number of ranks for VBE"
|
|
42
|
+
|
|
43
|
+
# Indices Parameters
|
|
44
|
+
TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
|
|
45
|
+
TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
|
|
46
|
+
TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
|
|
47
|
+
TBE_OFFSETS_DTYPE = "The dtype of the table indices (choices: '32', '64')"
|
|
48
|
+
|
|
49
|
+
# Pooling Parameters
|
|
50
|
+
TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
|
|
51
|
+
TBE_POOLING_VL_SIGMA = "Standard deviation of B for VBE"
|
|
52
|
+
TBE_POOLING_VL_DIST = "VBE distribution (choices: 'uniform', 'normal')"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TBEDataConfigLoader:
|
|
56
|
+
@classmethod
|
|
57
|
+
# pyre-ignore [2]
|
|
58
|
+
def options(cls, func) -> click.Command:
|
|
59
|
+
options = [
|
|
60
|
+
# Config File
|
|
61
|
+
click.option(
|
|
62
|
+
"--tbe-config",
|
|
63
|
+
type=str,
|
|
64
|
+
required=False,
|
|
65
|
+
help=TBEDataConfigHelperText.TBE_CONFIG.value,
|
|
66
|
+
),
|
|
67
|
+
# Table Parameters
|
|
68
|
+
click.option(
|
|
69
|
+
"--tbe-num-tables",
|
|
70
|
+
type=int,
|
|
71
|
+
default=32,
|
|
72
|
+
help=TBEDataConfigHelperText.TBE_NUM_TABLES.value,
|
|
73
|
+
),
|
|
74
|
+
click.option(
|
|
75
|
+
"--tbe-num-embeddings",
|
|
76
|
+
type=int,
|
|
77
|
+
default=int(1e5),
|
|
78
|
+
help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
|
|
79
|
+
),
|
|
80
|
+
click.option(
|
|
81
|
+
"--tbe-num-embeddings-list",
|
|
82
|
+
type=str,
|
|
83
|
+
required=False,
|
|
84
|
+
default=None,
|
|
85
|
+
help="Comma-separated list of number of embeddings (Es)",
|
|
86
|
+
),
|
|
87
|
+
click.option(
|
|
88
|
+
"--tbe-embedding-dim",
|
|
89
|
+
type=int,
|
|
90
|
+
default=128,
|
|
91
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
|
|
92
|
+
),
|
|
93
|
+
click.option(
|
|
94
|
+
"--tbe-embedding-dim-list",
|
|
95
|
+
type=str,
|
|
96
|
+
required=False,
|
|
97
|
+
default=None,
|
|
98
|
+
help="Comma-separated list of number of Embedding dimensions (D)",
|
|
99
|
+
),
|
|
100
|
+
click.option(
|
|
101
|
+
"--tbe-mixed-dim",
|
|
102
|
+
is_flag=True,
|
|
103
|
+
default=False,
|
|
104
|
+
help=TBEDataConfigHelperText.TBE_MIXED_DIM.value,
|
|
105
|
+
),
|
|
106
|
+
click.option(
|
|
107
|
+
"--tbe-weighted",
|
|
108
|
+
is_flag=True,
|
|
109
|
+
default=False,
|
|
110
|
+
help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
|
|
111
|
+
),
|
|
112
|
+
click.option(
|
|
113
|
+
"--tbe-max-indices",
|
|
114
|
+
type=int,
|
|
115
|
+
required=False,
|
|
116
|
+
default=None,
|
|
117
|
+
help="(Optional) Maximum number of indices, will be calculated if not provided",
|
|
118
|
+
),
|
|
119
|
+
# Batch Parameters
|
|
120
|
+
click.option(
|
|
121
|
+
"--tbe-batch-size",
|
|
122
|
+
type=int,
|
|
123
|
+
default=512,
|
|
124
|
+
help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
|
|
125
|
+
),
|
|
126
|
+
click.option(
|
|
127
|
+
"--tbe-batch-sizes-list",
|
|
128
|
+
type=str,
|
|
129
|
+
required=False,
|
|
130
|
+
default=None,
|
|
131
|
+
help="List Batch sizes per feature (Bs)",
|
|
132
|
+
),
|
|
133
|
+
click.option(
|
|
134
|
+
"--tbe-batch-vbe-sigma",
|
|
135
|
+
type=int,
|
|
136
|
+
required=False,
|
|
137
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_SIGMA.value,
|
|
138
|
+
),
|
|
139
|
+
click.option(
|
|
140
|
+
"--tbe-batch-vbe-dist",
|
|
141
|
+
type=click.Choice(["uniform", "normal"]),
|
|
142
|
+
required=False,
|
|
143
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_DIST.value,
|
|
144
|
+
),
|
|
145
|
+
click.option(
|
|
146
|
+
"--tbe-batch-vbe-ranks",
|
|
147
|
+
type=int,
|
|
148
|
+
required=False,
|
|
149
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_RANKS.value,
|
|
150
|
+
),
|
|
151
|
+
# Indices Parameters
|
|
152
|
+
click.option(
|
|
153
|
+
"--tbe-indices-hitters",
|
|
154
|
+
type=str,
|
|
155
|
+
default="",
|
|
156
|
+
help=TBEDataConfigHelperText.TBE_INDICES_HITTERS.value,
|
|
157
|
+
),
|
|
158
|
+
click.option(
|
|
159
|
+
"--tbe-indices-zipf",
|
|
160
|
+
type=(float, float),
|
|
161
|
+
default=(0.1, 0.1),
|
|
162
|
+
help=TBEDataConfigHelperText.TBE_INDICES_ZIPF.value,
|
|
163
|
+
),
|
|
164
|
+
click.option(
|
|
165
|
+
"--tbe-indices-dtype",
|
|
166
|
+
type=click.Choice(["32", "64"]),
|
|
167
|
+
default="64",
|
|
168
|
+
help=TBEDataConfigHelperText.TBE_INDICES_DTYPE.value,
|
|
169
|
+
),
|
|
170
|
+
click.option(
|
|
171
|
+
"--tbe-offsets-dtype",
|
|
172
|
+
type=click.Choice(["32", "64"]),
|
|
173
|
+
default="64",
|
|
174
|
+
help=TBEDataConfigHelperText.TBE_OFFSETS_DTYPE.value,
|
|
175
|
+
),
|
|
176
|
+
# Pooling Parameters
|
|
177
|
+
click.option(
|
|
178
|
+
"--tbe-pooling-size",
|
|
179
|
+
type=int,
|
|
180
|
+
default=20,
|
|
181
|
+
help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value,
|
|
182
|
+
),
|
|
183
|
+
click.option(
|
|
184
|
+
"--tbe-pooling-vl-sigma",
|
|
185
|
+
type=int,
|
|
186
|
+
required=False,
|
|
187
|
+
help=TBEDataConfigHelperText.TBE_POOLING_VL_SIGMA.value,
|
|
188
|
+
),
|
|
189
|
+
click.option(
|
|
190
|
+
"--tbe-pooling-vl-dist",
|
|
191
|
+
type=click.Choice(["uniform", "normal"]),
|
|
192
|
+
required=False,
|
|
193
|
+
help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
|
|
194
|
+
),
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
for option in reversed(options):
|
|
198
|
+
func = option(func)
|
|
199
|
+
return func
|
|
200
|
+
|
|
201
|
+
@classmethod
|
|
202
|
+
def load_from_file(cls, filepath: str) -> TBEDataConfig:
|
|
203
|
+
with open(filepath, "r") as f:
|
|
204
|
+
if filepath.endswith(".yaml") or filepath.endswith(".yml"):
|
|
205
|
+
data = yaml.safe_load(f)
|
|
206
|
+
return TBEDataConfig.from_dict(data).validate()
|
|
207
|
+
else:
|
|
208
|
+
return TBEDataConfig.from_json(f.read()).validate()
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def load_from_context(cls, context: click.Context) -> TBEDataConfig:
|
|
212
|
+
params = context.params
|
|
213
|
+
|
|
214
|
+
# Read table parameters
|
|
215
|
+
T = params["tbe_num_tables"]
|
|
216
|
+
E = params["tbe_num_embeddings"]
|
|
217
|
+
if params["tbe_num_embeddings_list"] is not None:
|
|
218
|
+
Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
|
|
219
|
+
else:
|
|
220
|
+
Es = None
|
|
221
|
+
D = params["tbe_embedding_dim"]
|
|
222
|
+
if params["tbe_embedding_dim_list"] is not None:
|
|
223
|
+
Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
|
|
224
|
+
else:
|
|
225
|
+
Ds = None
|
|
226
|
+
|
|
227
|
+
mixed_dim = params["tbe_mixed_dim"]
|
|
228
|
+
weighted = params["tbe_weighted"]
|
|
229
|
+
if params["tbe_max_indices"] is not None:
|
|
230
|
+
max_indices = params["tbe_max_indices"]
|
|
231
|
+
else:
|
|
232
|
+
max_indices = None
|
|
233
|
+
|
|
234
|
+
# Read batch parameters
|
|
235
|
+
B = params["tbe_batch_size"]
|
|
236
|
+
sigma_B = params["tbe_batch_vbe_sigma"]
|
|
237
|
+
vbe_distribution = params["tbe_batch_vbe_dist"]
|
|
238
|
+
vbe_num_ranks = params["tbe_batch_vbe_ranks"]
|
|
239
|
+
if params["tbe_batch_sizes_list"] is not None:
|
|
240
|
+
Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
|
|
241
|
+
else:
|
|
242
|
+
Bs = None
|
|
243
|
+
batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
|
|
244
|
+
|
|
245
|
+
# Read indices parameters
|
|
246
|
+
heavy_hitters = (
|
|
247
|
+
torch.tensor([float(x) for x in params["tbe_indices_hitters"].split(",")])
|
|
248
|
+
if params["tbe_indices_hitters"]
|
|
249
|
+
else torch.tensor([])
|
|
250
|
+
)
|
|
251
|
+
zipf_q, zipf_s = params["tbe_indices_zipf"]
|
|
252
|
+
index_dtype = (
|
|
253
|
+
torch.int32 if int(params["tbe_indices_dtype"]) == 32 else torch.int64
|
|
254
|
+
)
|
|
255
|
+
offset_dtype = (
|
|
256
|
+
torch.int32 if int(params["tbe_offsets_dtype"]) == 32 else torch.int64
|
|
257
|
+
)
|
|
258
|
+
indices_params = IndicesParams(
|
|
259
|
+
heavy_hitters, zipf_q, zipf_s, index_dtype, offset_dtype
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Read pooling parameters
|
|
263
|
+
L = params["tbe_pooling_size"]
|
|
264
|
+
sigma_L = params["tbe_pooling_vl_sigma"]
|
|
265
|
+
length_distribution = params["tbe_pooling_vl_dist"]
|
|
266
|
+
pooling_params = PoolingParams(L, sigma_L, length_distribution)
|
|
267
|
+
|
|
268
|
+
return TBEDataConfig(
|
|
269
|
+
T,
|
|
270
|
+
E,
|
|
271
|
+
D,
|
|
272
|
+
mixed_dim,
|
|
273
|
+
weighted,
|
|
274
|
+
batch_params,
|
|
275
|
+
indices_params,
|
|
276
|
+
pooling_params,
|
|
277
|
+
not torch.cuda.is_available(),
|
|
278
|
+
Es,
|
|
279
|
+
Ds,
|
|
280
|
+
max_indices,
|
|
281
|
+
).validate()
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def load(cls, context: click.Context) -> TBEDataConfig:
|
|
285
|
+
tbe_config_filepath = context.params["tbe_config"]
|
|
286
|
+
if tbe_config_filepath is not None:
|
|
287
|
+
return cls.load_from_file(tbe_config_filepath)
|
|
288
|
+
else:
|
|
289
|
+
return cls.load_from_context(context)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
import json
|
|
12
|
+
from typing import Any, Optional
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def str_to_int_dtype(dtype: str) -> torch.dtype:
|
|
18
|
+
if dtype == "torch.int32":
|
|
19
|
+
return torch.int32
|
|
20
|
+
elif dtype == "torch.int64":
|
|
21
|
+
return torch.int64
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass(frozen=True, eq=False)
|
|
27
|
+
class IndicesParams:
|
|
28
|
+
# Heavy hitters for the Zipf distribution, i.e. a probability density map
|
|
29
|
+
# for the most hot indices. There should not ever be more than 100
|
|
30
|
+
# elements, and currently it is limited to 20 entries (kHeavyHittersMaxSize)
|
|
31
|
+
heavy_hitters: torch.Tensor
|
|
32
|
+
# zipf*: parameters for the Zipf distribution (x+q)^{-s}
|
|
33
|
+
zipf_q: float
|
|
34
|
+
# zipf_s is synonymous with alpha in the literature
|
|
35
|
+
zipf_s: float
|
|
36
|
+
# [Optional] dtype for indices tensor
|
|
37
|
+
index_dtype: Optional[torch.dtype] = None
|
|
38
|
+
# [Optional] dtype for offsets tensor
|
|
39
|
+
offset_dtype: Optional[torch.dtype] = None
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
# pyre-ignore [3]
|
|
43
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
44
|
+
if not isinstance(data["heavy_hitters"], torch.Tensor):
|
|
45
|
+
data["heavy_hitters"] = torch.tensor(
|
|
46
|
+
data["heavy_hitters"], dtype=torch.float32
|
|
47
|
+
)
|
|
48
|
+
data["index_dtype"] = str_to_int_dtype(data["index_dtype"])
|
|
49
|
+
data["offset_dtype"] = str_to_int_dtype(data["offset_dtype"])
|
|
50
|
+
return cls(**data)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
# pyre-ignore [3]
|
|
54
|
+
def from_json(cls, data: str):
|
|
55
|
+
return cls.from_dict(json.loads(data))
|
|
56
|
+
|
|
57
|
+
def dict(self) -> dict[str, Any]:
|
|
58
|
+
# https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
|
|
59
|
+
tmp = dataclasses.asdict(self)
|
|
60
|
+
# Convert tensor to list for JSON serialization
|
|
61
|
+
tmp["heavy_hitters"] = self.heavy_hitters.tolist()
|
|
62
|
+
tmp["index_dtype"] = str(self.index_dtype)
|
|
63
|
+
tmp["offset_dtype"] = str(self.offset_dtype)
|
|
64
|
+
return tmp
|
|
65
|
+
|
|
66
|
+
def json(self, format: bool = False) -> str:
|
|
67
|
+
return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
|
|
68
|
+
|
|
69
|
+
# pyre-ignore [2]
|
|
70
|
+
def __eq__(self, other) -> bool:
|
|
71
|
+
return (
|
|
72
|
+
(self.zipf_q, self.zipf_s, self.index_dtype, self.offset_dtype)
|
|
73
|
+
== (other.zipf_q, other.zipf_s, other.index_dtype, other.offset_dtype)
|
|
74
|
+
) and bool((self.heavy_hitters - other.heavy_hitters).abs().max() < 1e-6)
|
|
75
|
+
|
|
76
|
+
# pyre-ignore [3]
|
|
77
|
+
def validate(self):
|
|
78
|
+
assert self.zipf_q > 0, "zipf_q must be positive"
|
|
79
|
+
assert self.zipf_s > 0, "zipf_s must be positive"
|
|
80
|
+
assert self.index_dtype is None or self.index_dtype in [
|
|
81
|
+
torch.int32,
|
|
82
|
+
torch.int64,
|
|
83
|
+
], "index_dtype must be one of [torch.int32, torch.int64]"
|
|
84
|
+
assert self.offset_dtype is None or self.offset_dtype in [
|
|
85
|
+
torch.int32,
|
|
86
|
+
torch.int64,
|
|
87
|
+
], "offset_dtype must be one of [torch.int32, torch.int64]"
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclasses.dataclass(frozen=True)
|
|
92
|
+
class BatchParams:
|
|
93
|
+
# Target batch size, i.e. number of batch lookups per table
|
|
94
|
+
B: int
|
|
95
|
+
# [Optional] Standard deviation of B (for variable batch size configuration)
|
|
96
|
+
sigma_B: Optional[int] = None
|
|
97
|
+
# [Optional] Distribution of batch sizes (normal, uniform)
|
|
98
|
+
vbe_distribution: Optional[str] = "normal"
|
|
99
|
+
# Number of ranks for variable batch size generation
|
|
100
|
+
vbe_num_ranks: Optional[int] = None
|
|
101
|
+
# List of target batch sizes, i.e. number of batch lookups per table
|
|
102
|
+
Bs: Optional[list[int]] = None
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
# pyre-ignore [3]
|
|
106
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
107
|
+
return cls(**data)
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
# pyre-ignore [3]
|
|
111
|
+
def from_json(cls, data: str):
|
|
112
|
+
return cls.from_dict(json.loads(data))
|
|
113
|
+
|
|
114
|
+
def dict(self) -> dict[str, Any]:
|
|
115
|
+
return dataclasses.asdict(self)
|
|
116
|
+
|
|
117
|
+
def json(self, format: bool = False) -> str:
|
|
118
|
+
return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
|
|
119
|
+
|
|
120
|
+
# pyre-ignore [3]
|
|
121
|
+
def validate(self):
|
|
122
|
+
if self.Bs is not None:
|
|
123
|
+
assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
|
|
124
|
+
else:
|
|
125
|
+
assert self.B > 0, "B must be positive"
|
|
126
|
+
assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
|
|
127
|
+
assert (
|
|
128
|
+
self.vbe_num_ranks is None or self.vbe_num_ranks > 0
|
|
129
|
+
), "vbe_num_ranks must be positive"
|
|
130
|
+
assert self.vbe_distribution is None or self.vbe_distribution in [
|
|
131
|
+
"normal",
|
|
132
|
+
"uniform",
|
|
133
|
+
], "vbe_distribution must be one of [normal, uniform]"
|
|
134
|
+
return self
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclasses.dataclass(frozen=True)
|
|
138
|
+
class PoolingParams:
|
|
139
|
+
# Target bag size, i.e. pooling factor, or number of indices per batch lookup
|
|
140
|
+
L: int
|
|
141
|
+
# [Optional] Standard deviation of L (for variable bag size configuration)
|
|
142
|
+
sigma_L: Optional[int] = None
|
|
143
|
+
# [Optional] Distribution of embedding sequence lengths (normal, uniform)
|
|
144
|
+
length_distribution: Optional[str] = "normal"
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
# pyre-ignore [3]
|
|
148
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
149
|
+
return cls(**data)
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
# pyre-ignore [3]
|
|
153
|
+
def from_json(cls, data: str):
|
|
154
|
+
return cls.from_dict(json.loads(data))
|
|
155
|
+
|
|
156
|
+
def dict(self) -> dict[str, Any]:
|
|
157
|
+
return dataclasses.asdict(self)
|
|
158
|
+
|
|
159
|
+
def json(self, format: bool = False) -> str:
|
|
160
|
+
return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
|
|
161
|
+
|
|
162
|
+
# pyre-ignore [3]
|
|
163
|
+
def validate(self):
|
|
164
|
+
assert self.L > 0, "L must be positive"
|
|
165
|
+
assert not self.sigma_L or self.sigma_L > 0, "sigma_L must be positive"
|
|
166
|
+
assert self.length_distribution is None or self.length_distribution in [
|
|
167
|
+
"normal",
|
|
168
|
+
"uniform",
|
|
169
|
+
], "length_distribution must be one of [normal, uniform]"
|
|
170
|
+
return self
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def fill_random_scale_bias(
|
|
20
|
+
emb: torch.nn.Module,
|
|
21
|
+
T: int,
|
|
22
|
+
weights_precision: SparseType,
|
|
23
|
+
) -> None:
|
|
24
|
+
for t in range(T):
|
|
25
|
+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
|
26
|
+
(weights, scale_shift) = emb.split_embedding_weights()[t]
|
|
27
|
+
if scale_shift is not None:
|
|
28
|
+
(E, R) = scale_shift.shape
|
|
29
|
+
assert R == 4
|
|
30
|
+
scales = None
|
|
31
|
+
shifts = None
|
|
32
|
+
if weights_precision == SparseType.INT8:
|
|
33
|
+
scales = np.random.uniform(0.001, 0.01, size=(E,)).astype(np.float16)
|
|
34
|
+
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
|
|
35
|
+
elif weights_precision == SparseType.INT4:
|
|
36
|
+
scales = np.random.uniform(0.01, 0.1, size=(E,)).astype(np.float16)
|
|
37
|
+
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
|
|
38
|
+
elif weights_precision == SparseType.INT2:
|
|
39
|
+
scales = np.random.uniform(0.1, 1, size=(E,)).astype(np.float16)
|
|
40
|
+
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
|
|
41
|
+
scale_shift.copy_(
|
|
42
|
+
torch.tensor(
|
|
43
|
+
np.stack([scales, shifts], axis=1)
|
|
44
|
+
.astype(np.float16)
|
|
45
|
+
.view(np.uint8),
|
|
46
|
+
device=scale_shift.device,
|
|
47
|
+
)
|
|
48
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
|
|
10
|
+
from .kv_embedding_ops_inference import KVEmbeddingInference # noqa: F401
|
|
11
|
+
from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401
|