fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -7,12 +7,56 @@
|
|
|
7
7
|
|
|
8
8
|
# pyre-strict
|
|
9
9
|
|
|
10
|
+
import dataclasses
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
13
|
+
from enum import Enum
|
|
14
|
+
|
|
10
15
|
import click
|
|
11
16
|
import torch
|
|
12
17
|
import yaml
|
|
13
18
|
|
|
14
|
-
|
|
15
|
-
from .
|
|
19
|
+
# fmt:skip
|
|
20
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
|
|
21
|
+
BatchParams,
|
|
22
|
+
IndicesParams,
|
|
23
|
+
PoolingParams,
|
|
24
|
+
TBEDataConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclasses.dataclass(frozen=True)
|
|
29
|
+
class TBEDataConfigHelperText(Enum):
|
|
30
|
+
# Config File
|
|
31
|
+
TBE_CONFIG = "TBE data configuration filepath. If provided, all other `--tbe-*` options are ignored."
|
|
32
|
+
|
|
33
|
+
# Table Parameters
|
|
34
|
+
TBE_NUM_TABLES = "Number of tables (T)"
|
|
35
|
+
TBE_NUM_EMBEDDINGS = "Number of embeddings (E)"
|
|
36
|
+
TBE_EMBEDDING_DIM = "Embedding dimensions (D)"
|
|
37
|
+
TBE_MIXED_DIM = "Use mixed dimensions"
|
|
38
|
+
TBE_WEIGHTED = "Flag to indicate if the table is weighted"
|
|
39
|
+
|
|
40
|
+
# Batch Parameters
|
|
41
|
+
TBE_BATCH_SIZE = "Batch size (B)"
|
|
42
|
+
TBE_BATCH_VBE_SIGMA = "Standard deviation of B for VBE"
|
|
43
|
+
TBE_BATCH_VBE_DIST = "VBE distribution (choices: 'uniform', 'normal')"
|
|
44
|
+
TBE_BATCH_VBE_RANKS = "Number of ranks for VBE"
|
|
45
|
+
|
|
46
|
+
# Indices Parameters
|
|
47
|
+
TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
|
|
48
|
+
TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
|
|
49
|
+
TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
|
|
50
|
+
TBE_OFFSETS_DTYPE = "The dtype of the table offsets (choices: '32', '64')"
|
|
51
|
+
|
|
52
|
+
# Pooling Parameters
|
|
53
|
+
TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
|
|
54
|
+
TBE_POOLING_VL_SIGMA = "Standard deviation of L for variable bag size"
|
|
55
|
+
TBE_POOLING_VL_DIST = (
|
|
56
|
+
"Variable bag size distribution (choices: 'uniform', 'normal')"
|
|
57
|
+
)
|
|
58
|
+
TBE_EMBEDDING_SPECS = "Embedding Specs which is List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]"
|
|
59
|
+
TBE_FEATURE_TABLE_MAP = "Mapping of feature-table"
|
|
16
60
|
|
|
17
61
|
|
|
18
62
|
class TBEDataConfigLoader:
|
|
@@ -20,119 +64,152 @@ class TBEDataConfigLoader:
|
|
|
20
64
|
# pyre-ignore [2]
|
|
21
65
|
def options(cls, func) -> click.Command:
|
|
22
66
|
options = [
|
|
23
|
-
####################################################################
|
|
24
67
|
# Config File
|
|
25
|
-
####################################################################
|
|
26
68
|
click.option(
|
|
27
69
|
"--tbe-config",
|
|
28
70
|
type=str,
|
|
29
71
|
required=False,
|
|
30
|
-
help=
|
|
72
|
+
help=TBEDataConfigHelperText.TBE_CONFIG.value,
|
|
31
73
|
),
|
|
32
|
-
####################################################################
|
|
33
74
|
# Table Parameters
|
|
34
|
-
####################################################################
|
|
35
75
|
click.option(
|
|
36
76
|
"--tbe-num-tables",
|
|
37
77
|
type=int,
|
|
38
78
|
default=32,
|
|
39
|
-
help=
|
|
79
|
+
help=TBEDataConfigHelperText.TBE_NUM_TABLES.value,
|
|
40
80
|
),
|
|
41
81
|
click.option(
|
|
42
82
|
"--tbe-num-embeddings",
|
|
43
83
|
type=int,
|
|
44
84
|
default=int(1e5),
|
|
45
|
-
help=
|
|
85
|
+
help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
|
|
86
|
+
),
|
|
87
|
+
click.option(
|
|
88
|
+
"--tbe-num-embeddings-list",
|
|
89
|
+
type=str,
|
|
90
|
+
required=False,
|
|
91
|
+
default=None,
|
|
92
|
+
help="Comma-separated list of number of embeddings (Es)",
|
|
46
93
|
),
|
|
47
94
|
click.option(
|
|
48
95
|
"--tbe-embedding-dim",
|
|
49
96
|
type=int,
|
|
50
97
|
default=128,
|
|
51
|
-
help=
|
|
98
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
|
|
99
|
+
),
|
|
100
|
+
click.option(
|
|
101
|
+
"--tbe-embedding-dim-list",
|
|
102
|
+
type=str,
|
|
103
|
+
required=False,
|
|
104
|
+
default=None,
|
|
105
|
+
help="Comma-separated list of number of Embedding dimensions (D)",
|
|
52
106
|
),
|
|
53
107
|
click.option(
|
|
54
108
|
"--tbe-mixed-dim",
|
|
55
109
|
is_flag=True,
|
|
56
110
|
default=False,
|
|
57
|
-
help=
|
|
111
|
+
help=TBEDataConfigHelperText.TBE_MIXED_DIM.value,
|
|
58
112
|
),
|
|
59
113
|
click.option(
|
|
60
114
|
"--tbe-weighted",
|
|
61
115
|
is_flag=True,
|
|
62
116
|
default=False,
|
|
63
|
-
help=
|
|
117
|
+
help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
|
|
118
|
+
),
|
|
119
|
+
click.option(
|
|
120
|
+
"--tbe-max-indices",
|
|
121
|
+
type=int,
|
|
122
|
+
required=False,
|
|
123
|
+
default=None,
|
|
124
|
+
help="(Optional) Maximum number of indices, will be calculated if not provided",
|
|
64
125
|
),
|
|
65
|
-
####################################################################
|
|
66
126
|
# Batch Parameters
|
|
67
|
-
####################################################################
|
|
68
127
|
click.option(
|
|
69
|
-
"--tbe-batch-size",
|
|
128
|
+
"--tbe-batch-size",
|
|
129
|
+
type=int,
|
|
130
|
+
default=512,
|
|
131
|
+
help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
|
|
132
|
+
),
|
|
133
|
+
click.option(
|
|
134
|
+
"--tbe-batch-sizes-list",
|
|
135
|
+
type=str,
|
|
136
|
+
required=False,
|
|
137
|
+
default=None,
|
|
138
|
+
help="List Batch sizes per feature (Bs)",
|
|
70
139
|
),
|
|
71
140
|
click.option(
|
|
72
141
|
"--tbe-batch-vbe-sigma",
|
|
73
142
|
type=int,
|
|
74
143
|
required=False,
|
|
75
|
-
help=
|
|
144
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_SIGMA.value,
|
|
76
145
|
),
|
|
77
146
|
click.option(
|
|
78
147
|
"--tbe-batch-vbe-dist",
|
|
79
148
|
type=click.Choice(["uniform", "normal"]),
|
|
80
149
|
required=False,
|
|
81
|
-
help=
|
|
150
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_DIST.value,
|
|
82
151
|
),
|
|
83
152
|
click.option(
|
|
84
153
|
"--tbe-batch-vbe-ranks",
|
|
85
154
|
type=int,
|
|
86
155
|
required=False,
|
|
87
|
-
help=
|
|
156
|
+
help=TBEDataConfigHelperText.TBE_BATCH_VBE_RANKS.value,
|
|
88
157
|
),
|
|
89
|
-
####################################################################
|
|
90
158
|
# Indices Parameters
|
|
91
|
-
####################################################################
|
|
92
159
|
click.option(
|
|
93
160
|
"--tbe-indices-hitters",
|
|
94
161
|
type=str,
|
|
95
162
|
default="",
|
|
96
|
-
help=
|
|
163
|
+
help=TBEDataConfigHelperText.TBE_INDICES_HITTERS.value,
|
|
97
164
|
),
|
|
98
165
|
click.option(
|
|
99
166
|
"--tbe-indices-zipf",
|
|
100
167
|
type=(float, float),
|
|
101
168
|
default=(0.1, 0.1),
|
|
102
|
-
help=
|
|
169
|
+
help=TBEDataConfigHelperText.TBE_INDICES_ZIPF.value,
|
|
103
170
|
),
|
|
104
171
|
click.option(
|
|
105
172
|
"--tbe-indices-dtype",
|
|
106
173
|
type=click.Choice(["32", "64"]),
|
|
107
174
|
default="64",
|
|
108
|
-
help=
|
|
175
|
+
help=TBEDataConfigHelperText.TBE_INDICES_DTYPE.value,
|
|
109
176
|
),
|
|
110
177
|
click.option(
|
|
111
178
|
"--tbe-offsets-dtype",
|
|
112
179
|
type=click.Choice(["32", "64"]),
|
|
113
180
|
default="64",
|
|
114
|
-
help=
|
|
181
|
+
help=TBEDataConfigHelperText.TBE_OFFSETS_DTYPE.value,
|
|
115
182
|
),
|
|
116
|
-
####################################################################
|
|
117
183
|
# Pooling Parameters
|
|
118
|
-
####################################################################
|
|
119
184
|
click.option(
|
|
120
185
|
"--tbe-pooling-size",
|
|
121
186
|
type=int,
|
|
122
187
|
default=20,
|
|
123
|
-
help=
|
|
188
|
+
help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value,
|
|
124
189
|
),
|
|
125
190
|
click.option(
|
|
126
191
|
"--tbe-pooling-vl-sigma",
|
|
127
192
|
type=int,
|
|
128
193
|
required=False,
|
|
129
|
-
help=
|
|
194
|
+
help=TBEDataConfigHelperText.TBE_POOLING_VL_SIGMA.value,
|
|
130
195
|
),
|
|
131
196
|
click.option(
|
|
132
197
|
"--tbe-pooling-vl-dist",
|
|
133
198
|
type=click.Choice(["uniform", "normal"]),
|
|
134
199
|
required=False,
|
|
135
|
-
help=
|
|
200
|
+
help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
|
|
201
|
+
),
|
|
202
|
+
click.option(
|
|
203
|
+
"--tbe-embedding-specs",
|
|
204
|
+
type=str,
|
|
205
|
+
required=False,
|
|
206
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_SPECS.value,
|
|
207
|
+
),
|
|
208
|
+
click.option(
|
|
209
|
+
"--tbe-feature-table-map",
|
|
210
|
+
type=str,
|
|
211
|
+
required=False,
|
|
212
|
+
help=TBEDataConfigHelperText.TBE_FEATURE_TABLE_MAP.value,
|
|
136
213
|
),
|
|
137
214
|
]
|
|
138
215
|
|
|
@@ -154,18 +231,62 @@ class TBEDataConfigLoader:
|
|
|
154
231
|
params = context.params
|
|
155
232
|
|
|
156
233
|
# Read table parameters
|
|
157
|
-
T = params["tbe_num_tables"]
|
|
158
|
-
E = params["tbe_num_embeddings"]
|
|
234
|
+
T = params["tbe_num_tables"] # number of features
|
|
235
|
+
E = params["tbe_num_embeddings"] # feature_rows
|
|
236
|
+
if params["tbe_num_embeddings_list"] is not None:
|
|
237
|
+
Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
|
|
238
|
+
T = len(Es)
|
|
239
|
+
E = sum(Es) // T # average E
|
|
240
|
+
else:
|
|
241
|
+
Es = None
|
|
159
242
|
D = params["tbe_embedding_dim"]
|
|
243
|
+
if params["tbe_embedding_dim_list"] is not None:
|
|
244
|
+
Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
|
|
245
|
+
assert (
|
|
246
|
+
len(Ds) == T
|
|
247
|
+
), f"Expected tbe_embedding_dim_list to have {T} elements, but got {len(Ds)}"
|
|
248
|
+
D = sum(Ds) // T # average D
|
|
249
|
+
else:
|
|
250
|
+
Ds = None
|
|
251
|
+
|
|
160
252
|
mixed_dim = params["tbe_mixed_dim"]
|
|
161
253
|
weighted = params["tbe_weighted"]
|
|
254
|
+
if params["tbe_max_indices"] is not None:
|
|
255
|
+
max_indices = params["tbe_max_indices"]
|
|
256
|
+
else:
|
|
257
|
+
max_indices = None
|
|
162
258
|
|
|
163
259
|
# Read batch parameters
|
|
164
260
|
B = params["tbe_batch_size"]
|
|
165
261
|
sigma_B = params["tbe_batch_vbe_sigma"]
|
|
166
262
|
vbe_distribution = params["tbe_batch_vbe_dist"]
|
|
167
263
|
vbe_num_ranks = params["tbe_batch_vbe_ranks"]
|
|
168
|
-
|
|
264
|
+
if params["tbe_batch_sizes_list"] is not None:
|
|
265
|
+
Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
|
|
266
|
+
B = sum(Bs) // T # average B
|
|
267
|
+
else:
|
|
268
|
+
B = params["tbe_batch_size"]
|
|
269
|
+
Bs = None
|
|
270
|
+
batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
|
|
271
|
+
|
|
272
|
+
# Parse embedding_specs: "(E,D),(E,D),..." or "(E,D,loc,dev),(E,D,loc,dev),..."
|
|
273
|
+
# Only the first two values (E, D) are extracted.
|
|
274
|
+
embedding_specs = None
|
|
275
|
+
feature_table_map = None
|
|
276
|
+
if params["tbe_embedding_specs"] is not None:
|
|
277
|
+
try:
|
|
278
|
+
tuples = re.findall(r"\(([^)]+)\)", params["tbe_embedding_specs"])
|
|
279
|
+
if tuples:
|
|
280
|
+
embedding_specs = [
|
|
281
|
+
(int(t.split(",")[0].strip()), int(t.split(",")[1].strip()))
|
|
282
|
+
for t in tuples
|
|
283
|
+
]
|
|
284
|
+
except (ValueError, IndexError):
|
|
285
|
+
logging.warning("Failed to parse embedding_specs. Setting to None.")
|
|
286
|
+
if params["tbe_feature_table_map"] is not None:
|
|
287
|
+
feature_table_map = [
|
|
288
|
+
int(x) for x in params["tbe_feature_table_map"].split(",")
|
|
289
|
+
]
|
|
169
290
|
|
|
170
291
|
# Read indices parameters
|
|
171
292
|
heavy_hitters = (
|
|
@@ -200,6 +321,11 @@ class TBEDataConfigLoader:
|
|
|
200
321
|
indices_params,
|
|
201
322
|
pooling_params,
|
|
202
323
|
not torch.cuda.is_available(),
|
|
324
|
+
Es,
|
|
325
|
+
Ds,
|
|
326
|
+
max_indices,
|
|
327
|
+
embedding_specs,
|
|
328
|
+
feature_table_map,
|
|
203
329
|
).validate()
|
|
204
330
|
|
|
205
331
|
@classmethod
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
11
|
import json
|
|
12
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Optional
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
|
|
@@ -31,6 +31,7 @@ class IndicesParams:
|
|
|
31
31
|
heavy_hitters: torch.Tensor
|
|
32
32
|
# zipf*: parameters for the Zipf distribution (x+q)^{-s}
|
|
33
33
|
zipf_q: float
|
|
34
|
+
# zipf_s is synonymous with alpha in the literature
|
|
34
35
|
zipf_s: float
|
|
35
36
|
# [Optional] dtype for indices tensor
|
|
36
37
|
index_dtype: Optional[torch.dtype] = None
|
|
@@ -39,7 +40,7 @@ class IndicesParams:
|
|
|
39
40
|
|
|
40
41
|
@classmethod
|
|
41
42
|
# pyre-ignore [3]
|
|
42
|
-
def from_dict(cls, data:
|
|
43
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
43
44
|
if not isinstance(data["heavy_hitters"], torch.Tensor):
|
|
44
45
|
data["heavy_hitters"] = torch.tensor(
|
|
45
46
|
data["heavy_hitters"], dtype=torch.float32
|
|
@@ -53,7 +54,7 @@ class IndicesParams:
|
|
|
53
54
|
def from_json(cls, data: str):
|
|
54
55
|
return cls.from_dict(json.loads(data))
|
|
55
56
|
|
|
56
|
-
def dict(self) ->
|
|
57
|
+
def dict(self) -> dict[str, Any]:
|
|
57
58
|
# https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
|
|
58
59
|
tmp = dataclasses.asdict(self)
|
|
59
60
|
# Convert tensor to list for JSON serialization
|
|
@@ -97,10 +98,12 @@ class BatchParams:
|
|
|
97
98
|
vbe_distribution: Optional[str] = "normal"
|
|
98
99
|
# Number of ranks for variable batch size generation
|
|
99
100
|
vbe_num_ranks: Optional[int] = None
|
|
101
|
+
# List of target batch sizes, i.e. number of batch lookups per feature
|
|
102
|
+
Bs: Optional[list[int]] = None
|
|
100
103
|
|
|
101
104
|
@classmethod
|
|
102
105
|
# pyre-ignore [3]
|
|
103
|
-
def from_dict(cls, data:
|
|
106
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
104
107
|
return cls(**data)
|
|
105
108
|
|
|
106
109
|
@classmethod
|
|
@@ -108,7 +111,7 @@ class BatchParams:
|
|
|
108
111
|
def from_json(cls, data: str):
|
|
109
112
|
return cls.from_dict(json.loads(data))
|
|
110
113
|
|
|
111
|
-
def dict(self) ->
|
|
114
|
+
def dict(self) -> dict[str, Any]:
|
|
112
115
|
return dataclasses.asdict(self)
|
|
113
116
|
|
|
114
117
|
def json(self, format: bool = False) -> str:
|
|
@@ -116,7 +119,10 @@ class BatchParams:
|
|
|
116
119
|
|
|
117
120
|
# pyre-ignore [3]
|
|
118
121
|
def validate(self):
|
|
119
|
-
|
|
122
|
+
if self.Bs is not None:
|
|
123
|
+
assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
|
|
124
|
+
else:
|
|
125
|
+
assert self.B > 0, "B must be positive"
|
|
120
126
|
assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
|
|
121
127
|
assert (
|
|
122
128
|
self.vbe_num_ranks is None or self.vbe_num_ranks > 0
|
|
@@ -136,10 +142,12 @@ class PoolingParams:
|
|
|
136
142
|
sigma_L: Optional[int] = None
|
|
137
143
|
# [Optional] Distribution of embedding sequence lengths (normal, uniform)
|
|
138
144
|
length_distribution: Optional[str] = "normal"
|
|
145
|
+
# [Optional] List of target bag sizes, i.e. pooling factors per batch
|
|
146
|
+
Ls: Optional[list[float]] = None
|
|
139
147
|
|
|
140
148
|
@classmethod
|
|
141
149
|
# pyre-ignore [3]
|
|
142
|
-
def from_dict(cls, data:
|
|
150
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
143
151
|
return cls(**data)
|
|
144
152
|
|
|
145
153
|
@classmethod
|
|
@@ -147,7 +155,7 @@ class PoolingParams:
|
|
|
147
155
|
def from_json(cls, data: str):
|
|
148
156
|
return cls.from_dict(json.loads(data))
|
|
149
157
|
|
|
150
|
-
def dict(self) ->
|
|
158
|
+
def dict(self) -> dict[str, Any]:
|
|
151
159
|
return dataclasses.asdict(self)
|
|
152
160
|
|
|
153
161
|
def json(self, format: bool = False) -> str:
|
fbgemm_gpu/tbe/bench/utils.py
CHANGED
|
@@ -6,15 +6,14 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
|
-
import
|
|
9
|
+
from typing import List, Tuple
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
|
|
14
|
+
# fmt:skip
|
|
14
15
|
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
15
16
|
|
|
16
|
-
logging.basicConfig(level=logging.DEBUG)
|
|
17
|
-
|
|
18
17
|
|
|
19
18
|
def fill_random_scale_bias(
|
|
20
19
|
emb: torch.nn.Module,
|
|
@@ -23,9 +22,9 @@ def fill_random_scale_bias(
|
|
|
23
22
|
) -> None:
|
|
24
23
|
for t in range(T):
|
|
25
24
|
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
|
26
|
-
|
|
25
|
+
weights, scale_shift = emb.split_embedding_weights()[t]
|
|
27
26
|
if scale_shift is not None:
|
|
28
|
-
|
|
27
|
+
E, R = scale_shift.shape
|
|
29
28
|
assert R == 4
|
|
30
29
|
scales = None
|
|
31
30
|
shifts = None
|
|
@@ -46,3 +45,128 @@ def fill_random_scale_bias(
|
|
|
46
45
|
device=scale_shift.device,
|
|
47
46
|
)
|
|
48
47
|
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def check_oom(
|
|
51
|
+
data_size: int,
|
|
52
|
+
) -> Tuple[bool, str]:
|
|
53
|
+
free_memory, total_memory = torch.cuda.mem_get_info()
|
|
54
|
+
if data_size > free_memory:
|
|
55
|
+
warning = f"Expect to allocate {round(data_size / (1024 ** 3), 2)} GB, but available memory is {round(free_memory / (1024 ** 3), 2)} GB from {round(total_memory / (1024 ** 3), 2)} GB."
|
|
56
|
+
return (True, warning)
|
|
57
|
+
return (False, "")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def generate_batch_size_per_feature_per_rank(
|
|
61
|
+
Bs: List[int], num_ranks: int
|
|
62
|
+
) -> List[List[int]]:
|
|
63
|
+
"""
|
|
64
|
+
Generate batch size per feature per rank for VBE, assuming the batch size
|
|
65
|
+
is evenly distributed across ranks.
|
|
66
|
+
Args:
|
|
67
|
+
Bs (List[int]): batch size per feature
|
|
68
|
+
num_ranks (int): number of ranks
|
|
69
|
+
Returns:
|
|
70
|
+
List[List[int]]: batch size per feature per rank
|
|
71
|
+
"""
|
|
72
|
+
b_per_feature_per_rank = []
|
|
73
|
+
for B in Bs:
|
|
74
|
+
b_per_feature = []
|
|
75
|
+
for i in range(num_ranks):
|
|
76
|
+
if i != num_ranks - 1:
|
|
77
|
+
b_per_feature.append(int(B / num_ranks))
|
|
78
|
+
else:
|
|
79
|
+
b_per_feature.append(B - sum(b_per_feature))
|
|
80
|
+
b_per_feature_per_rank.append(b_per_feature)
|
|
81
|
+
return b_per_feature_per_rank
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_merged_output_and_offsets(
|
|
85
|
+
Ds: List[int],
|
|
86
|
+
Bs: List[int],
|
|
87
|
+
output_dtype: torch.dtype,
|
|
88
|
+
device: torch.device,
|
|
89
|
+
num_ranks: int = 2,
|
|
90
|
+
num_tbe_ops: int = 2,
|
|
91
|
+
) -> Tuple[List[List[int]], torch.Tensor, torch.Tensor]:
|
|
92
|
+
"""
|
|
93
|
+
Generate merged vbe_output and vbe_output_offsets tensors for VBE.
|
|
94
|
+
The vbe_output is a tensor that will contain forward output from all VBE TBE ops.
|
|
95
|
+
The vbe_output_offsets is a tensor that will contain start offsets for the output to be written to.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
Ds (List[int]): embedding dimension per feature
|
|
99
|
+
Bs (List[int]): batch size per feature
|
|
100
|
+
num_ranks (int): number of ranks
|
|
101
|
+
num_tbe_ops (int): number of TBE ops
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple[List[List[int]], torch.Tensor, torch.Tensor]: batch_size_per_feature_per_rank, merged vbe_output and vbe_output_offsets tensors
|
|
104
|
+
"""
|
|
105
|
+
# The first embedding ops is the embedding op created in the benchmark
|
|
106
|
+
emb_op = {}
|
|
107
|
+
emb_op[0] = {}
|
|
108
|
+
emb_op[0]["dim"] = Ds
|
|
109
|
+
emb_op[0]["Bs"] = Bs
|
|
110
|
+
emb_op[0]["output_size"] = sum([b * d for b, d in zip(Bs, Ds)])
|
|
111
|
+
emb_op[0]["batch_size_per_feature_per_rank"] = (
|
|
112
|
+
generate_batch_size_per_feature_per_rank(Bs, num_ranks)
|
|
113
|
+
)
|
|
114
|
+
num_features = len(Bs)
|
|
115
|
+
# create other embedding ops to allocate output and offsets tensors
|
|
116
|
+
# Using representative values for additional TBE ops in multi-op scenarios:
|
|
117
|
+
# - batch_size=32000: typical large batch size for production workloads
|
|
118
|
+
# - dim=512: common embedding dimension for large models
|
|
119
|
+
for i in range(1, num_tbe_ops):
|
|
120
|
+
emb_op[i] = {}
|
|
121
|
+
emb_op[i]["batch_size_per_feature_per_rank"] = (
|
|
122
|
+
generate_batch_size_per_feature_per_rank([32000], num_ranks)
|
|
123
|
+
)
|
|
124
|
+
emb_op[i]["Bs"] = [sum(B) for B in emb_op[i]["batch_size_per_feature_per_rank"]]
|
|
125
|
+
emb_op[i]["dim"] = [512]
|
|
126
|
+
emb_op[i]["output_size"] = sum(
|
|
127
|
+
[b * d for b, d in zip(emb_op[i]["Bs"], emb_op[i]["dim"])]
|
|
128
|
+
)
|
|
129
|
+
total_output = 0
|
|
130
|
+
ranks = [[] for _ in range(num_ranks)]
|
|
131
|
+
for e in emb_op.values():
|
|
132
|
+
b_per_rank_per_feature = list(zip(*e["batch_size_per_feature_per_rank"]))
|
|
133
|
+
assert len(b_per_rank_per_feature) == num_ranks
|
|
134
|
+
dims = e["dim"]
|
|
135
|
+
for r, b_r in enumerate(b_per_rank_per_feature):
|
|
136
|
+
for f, b in enumerate(b_r):
|
|
137
|
+
output_size_per_batch = b * dims[f]
|
|
138
|
+
ranks[r].append(output_size_per_batch)
|
|
139
|
+
total_output += output_size_per_batch
|
|
140
|
+
ranks[0].insert(0, 0)
|
|
141
|
+
offsets_ranks: List[List[int]] = [[] for _ in range(num_ranks)]
|
|
142
|
+
total_output_offsets = []
|
|
143
|
+
start = 0
|
|
144
|
+
for r in range(num_ranks):
|
|
145
|
+
offsets_ranks[r] = [
|
|
146
|
+
start + sum(ranks[r][: i + 1]) for i in range(len(ranks[r]))
|
|
147
|
+
]
|
|
148
|
+
start = offsets_ranks[r][-1]
|
|
149
|
+
total_output_offsets.extend(offsets_ranks[r])
|
|
150
|
+
check_total_output_size = sum([e["output_size"] for e in emb_op.values()])
|
|
151
|
+
assert (
|
|
152
|
+
total_output == check_total_output_size
|
|
153
|
+
), f"{total_output} != {check_total_output_size}{[e['output_size'] for e in emb_op.values()]}"
|
|
154
|
+
assert (
|
|
155
|
+
total_output == total_output_offsets[-1]
|
|
156
|
+
), f"{total_output} != {total_output_offsets[-1]}"
|
|
157
|
+
out = torch.empty(total_output, dtype=output_dtype, device=device)
|
|
158
|
+
offsets = []
|
|
159
|
+
offsets.append(offsets_ranks[0][:num_features])
|
|
160
|
+
for r in range(1, num_ranks):
|
|
161
|
+
start = [offsets_ranks[r - 1][-1]]
|
|
162
|
+
the_rest = offsets_ranks[r][: num_features - 1] if num_features > 1 else []
|
|
163
|
+
start.extend(the_rest)
|
|
164
|
+
offsets.append(start)
|
|
165
|
+
|
|
166
|
+
out_offsets = torch.tensor(
|
|
167
|
+
offsets,
|
|
168
|
+
dtype=torch.int64,
|
|
169
|
+
device=device,
|
|
170
|
+
)
|
|
171
|
+
batch_size_per_feature_per_rank = emb_op[0]["batch_size_per_feature_per_rank"]
|
|
172
|
+
return (batch_size_per_feature_per_rank, out, out_offsets)
|