fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
|
|
12
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
13
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import BoundsCheckMode
|
|
14
|
+
|
|
15
|
+
from .bench_config import TBEBenchmarkingHelperText
|
|
16
|
+
from .tbe_data_config_loader import TBEDataConfigHelperText
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TbeBenchClickInterface:
|
|
20
|
+
@classmethod
|
|
21
|
+
# pyre-ignore [2]
|
|
22
|
+
def common_options(cls, func) -> click.Command:
|
|
23
|
+
options = [
|
|
24
|
+
click.option(
|
|
25
|
+
"--alpha",
|
|
26
|
+
default=1.0,
|
|
27
|
+
help="The alpha value used for the benchmark, default is 1.0. Recommended value: alpha=1.15 for training and alpha=1.09 for inference",
|
|
28
|
+
),
|
|
29
|
+
click.option(
|
|
30
|
+
"--batch-size",
|
|
31
|
+
default=512,
|
|
32
|
+
help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value + " Default is 512.",
|
|
33
|
+
),
|
|
34
|
+
click.option(
|
|
35
|
+
"--weights-precision",
|
|
36
|
+
type=SparseType,
|
|
37
|
+
default=SparseType.FP32,
|
|
38
|
+
help="The precision type for weights, default is FP32.",
|
|
39
|
+
),
|
|
40
|
+
click.option(
|
|
41
|
+
"--stoc",
|
|
42
|
+
is_flag=True,
|
|
43
|
+
default=False,
|
|
44
|
+
help="Flag to enable stochastic rounding, default is False.",
|
|
45
|
+
),
|
|
46
|
+
click.option(
|
|
47
|
+
"--iters",
|
|
48
|
+
default=100,
|
|
49
|
+
help=TBEBenchmarkingHelperText.BENCH_ITERATIONS.value
|
|
50
|
+
+ " Default is 100.",
|
|
51
|
+
),
|
|
52
|
+
click.option(
|
|
53
|
+
"--warmup-runs",
|
|
54
|
+
default=0,
|
|
55
|
+
help=(
|
|
56
|
+
TBEBenchmarkingHelperText.BENCH_WARMUP_ITERATIONS.value
|
|
57
|
+
+ " Default is 0."
|
|
58
|
+
),
|
|
59
|
+
),
|
|
60
|
+
click.option( # Note: Original default for uvm bencmark is 0.1
|
|
61
|
+
"--reuse",
|
|
62
|
+
default=0.0,
|
|
63
|
+
help="The inter-batch indices reuse rate for the benchmark, default is 0.0.",
|
|
64
|
+
),
|
|
65
|
+
click.option(
|
|
66
|
+
"--flush-gpu-cache-size-mb",
|
|
67
|
+
default=0,
|
|
68
|
+
help=TBEBenchmarkingHelperText.BENCH_FLUSH_GPU_CACHE_SIZE.value,
|
|
69
|
+
),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
for option in reversed(options):
|
|
73
|
+
func = option(func)
|
|
74
|
+
return func
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
# pyre-ignore [2]
|
|
78
|
+
def table_options(cls, func) -> click.Command:
|
|
79
|
+
options = [
|
|
80
|
+
click.option(
|
|
81
|
+
"--bag-size",
|
|
82
|
+
default=20,
|
|
83
|
+
help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value + " Default is 20.",
|
|
84
|
+
),
|
|
85
|
+
click.option(
|
|
86
|
+
"--embedding-dim",
|
|
87
|
+
default=128,
|
|
88
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value
|
|
89
|
+
+ " Default is 128.",
|
|
90
|
+
),
|
|
91
|
+
click.option(
|
|
92
|
+
"--mixed",
|
|
93
|
+
is_flag=True,
|
|
94
|
+
default=False,
|
|
95
|
+
help=TBEDataConfigHelperText.TBE_MIXED_DIM.value + " Default is False.",
|
|
96
|
+
),
|
|
97
|
+
click.option(
|
|
98
|
+
"--num-embeddings",
|
|
99
|
+
default=int(1e5),
|
|
100
|
+
help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value
|
|
101
|
+
+ " Default is 1e5.",
|
|
102
|
+
),
|
|
103
|
+
click.option(
|
|
104
|
+
"--num-tables",
|
|
105
|
+
default=32,
|
|
106
|
+
help=TBEDataConfigHelperText.TBE_NUM_TABLES.value + " Default is 32.",
|
|
107
|
+
),
|
|
108
|
+
click.option(
|
|
109
|
+
"--tables",
|
|
110
|
+
type=str,
|
|
111
|
+
default=None,
|
|
112
|
+
help="Comma-separated list of table numbers Default is None.",
|
|
113
|
+
),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
for option in reversed(options):
|
|
117
|
+
func = option(func)
|
|
118
|
+
return func
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
# pyre-ignore [2]
|
|
122
|
+
def device_options(cls, func) -> click.Command:
|
|
123
|
+
options = [
|
|
124
|
+
click.option(
|
|
125
|
+
"--cache-precision",
|
|
126
|
+
type=SparseType,
|
|
127
|
+
default=None,
|
|
128
|
+
help="The precision type for cache, default is None.",
|
|
129
|
+
),
|
|
130
|
+
click.option(
|
|
131
|
+
"--managed",
|
|
132
|
+
type=click.Choice(
|
|
133
|
+
["device", "managed", "managed_caching"], case_sensitive=False
|
|
134
|
+
),
|
|
135
|
+
default="device",
|
|
136
|
+
help="The managed option for embedding location. Choices are 'device', 'managed', or 'managed_caching'. Default is 'device'.",
|
|
137
|
+
),
|
|
138
|
+
click.option(
|
|
139
|
+
"--row-wise/--no-row-wise",
|
|
140
|
+
default=True,
|
|
141
|
+
help="Flag to enable or disable row-wise optimization, default is enabled. Use --no-row-wise to disable.",
|
|
142
|
+
),
|
|
143
|
+
click.option(
|
|
144
|
+
"--weighted",
|
|
145
|
+
is_flag=True,
|
|
146
|
+
default=False,
|
|
147
|
+
help=TBEDataConfigHelperText.TBE_WEIGHTED.value + " Default is False.",
|
|
148
|
+
),
|
|
149
|
+
click.option(
|
|
150
|
+
"--pooling",
|
|
151
|
+
type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
|
|
152
|
+
default="sum",
|
|
153
|
+
help="The pooling method to use. Choices are 'sum', 'mean', or 'none'. Default is 'sum'.",
|
|
154
|
+
),
|
|
155
|
+
click.option(
|
|
156
|
+
"--bounds-check-mode",
|
|
157
|
+
type=int,
|
|
158
|
+
default=BoundsCheckMode.NONE.value,
|
|
159
|
+
help="The bounds check mode, default is NONE. Options are: FATAL (0) - Raise an exception (CPU) or device-side assert (CUDA), WARNING (1) - Log the first out-of-bounds instance per kernel, and set to zero, IGNORE (2) - Set to zero, NONE (3) - No bounds checks, V2_IGNORE (4) - IGNORE with V2 enabled, V2_WARNING (5) - WARNING with V2 enabled, V2_FATAL (6) - FATAL with V2 enabled.",
|
|
160
|
+
),
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
for option in reversed(options):
|
|
164
|
+
func = option(func)
|
|
165
|
+
return func
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
# pyre-ignore [2]
|
|
169
|
+
def vbe_options(cls, func) -> click.Command:
|
|
170
|
+
options = [
|
|
171
|
+
click.option(
|
|
172
|
+
"--bag-size-list",
|
|
173
|
+
type=str,
|
|
174
|
+
default="20",
|
|
175
|
+
help="A comma-separated list of bag sizes for each table, default is '20'.",
|
|
176
|
+
),
|
|
177
|
+
click.option(
|
|
178
|
+
"--bag-size-sigma-list",
|
|
179
|
+
type=str,
|
|
180
|
+
default="None",
|
|
181
|
+
help="A comma-separated list of bag size standard deviations for generating bag sizes (one std per table). If set, the benchmark will treat --bag-size-list as a list of bag size means. Default is 'None'.",
|
|
182
|
+
),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
for option in reversed(options):
|
|
186
|
+
func = option(func)
|
|
187
|
+
return func
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from fbgemm_gpu.tbe.bench import IndicesParams
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.group()
|
|
17
|
+
def cli() -> None:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@cli.command()
|
|
22
|
+
@click.option("--indices", required=True, help="Indices tensor file (*.pt)")
|
|
23
|
+
def estimate(indices: str) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Estimate the distribution of indices given a tensor file
|
|
26
|
+
|
|
27
|
+
Parameters:
|
|
28
|
+
indices (str): Indices tensor file (*.pt)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
None
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
estimate --indices="indices.pt"
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
indices = torch.load(indices)
|
|
38
|
+
heavy_hitters, q, s, max_index, num_indices = (
|
|
39
|
+
torch.ops.fbgemm.tbe_estimate_indices_distribution(indices)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
params = IndicesParams(
|
|
43
|
+
heavy_hitters=heavy_hitters, zipf_q=q, zipf_s=s, index_dtype=indices.dtype
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
print(params.json(format=True), f"max_index={max_index}\nnum_indices={num_indices}")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@cli.command()
|
|
50
|
+
@click.option(
|
|
51
|
+
"--hitters",
|
|
52
|
+
type=str,
|
|
53
|
+
default="",
|
|
54
|
+
help="TBE heavy hitter indices (comma-delimited list of floats)",
|
|
55
|
+
)
|
|
56
|
+
@click.option(
|
|
57
|
+
"--zipf",
|
|
58
|
+
type=(float, float),
|
|
59
|
+
default=(0.1, 0.1),
|
|
60
|
+
help="Zipf distribution parameters for indices generation (q, s)",
|
|
61
|
+
)
|
|
62
|
+
@click.option(
|
|
63
|
+
"-e",
|
|
64
|
+
"--max-index",
|
|
65
|
+
type=int,
|
|
66
|
+
default=20,
|
|
67
|
+
help="Max index value (< E)",
|
|
68
|
+
)
|
|
69
|
+
@click.option(
|
|
70
|
+
"-n",
|
|
71
|
+
"--num-indices",
|
|
72
|
+
type=int,
|
|
73
|
+
default=20,
|
|
74
|
+
help="Target number of indices to generate",
|
|
75
|
+
)
|
|
76
|
+
@click.option(
|
|
77
|
+
"--output",
|
|
78
|
+
type=str,
|
|
79
|
+
required=True,
|
|
80
|
+
help="Tensor filepath (*.pt) to save the generated indices",
|
|
81
|
+
)
|
|
82
|
+
def generate(
|
|
83
|
+
hitters: str,
|
|
84
|
+
zipf: tuple[float, float],
|
|
85
|
+
max_index: int,
|
|
86
|
+
num_indices: int,
|
|
87
|
+
output: str,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Generates a tensor of indices given the indices distribution parameters
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
hitters (str): heavy hitter indices (comma-delimited list of floats)
|
|
94
|
+
|
|
95
|
+
zipf (Tuple[float, float]): Zipf distribution parameters for indices generation (q, s)
|
|
96
|
+
|
|
97
|
+
max_index (int): Max index value (E)
|
|
98
|
+
|
|
99
|
+
num_indices (int): Target number of indices to generate
|
|
100
|
+
|
|
101
|
+
output (str): Tensor filepath (*.pt) to save the generated indices
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
None
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
generate --hitters="2,4,6" --zipf="1.1,1.1" --max-index=10 --num-indices=100 --output="generated_indices.pt"
|
|
108
|
+
"""
|
|
109
|
+
assert max_index > 0, "Max index value (E) must be greater than 0"
|
|
110
|
+
assert num_indices > 0, "Target number of indices must be greater than 0"
|
|
111
|
+
assert zipf[0] > 0, "Zipf parameter q must be greater than 0.0"
|
|
112
|
+
assert zipf[1] > 0, "Zipf parameter s must be greater than 0.0"
|
|
113
|
+
assert output != "", "Output file path must be provided"
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
_hitters: list[float] = (
|
|
117
|
+
[float(x) for x in hitters.split(",")] if hitters else []
|
|
118
|
+
)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise AssertionError(
|
|
121
|
+
f'Error: {e}. Please ensure to use comma-delimited list of floats, e.g., --hitters="2,4,6". '
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
heavy_hitters = torch.tensor(_hitters)
|
|
125
|
+
assert heavy_hitters.numel() <= 20, "The number of heavy hitters should be <= 20"
|
|
126
|
+
|
|
127
|
+
indices = torch.ops.fbgemm.tbe_generate_indices_from_distribution(
|
|
128
|
+
heavy_hitters, zipf[0], zipf[1], max_index, num_indices
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
print(f"Generated indices: {indices}")
|
|
132
|
+
torch.save(indices, output)
|
|
133
|
+
print(f"Saved indices to: {output}")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
cli()
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
17
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
18
|
+
BoundsCheckMode,
|
|
19
|
+
EmbeddingLocation,
|
|
20
|
+
PoolingMode,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclasses.dataclass(frozen=True)
|
|
25
|
+
class EmbeddingOpsCommonConfig:
|
|
26
|
+
# Precision of the embedding weights
|
|
27
|
+
weights_dtype: SparseType
|
|
28
|
+
# Precision of the embedding cache
|
|
29
|
+
cache_dtype: Optional[SparseType]
|
|
30
|
+
# Precision of the embedding output
|
|
31
|
+
output_dtype: SparseType
|
|
32
|
+
# Enable stochastic rounding when performing quantization
|
|
33
|
+
stochastic_rounding: bool
|
|
34
|
+
# Pooling operation to perform
|
|
35
|
+
pooling_mode: PoolingMode
|
|
36
|
+
# Use host-mapped UVM buffers
|
|
37
|
+
uvm_host_mapped: bool
|
|
38
|
+
# Memory location of the embeddings
|
|
39
|
+
embedding_location: EmbeddingLocation
|
|
40
|
+
# Bounds check mode
|
|
41
|
+
bounds_check_mode: BoundsCheckMode
|
|
42
|
+
|
|
43
|
+
# pyre-ignore [3]
|
|
44
|
+
def validate(self):
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
def split_args(self) -> dict[str, Any]:
|
|
48
|
+
return {
|
|
49
|
+
"weights_precision": self.weights_dtype,
|
|
50
|
+
"stochastic_rounding": self.stochastic_rounding,
|
|
51
|
+
"output_dtype": self.output_dtype,
|
|
52
|
+
"pooling_mode": self.pooling_mode,
|
|
53
|
+
"bounds_check_mode": self.bounds_check_mode,
|
|
54
|
+
"uvm_host_mapped": self.uvm_host_mapped,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class EmbeddingOpsCommonConfigLoader:
|
|
59
|
+
@classmethod
|
|
60
|
+
# pyre-ignore [2]
|
|
61
|
+
def options(cls, func) -> click.Command:
|
|
62
|
+
options = [
|
|
63
|
+
click.option(
|
|
64
|
+
"--emb-weights-dtype",
|
|
65
|
+
type=SparseType,
|
|
66
|
+
default=SparseType.FP32,
|
|
67
|
+
help="Precision of the embedding weights",
|
|
68
|
+
),
|
|
69
|
+
click.option(
|
|
70
|
+
"--emb-cache-dtype",
|
|
71
|
+
type=SparseType,
|
|
72
|
+
default=None,
|
|
73
|
+
help="Precision of the embedding cache",
|
|
74
|
+
),
|
|
75
|
+
click.option(
|
|
76
|
+
"--emb-output-dtype",
|
|
77
|
+
type=SparseType,
|
|
78
|
+
default=SparseType.FP32,
|
|
79
|
+
help="Precision of the embedding output",
|
|
80
|
+
),
|
|
81
|
+
click.option(
|
|
82
|
+
"--emb-stochastic-rounding",
|
|
83
|
+
is_flag=True,
|
|
84
|
+
default=False,
|
|
85
|
+
help="Enable stochastic rounding when performing quantization",
|
|
86
|
+
),
|
|
87
|
+
click.option(
|
|
88
|
+
"--emb-pooling-mode",
|
|
89
|
+
type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
|
|
90
|
+
default="sum",
|
|
91
|
+
help="Pooling operation to perform",
|
|
92
|
+
),
|
|
93
|
+
click.option(
|
|
94
|
+
"--emb-uvm-host-mapped",
|
|
95
|
+
is_flag=True,
|
|
96
|
+
default=False,
|
|
97
|
+
help="Use host-mapped UVM buffers",
|
|
98
|
+
),
|
|
99
|
+
click.option(
|
|
100
|
+
"--emb-location",
|
|
101
|
+
default="device",
|
|
102
|
+
type=click.Choice(EmbeddingLocation.str_values(), case_sensitive=False),
|
|
103
|
+
help="Memory location of the embeddings",
|
|
104
|
+
),
|
|
105
|
+
click.option(
|
|
106
|
+
"--emb-bounds-check",
|
|
107
|
+
type=int,
|
|
108
|
+
default=BoundsCheckMode.WARNING.value,
|
|
109
|
+
help="Bounds check mode"
|
|
110
|
+
f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
|
|
111
|
+
f"WARNING={BoundsCheckMode.WARNING.value}, "
|
|
112
|
+
f"IGNORE={BoundsCheckMode.IGNORE.value}, "
|
|
113
|
+
f"NONE={BoundsCheckMode.NONE.value}",
|
|
114
|
+
),
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
for option in reversed(options):
|
|
118
|
+
func = option(func)
|
|
119
|
+
return func
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def load(cls, context: click.Context) -> EmbeddingOpsCommonConfig:
|
|
123
|
+
params = context.params
|
|
124
|
+
|
|
125
|
+
weights_dtype = params["emb_weights_dtype"]
|
|
126
|
+
cache_dtype = params["emb_cache_dtype"]
|
|
127
|
+
output_dtype = params["emb_output_dtype"]
|
|
128
|
+
stochastic_rounding = params["emb_stochastic_rounding"]
|
|
129
|
+
pooling_mode = PoolingMode.from_str(str(params["emb_pooling_mode"]))
|
|
130
|
+
uvm_host_mapped = params["emb_uvm_host_mapped"]
|
|
131
|
+
bounds_check_mode = BoundsCheckMode(params["emb_bounds_check"])
|
|
132
|
+
|
|
133
|
+
embedding_location = EmbeddingLocation.from_str(str(params["emb_location"]))
|
|
134
|
+
if (
|
|
135
|
+
embedding_location is EmbeddingLocation.DEVICE
|
|
136
|
+
and not torch.cuda.is_available()
|
|
137
|
+
):
|
|
138
|
+
embedding_location = EmbeddingLocation.HOST
|
|
139
|
+
|
|
140
|
+
return EmbeddingOpsCommonConfig(
|
|
141
|
+
weights_dtype,
|
|
142
|
+
cache_dtype,
|
|
143
|
+
output_dtype,
|
|
144
|
+
stochastic_rounding,
|
|
145
|
+
pooling_mode,
|
|
146
|
+
uvm_host_mapped,
|
|
147
|
+
embedding_location,
|
|
148
|
+
bounds_check_mode,
|
|
149
|
+
).validate()
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import statistics
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Callable
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class EvalCompressionBenchmarkOutput:
|
|
22
|
+
avg: float
|
|
23
|
+
fwd: float
|
|
24
|
+
bwd: float
|
|
25
|
+
compressed_avg: float
|
|
26
|
+
compressed_fwd: float
|
|
27
|
+
reindex: float
|
|
28
|
+
compressed_bwd: float
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def benchmark_eval_compression(
|
|
32
|
+
baseline_requests: list[tuple[torch.Tensor, torch.Tensor]],
|
|
33
|
+
compressed_requests: list[tuple[torch.Tensor, torch.Tensor]],
|
|
34
|
+
baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
35
|
+
compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
36
|
+
reindex: torch.Tensor,
|
|
37
|
+
embedding_dim: int,
|
|
38
|
+
) -> EvalCompressionBenchmarkOutput:
|
|
39
|
+
times = []
|
|
40
|
+
fwd_times = []
|
|
41
|
+
bwd_times = []
|
|
42
|
+
torch.cuda.synchronize()
|
|
43
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
44
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
45
|
+
for indices, offsets in baseline_requests:
|
|
46
|
+
time = 0.0
|
|
47
|
+
start_event.record()
|
|
48
|
+
# forward
|
|
49
|
+
out = baseline_func(indices, offsets)
|
|
50
|
+
end_event.record()
|
|
51
|
+
torch.cuda.synchronize()
|
|
52
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
53
|
+
fwd_times.append(it_time)
|
|
54
|
+
time += it_time
|
|
55
|
+
|
|
56
|
+
grad = torch.rand_like(out)
|
|
57
|
+
start_event.record()
|
|
58
|
+
# backward
|
|
59
|
+
out.backward(grad)
|
|
60
|
+
end_event.record()
|
|
61
|
+
torch.cuda.synchronize()
|
|
62
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
63
|
+
bwd_times.append(it_time)
|
|
64
|
+
time += it_time
|
|
65
|
+
times.append(time)
|
|
66
|
+
|
|
67
|
+
avg = statistics.median(times)
|
|
68
|
+
fwd = statistics.median(fwd_times)
|
|
69
|
+
bwd = statistics.median(bwd_times)
|
|
70
|
+
|
|
71
|
+
times.clear()
|
|
72
|
+
fwd_times.clear()
|
|
73
|
+
bwd_times.clear()
|
|
74
|
+
reindex_times = []
|
|
75
|
+
|
|
76
|
+
torch.cuda.synchronize()
|
|
77
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
78
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
79
|
+
|
|
80
|
+
for indices, offsets in compressed_requests:
|
|
81
|
+
time = 0.0
|
|
82
|
+
start_event.record()
|
|
83
|
+
# forward
|
|
84
|
+
out = compressed_func(indices, offsets)
|
|
85
|
+
end_event.record()
|
|
86
|
+
torch.cuda.synchronize()
|
|
87
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
88
|
+
fwd_times.append(it_time)
|
|
89
|
+
time += it_time
|
|
90
|
+
|
|
91
|
+
start_event.record()
|
|
92
|
+
# reindex
|
|
93
|
+
out = out.reshape(-1, embedding_dim)
|
|
94
|
+
out = torch.ops.fbgemm.index_select_dim0(out, reindex)
|
|
95
|
+
end_event.record()
|
|
96
|
+
torch.cuda.synchronize()
|
|
97
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
98
|
+
reindex_times.append(it_time)
|
|
99
|
+
time += it_time
|
|
100
|
+
|
|
101
|
+
grad = torch.rand_like(out)
|
|
102
|
+
start_event.record()
|
|
103
|
+
# backward
|
|
104
|
+
out.backward(grad)
|
|
105
|
+
end_event.record()
|
|
106
|
+
torch.cuda.synchronize()
|
|
107
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
108
|
+
bwd_times.append(it_time)
|
|
109
|
+
time += it_time
|
|
110
|
+
times.append(time)
|
|
111
|
+
|
|
112
|
+
compressed_avg = statistics.median(times)
|
|
113
|
+
compressed_fwd = statistics.median(fwd_times)
|
|
114
|
+
reindex = statistics.median(reindex_times)
|
|
115
|
+
compressed_bwd = statistics.median(bwd_times)
|
|
116
|
+
|
|
117
|
+
return EvalCompressionBenchmarkOutput(
|
|
118
|
+
avg, fwd, bwd, compressed_avg, compressed_fwd, reindex, compressed_bwd
|
|
119
|
+
)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
haveAIBench = False
|
|
15
|
+
try:
|
|
16
|
+
from aibench_observer.utils.observer import emitMetric
|
|
17
|
+
|
|
18
|
+
haveAIBench = True
|
|
19
|
+
except Exception:
|
|
20
|
+
haveAIBench = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class BenchmarkReporter:
|
|
25
|
+
report: bool
|
|
26
|
+
logger: logging.Logger = logging.getLogger()
|
|
27
|
+
|
|
28
|
+
# pyre-ignore[3]
|
|
29
|
+
def __post_init__(self):
|
|
30
|
+
self.logger.setLevel(logging.INFO)
|
|
31
|
+
|
|
32
|
+
# pyre-ignore[2]
|
|
33
|
+
def emit_metric(self, **kwargs) -> None:
|
|
34
|
+
if self.report and haveAIBench:
|
|
35
|
+
self.logger.info(emitMetric(**kwargs))
|