fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
|
|
12
|
+
# fmt:skip
|
|
13
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
14
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import BoundsCheckMode
|
|
15
|
+
|
|
16
|
+
# fmt:skip
|
|
17
|
+
from .bench_config import TBEBenchmarkingHelperText # usort:skip
|
|
18
|
+
from .tbe_data_config_loader import TBEDataConfigHelperText # usort:skip
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TbeBenchClickInterface:
|
|
22
|
+
@classmethod
|
|
23
|
+
# pyre-ignore [2]
|
|
24
|
+
def common_options(cls, func) -> click.Command:
|
|
25
|
+
options = [
|
|
26
|
+
click.option(
|
|
27
|
+
"--alpha",
|
|
28
|
+
default=1.0,
|
|
29
|
+
help="The alpha value used for the benchmark, default is 1.0. Recommended value: alpha=1.15 for training and alpha=1.09 for inference",
|
|
30
|
+
),
|
|
31
|
+
click.option(
|
|
32
|
+
"--batch-size",
|
|
33
|
+
default=512,
|
|
34
|
+
help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value + " Default is 512.",
|
|
35
|
+
),
|
|
36
|
+
click.option(
|
|
37
|
+
"--weights-precision",
|
|
38
|
+
type=SparseType,
|
|
39
|
+
default=SparseType.FP32,
|
|
40
|
+
help="The precision type for weights, default is FP32.",
|
|
41
|
+
),
|
|
42
|
+
click.option(
|
|
43
|
+
"--stoc",
|
|
44
|
+
is_flag=True,
|
|
45
|
+
default=False,
|
|
46
|
+
help="Flag to enable stochastic rounding, default is False.",
|
|
47
|
+
),
|
|
48
|
+
click.option(
|
|
49
|
+
"--iters",
|
|
50
|
+
default=100,
|
|
51
|
+
help=TBEBenchmarkingHelperText.BENCH_ITERATIONS.value
|
|
52
|
+
+ " Default is 100.",
|
|
53
|
+
),
|
|
54
|
+
click.option(
|
|
55
|
+
"--warmup-runs",
|
|
56
|
+
default=0,
|
|
57
|
+
help=(
|
|
58
|
+
TBEBenchmarkingHelperText.BENCH_WARMUP_ITERATIONS.value
|
|
59
|
+
+ " Default is 0."
|
|
60
|
+
),
|
|
61
|
+
),
|
|
62
|
+
click.option( # Note: Original default for uvm bencmark is 0.1
|
|
63
|
+
"--reuse",
|
|
64
|
+
default=0.0,
|
|
65
|
+
help="The inter-batch indices reuse rate for the benchmark, default is 0.0.",
|
|
66
|
+
),
|
|
67
|
+
click.option(
|
|
68
|
+
"--flush-gpu-cache-size-mb",
|
|
69
|
+
default=0,
|
|
70
|
+
help=TBEBenchmarkingHelperText.BENCH_FLUSH_GPU_CACHE_SIZE.value,
|
|
71
|
+
),
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
for option in reversed(options):
|
|
75
|
+
func = option(func)
|
|
76
|
+
return func
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
# pyre-ignore [2]
|
|
80
|
+
def table_options(cls, func) -> click.Command:
|
|
81
|
+
options = [
|
|
82
|
+
click.option(
|
|
83
|
+
"--bag-size",
|
|
84
|
+
default=20,
|
|
85
|
+
help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value + " Default is 20.",
|
|
86
|
+
),
|
|
87
|
+
click.option(
|
|
88
|
+
"--embedding-dim",
|
|
89
|
+
default=128,
|
|
90
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value
|
|
91
|
+
+ " Default is 128.",
|
|
92
|
+
),
|
|
93
|
+
click.option(
|
|
94
|
+
"--mixed",
|
|
95
|
+
is_flag=True,
|
|
96
|
+
default=False,
|
|
97
|
+
help=TBEDataConfigHelperText.TBE_MIXED_DIM.value + " Default is False.",
|
|
98
|
+
),
|
|
99
|
+
click.option(
|
|
100
|
+
"--num-embeddings",
|
|
101
|
+
default=int(1e5),
|
|
102
|
+
help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value
|
|
103
|
+
+ " Default is 1e5.",
|
|
104
|
+
),
|
|
105
|
+
click.option(
|
|
106
|
+
"--num-tables",
|
|
107
|
+
default=32,
|
|
108
|
+
help=TBEDataConfigHelperText.TBE_NUM_TABLES.value + " Default is 32.",
|
|
109
|
+
),
|
|
110
|
+
click.option(
|
|
111
|
+
"--tables",
|
|
112
|
+
type=str,
|
|
113
|
+
default=None,
|
|
114
|
+
help="Comma-separated list of table numbers Default is None.",
|
|
115
|
+
),
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
for option in reversed(options):
|
|
119
|
+
func = option(func)
|
|
120
|
+
return func
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
# pyre-ignore [2]
|
|
124
|
+
def device_options(cls, func) -> click.Command:
|
|
125
|
+
options = [
|
|
126
|
+
click.option(
|
|
127
|
+
"--cache-precision",
|
|
128
|
+
type=SparseType,
|
|
129
|
+
default=None,
|
|
130
|
+
help="The precision type for cache, default is None.",
|
|
131
|
+
),
|
|
132
|
+
click.option(
|
|
133
|
+
"--managed",
|
|
134
|
+
type=click.Choice(
|
|
135
|
+
["device", "managed", "managed_caching"], case_sensitive=False
|
|
136
|
+
),
|
|
137
|
+
default="device",
|
|
138
|
+
help="The managed option for embedding location. Choices are 'device', 'managed', or 'managed_caching'. Default is 'device'.",
|
|
139
|
+
),
|
|
140
|
+
click.option(
|
|
141
|
+
"--row-wise/--no-row-wise",
|
|
142
|
+
default=True,
|
|
143
|
+
help="Flag to enable or disable row-wise optimization, default is enabled. Use --no-row-wise to disable.",
|
|
144
|
+
),
|
|
145
|
+
click.option(
|
|
146
|
+
"--weighted",
|
|
147
|
+
is_flag=True,
|
|
148
|
+
default=False,
|
|
149
|
+
help=TBEDataConfigHelperText.TBE_WEIGHTED.value + " Default is False.",
|
|
150
|
+
),
|
|
151
|
+
click.option(
|
|
152
|
+
"--pooling",
|
|
153
|
+
type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
|
|
154
|
+
default="sum",
|
|
155
|
+
help="The pooling method to use. Choices are 'sum', 'mean', or 'none'. Default is 'sum'.",
|
|
156
|
+
),
|
|
157
|
+
click.option(
|
|
158
|
+
"--bounds-check-mode",
|
|
159
|
+
type=int,
|
|
160
|
+
default=BoundsCheckMode.NONE.value,
|
|
161
|
+
help="The bounds check mode, default is NONE. Options are: FATAL (0) - Raise an exception (CPU) or device-side assert (CUDA), WARNING (1) - Log the first out-of-bounds instance per kernel, and set to zero, IGNORE (2) - Set to zero, NONE (3) - No bounds checks, V2_IGNORE (4) - IGNORE with V2 enabled, V2_WARNING (5) - WARNING with V2 enabled, V2_FATAL (6) - FATAL with V2 enabled.",
|
|
162
|
+
),
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
for option in reversed(options):
|
|
166
|
+
func = option(func)
|
|
167
|
+
return func
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
# pyre-ignore [2]
|
|
171
|
+
def vbe_options(cls, func) -> click.Command:
|
|
172
|
+
options = [
|
|
173
|
+
click.option(
|
|
174
|
+
"--bag-size-list",
|
|
175
|
+
type=str,
|
|
176
|
+
default="20",
|
|
177
|
+
help="A comma-separated list of bag sizes for each table, default is '20'.",
|
|
178
|
+
),
|
|
179
|
+
click.option(
|
|
180
|
+
"--bag-size-sigma-list",
|
|
181
|
+
type=str,
|
|
182
|
+
default="None",
|
|
183
|
+
help="A comma-separated list of bag size standard deviations for generating bag sizes (one std per table). If set, the benchmark will treat --bag-size-list as a list of bag size means. Default is 'None'.",
|
|
184
|
+
),
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
for option in reversed(options):
|
|
188
|
+
func = option(func)
|
|
189
|
+
return func
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
# fmt:skip
|
|
14
|
+
from fbgemm_gpu.tbe.bench import IndicesParams
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@click.group()
|
|
18
|
+
def cli() -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@cli.command()
|
|
23
|
+
@click.option("--indices", required=True, help="Indices tensor file (*.pt)")
|
|
24
|
+
def estimate(indices: str) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Estimate the distribution of indices given a tensor file
|
|
27
|
+
|
|
28
|
+
Parameters:
|
|
29
|
+
indices (str): Indices tensor file (*.pt)
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
None
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
estimate --indices="indices.pt"
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
indices = torch.load(indices)
|
|
39
|
+
heavy_hitters, q, s, max_index, num_indices = (
|
|
40
|
+
torch.ops.fbgemm.tbe_estimate_indices_distribution(indices)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
params = IndicesParams(
|
|
44
|
+
heavy_hitters=heavy_hitters, zipf_q=q, zipf_s=s, index_dtype=indices.dtype
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
print(params.json(format=True), f"max_index={max_index}\nnum_indices={num_indices}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@cli.command()
|
|
51
|
+
@click.option(
|
|
52
|
+
"--hitters",
|
|
53
|
+
type=str,
|
|
54
|
+
default="",
|
|
55
|
+
help="TBE heavy hitter indices (comma-delimited list of floats)",
|
|
56
|
+
)
|
|
57
|
+
@click.option(
|
|
58
|
+
"--zipf",
|
|
59
|
+
type=(float, float),
|
|
60
|
+
default=(0.1, 0.1),
|
|
61
|
+
help="Zipf distribution parameters for indices generation (q, s)",
|
|
62
|
+
)
|
|
63
|
+
@click.option(
|
|
64
|
+
"-e",
|
|
65
|
+
"--max-index",
|
|
66
|
+
type=int,
|
|
67
|
+
default=20,
|
|
68
|
+
help="Max index value (< E)",
|
|
69
|
+
)
|
|
70
|
+
@click.option(
|
|
71
|
+
"-n",
|
|
72
|
+
"--num-indices",
|
|
73
|
+
type=int,
|
|
74
|
+
default=20,
|
|
75
|
+
help="Target number of indices to generate",
|
|
76
|
+
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--output",
|
|
79
|
+
type=str,
|
|
80
|
+
required=True,
|
|
81
|
+
help="Tensor filepath (*.pt) to save the generated indices",
|
|
82
|
+
)
|
|
83
|
+
def generate(
|
|
84
|
+
hitters: str,
|
|
85
|
+
zipf: tuple[float, float],
|
|
86
|
+
max_index: int,
|
|
87
|
+
num_indices: int,
|
|
88
|
+
output: str,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Generates a tensor of indices given the indices distribution parameters
|
|
92
|
+
|
|
93
|
+
Parameters:
|
|
94
|
+
hitters (str): heavy hitter indices (comma-delimited list of floats)
|
|
95
|
+
|
|
96
|
+
zipf (Tuple[float, float]): Zipf distribution parameters for indices generation (q, s)
|
|
97
|
+
|
|
98
|
+
max_index (int): Max index value (E)
|
|
99
|
+
|
|
100
|
+
num_indices (int): Target number of indices to generate
|
|
101
|
+
|
|
102
|
+
output (str): Tensor filepath (*.pt) to save the generated indices
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
None
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
generate --hitters="2,4,6" --zipf="1.1,1.1" --max-index=10 --num-indices=100 --output="generated_indices.pt"
|
|
109
|
+
"""
|
|
110
|
+
assert max_index > 0, "Max index value (E) must be greater than 0"
|
|
111
|
+
assert num_indices > 0, "Target number of indices must be greater than 0"
|
|
112
|
+
assert zipf[0] > 0, "Zipf parameter q must be greater than 0.0"
|
|
113
|
+
assert zipf[1] > 0, "Zipf parameter s must be greater than 0.0"
|
|
114
|
+
assert output != "", "Output file path must be provided"
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
_hitters: list[float] = (
|
|
118
|
+
[float(x) for x in hitters.split(",")] if hitters else []
|
|
119
|
+
)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
raise AssertionError(
|
|
122
|
+
f'Error: {e}. Please ensure to use comma-delimited list of floats, e.g., --hitters="2,4,6". '
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
heavy_hitters = torch.tensor(_hitters)
|
|
126
|
+
assert heavy_hitters.numel() <= 20, "The number of heavy hitters should be <= 20"
|
|
127
|
+
|
|
128
|
+
indices = torch.ops.fbgemm.tbe_generate_indices_from_distribution(
|
|
129
|
+
heavy_hitters, zipf[0], zipf[1], max_index, num_indices
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
print(f"Generated indices: {indices}")
|
|
133
|
+
torch.save(indices, output)
|
|
134
|
+
print(f"Saved indices to: {output}")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if __name__ == "__main__":
|
|
138
|
+
cli()
|
|
@@ -8,11 +8,12 @@
|
|
|
8
8
|
# pyre-strict
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Optional
|
|
12
12
|
|
|
13
13
|
import click
|
|
14
14
|
import torch
|
|
15
15
|
|
|
16
|
+
# fmt:skip
|
|
16
17
|
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
17
18
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
18
19
|
BoundsCheckMode,
|
|
@@ -44,7 +45,7 @@ class EmbeddingOpsCommonConfig:
|
|
|
44
45
|
def validate(self):
|
|
45
46
|
return self
|
|
46
47
|
|
|
47
|
-
def split_args(self) ->
|
|
48
|
+
def split_args(self) -> dict[str, Any]:
|
|
48
49
|
return {
|
|
49
50
|
"weights_precision": self.weights_dtype,
|
|
50
51
|
"stochastic_rounding": self.stochastic_rounding,
|
|
@@ -99,9 +100,7 @@ class EmbeddingOpsCommonConfigLoader:
|
|
|
99
100
|
click.option(
|
|
100
101
|
"--emb-location",
|
|
101
102
|
default="device",
|
|
102
|
-
type=click.Choice(
|
|
103
|
-
["device", "managed", "managed_caching"], case_sensitive=False
|
|
104
|
-
),
|
|
103
|
+
type=click.Choice(EmbeddingLocation.str_values(), case_sensitive=False),
|
|
105
104
|
help="Memory location of the embeddings",
|
|
106
105
|
),
|
|
107
106
|
click.option(
|
|
@@ -10,7 +10,7 @@
|
|
|
10
10
|
import logging
|
|
11
11
|
import statistics
|
|
12
12
|
from dataclasses import dataclass
|
|
13
|
-
from typing import Callable
|
|
13
|
+
from typing import Callable
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
@@ -29,8 +29,8 @@ class EvalCompressionBenchmarkOutput:
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def benchmark_eval_compression(
|
|
32
|
-
baseline_requests:
|
|
33
|
-
compressed_requests:
|
|
32
|
+
baseline_requests: list[tuple[torch.Tensor, torch.Tensor]],
|
|
33
|
+
compressed_requests: list[tuple[torch.Tensor, torch.Tensor]],
|
|
34
34
|
baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
35
35
|
compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
36
36
|
reindex: torch.Tensor,
|