fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -33,7 +33,7 @@ class PartiallyMaterializedTensor:
|
|
|
33
33
|
or use `full_tensor()` to get the full tensor (this could OOM).
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
|
-
def __init__(self, wrapped) -> None:
|
|
36
|
+
def __init__(self, wrapped, is_virtual: bool = False) -> None:
|
|
37
37
|
"""
|
|
38
38
|
Ensure caller loads the module before creating this object.
|
|
39
39
|
|
|
@@ -48,6 +48,7 @@ class PartiallyMaterializedTensor:
|
|
|
48
48
|
wrapped: torch.classes.fbgemm.KVTensorWrapper
|
|
49
49
|
"""
|
|
50
50
|
self._wrapped = wrapped
|
|
51
|
+
self._is_virtual = is_virtual
|
|
51
52
|
self._requires_grad = False
|
|
52
53
|
|
|
53
54
|
@property
|
|
@@ -57,6 +58,17 @@ class PartiallyMaterializedTensor:
|
|
|
57
58
|
"""
|
|
58
59
|
return self._wrapped
|
|
59
60
|
|
|
61
|
+
@property
|
|
62
|
+
def is_virtual(self):
|
|
63
|
+
"""
|
|
64
|
+
Indicate whether PMT is a virtual tensor.
|
|
65
|
+
This indicator is needed for checkpoint or publish.
|
|
66
|
+
They need to know wheether it is PMT for kvzch or for normal emb table
|
|
67
|
+
for kvzch, checkpoint and publish need to call all-gather to recalculate the correct
|
|
68
|
+
metadata of the ShardedTensor
|
|
69
|
+
"""
|
|
70
|
+
return self._is_virtual
|
|
71
|
+
|
|
60
72
|
@classmethod
|
|
61
73
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
62
74
|
if kwargs is None:
|
|
@@ -75,6 +87,18 @@ class PartiallyMaterializedTensor:
|
|
|
75
87
|
"""
|
|
76
88
|
return self._wrapped.narrow(dim, start, length)
|
|
77
89
|
|
|
90
|
+
def set_weights_and_ids(self, weights: torch.Tensor, ids: torch.Tensor) -> None:
|
|
91
|
+
self._wrapped.set_weights_and_ids(weights, ids)
|
|
92
|
+
|
|
93
|
+
def get_weights_by_ids(self, ids: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
return self._wrapped.get_weights_by_ids(ids)
|
|
95
|
+
|
|
96
|
+
def __reduce__(self):
|
|
97
|
+
return (
|
|
98
|
+
PartiallyMaterializedTensor,
|
|
99
|
+
(self._wrapped,),
|
|
100
|
+
)
|
|
101
|
+
|
|
78
102
|
def full_tensor(self) -> torch.Tensor:
|
|
79
103
|
"""
|
|
80
104
|
This loads the full tensor into memory (may OOM).
|
|
@@ -141,6 +165,8 @@ class PartiallyMaterializedTensor:
|
|
|
141
165
|
|
|
142
166
|
@property
|
|
143
167
|
def dtype(self) -> torch.dtype:
|
|
168
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
169
|
+
return self._wrapped.dtype
|
|
144
170
|
mapping = {"c10::Half": "half"}
|
|
145
171
|
dtype_str: str = self._wrapped.dtype_str
|
|
146
172
|
dtype_str = mapping.get(dtype_str, dtype_str)
|
|
@@ -151,6 +177,8 @@ class PartiallyMaterializedTensor:
|
|
|
151
177
|
|
|
152
178
|
@property
|
|
153
179
|
def device(self) -> torch.device:
|
|
180
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
181
|
+
return self._wrapped.device
|
|
154
182
|
device_str: str = self._wrapped.device_str
|
|
155
183
|
device = torch.device(device_str)
|
|
156
184
|
assert isinstance(device, torch.device)
|
|
@@ -158,11 +186,11 @@ class PartiallyMaterializedTensor:
|
|
|
158
186
|
|
|
159
187
|
@property
|
|
160
188
|
def layout(self) -> torch.layout:
|
|
161
|
-
|
|
189
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
190
|
+
return self._wrapped.layout
|
|
162
191
|
layout_str_mapping = {
|
|
163
192
|
"SparseCsr": "sparse_csr",
|
|
164
193
|
"Strided": "strided",
|
|
165
|
-
"SparseCsr": "sparse_csr",
|
|
166
194
|
"SparseCsc": "sparse_csc",
|
|
167
195
|
"Jagged": "jagged",
|
|
168
196
|
}
|
|
@@ -220,6 +248,9 @@ class PartiallyMaterializedTensor:
|
|
|
220
248
|
|
|
221
249
|
return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
|
|
222
250
|
|
|
251
|
+
def get_kvtensor_serializable_metadata(self) -> list[str]:
|
|
252
|
+
return self._wrapped.get_kvtensor_serializable_metadata()
|
|
253
|
+
|
|
223
254
|
def __hash__(self):
|
|
224
255
|
return id(self)
|
|
225
256
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from typing import List, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import fbgemm_gpu # noqa F401
|
|
17
|
+
import torch # usort:skip
|
|
18
|
+
|
|
19
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
|
|
20
|
+
BatchParams,
|
|
21
|
+
IndicesParams,
|
|
22
|
+
PoolingParams,
|
|
23
|
+
TBEDataConfig,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
open_source: bool = False
|
|
27
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
28
|
+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
|
|
29
|
+
|
|
30
|
+
if open_source:
|
|
31
|
+
from fbgemm_gpu.utils import FileStore
|
|
32
|
+
|
|
33
|
+
else:
|
|
34
|
+
try:
|
|
35
|
+
from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
|
|
36
|
+
|
|
37
|
+
torch.ops.load_library(
|
|
38
|
+
"//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_estimator"
|
|
39
|
+
)
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TBEBenchmarkParamsReporter:
|
|
45
|
+
"""
|
|
46
|
+
TBEBenchmarkParamsReporter is responsible for extracting and reporting the configuration data of TBE processes.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
report_interval: int,
|
|
52
|
+
report_iter_start: int = 0,
|
|
53
|
+
report_iter_end: int = -1,
|
|
54
|
+
bucket: Optional[str] = None,
|
|
55
|
+
path_prefix: Optional[str] = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Initializes the TBEBenchmarkParamsReporter with the specified parameters.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
report_interval (int): The interval at which reports are generated.
|
|
62
|
+
report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
|
|
63
|
+
report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
|
|
64
|
+
bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
|
|
65
|
+
path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
assert report_interval > 0, "report_interval must be greater than 0"
|
|
69
|
+
assert (
|
|
70
|
+
report_iter_start >= 0
|
|
71
|
+
), "report_iter_start must be greater than or equal to 0"
|
|
72
|
+
assert (
|
|
73
|
+
report_iter_end >= -1
|
|
74
|
+
), "report_iter_end must be greater than or equal to -1"
|
|
75
|
+
assert (
|
|
76
|
+
report_iter_end == -1 or report_iter_start <= report_iter_end
|
|
77
|
+
), "report_iter_start must be less than or equal to report_iter_end"
|
|
78
|
+
|
|
79
|
+
self.report_interval = report_interval
|
|
80
|
+
self.report_iter_start = report_iter_start
|
|
81
|
+
self.report_iter_end = report_iter_end
|
|
82
|
+
|
|
83
|
+
if path_prefix is not None and path_prefix.endswith("/"):
|
|
84
|
+
path_prefix = path_prefix[:-1]
|
|
85
|
+
|
|
86
|
+
self.path_prefix = path_prefix
|
|
87
|
+
|
|
88
|
+
default_bucket = "/tmp" if open_source else "tlparse_reports"
|
|
89
|
+
bucket = (
|
|
90
|
+
bucket
|
|
91
|
+
if bucket is not None
|
|
92
|
+
else os.environ.get("FBGEMM_TBE_REPORTING_BUCKET", default_bucket)
|
|
93
|
+
)
|
|
94
|
+
self.filestore = FileStore(bucket)
|
|
95
|
+
|
|
96
|
+
if self.path_prefix is not None and not self.filestore.exists(self.path_prefix):
|
|
97
|
+
self.filestore.create_directory(self.path_prefix)
|
|
98
|
+
|
|
99
|
+
self.logger: logging.Logger = logging.getLogger(__name__)
|
|
100
|
+
self.logger.setLevel(logging.INFO)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def create(cls) -> "TBEBenchmarkParamsReporter":
|
|
104
|
+
"""
|
|
105
|
+
This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
|
|
106
|
+
|
|
107
|
+
If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
|
|
108
|
+
- Reports input parameters (TBEDataConfig).
|
|
109
|
+
- Writes the output as a JSON file.
|
|
110
|
+
|
|
111
|
+
Additionally, the following environment variables are considered:
|
|
112
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
|
|
113
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
|
|
114
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
|
|
115
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
|
|
119
|
+
"""
|
|
120
|
+
report_interval = int(
|
|
121
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
|
|
122
|
+
)
|
|
123
|
+
report_iter_start = int(
|
|
124
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
|
|
125
|
+
)
|
|
126
|
+
report_iter_end = int(
|
|
127
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
|
|
128
|
+
)
|
|
129
|
+
bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
|
|
130
|
+
path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
|
|
131
|
+
|
|
132
|
+
return cls(
|
|
133
|
+
report_interval=report_interval,
|
|
134
|
+
report_iter_start=report_iter_start,
|
|
135
|
+
report_iter_end=report_iter_end,
|
|
136
|
+
bucket=bucket,
|
|
137
|
+
path_prefix=path_prefix,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def extract_Ls(
|
|
141
|
+
self,
|
|
142
|
+
bag_sizes: List[int],
|
|
143
|
+
Bs: List[int],
|
|
144
|
+
) -> List[float]:
|
|
145
|
+
Ls = []
|
|
146
|
+
start = 0
|
|
147
|
+
for b in Bs:
|
|
148
|
+
end = start + b
|
|
149
|
+
avg_L = sum(bag_sizes[start:end]) / b if b > 0 else 0
|
|
150
|
+
start = end
|
|
151
|
+
Ls.append(avg_L)
|
|
152
|
+
return Ls
|
|
153
|
+
|
|
154
|
+
def extract_params(
|
|
155
|
+
self,
|
|
156
|
+
feature_rows: torch.Tensor,
|
|
157
|
+
feature_dims: torch.Tensor,
|
|
158
|
+
indices: torch.Tensor,
|
|
159
|
+
offsets: torch.Tensor,
|
|
160
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
161
|
+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
|
|
162
|
+
Es: Optional[List[int]] = None,
|
|
163
|
+
Ds: Optional[List[int]] = None,
|
|
164
|
+
embedding_specs: Optional[List[Tuple[int, int]]] = None,
|
|
165
|
+
feature_table_map: Optional[List[int]] = None,
|
|
166
|
+
) -> TBEDataConfig:
|
|
167
|
+
"""
|
|
168
|
+
Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
feature_rows (torch.Tensor): Number of rows in each feature.
|
|
172
|
+
feature_dims (torch.Tensor): Number of dimensions in each feature.
|
|
173
|
+
indices (torch.Tensor): The input indices tensor.
|
|
174
|
+
offsets (torch.Tensor): The input offsets tensor.
|
|
175
|
+
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
|
|
176
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
TBEDataConfig: The configuration data for TBE benchmarking.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
Es = feature_rows.tolist()
|
|
183
|
+
Ds = feature_dims.tolist()
|
|
184
|
+
|
|
185
|
+
assert len(Es) == len(
|
|
186
|
+
Ds
|
|
187
|
+
), "feature_rows and feature_dims must have the same length"
|
|
188
|
+
|
|
189
|
+
# Transfer indices back to CPU for EEG analysis
|
|
190
|
+
indices_cpu = indices.cpu()
|
|
191
|
+
|
|
192
|
+
# Set T to be the number of features we are looking at
|
|
193
|
+
T = len(Ds)
|
|
194
|
+
# Set E to be the mean of the rowcounts to avoid biasing
|
|
195
|
+
E = (
|
|
196
|
+
Es[0]
|
|
197
|
+
if len(set(Es)) == 1
|
|
198
|
+
else torch.ceil(
|
|
199
|
+
torch.mean(torch.tensor(feature_rows, dtype=torch.float))
|
|
200
|
+
).item()
|
|
201
|
+
)
|
|
202
|
+
# Set mixed_dim to be True if there are multiple dims
|
|
203
|
+
mixed_dim = len(set(Ds)) > 1
|
|
204
|
+
# Set D to be the mean of the dims to avoid biasing
|
|
205
|
+
D = (
|
|
206
|
+
Ds[0]
|
|
207
|
+
if not mixed_dim
|
|
208
|
+
else torch.ceil(
|
|
209
|
+
torch.mean(torch.tensor(feature_dims, dtype=torch.float))
|
|
210
|
+
).item()
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Compute indices distribution parameters
|
|
214
|
+
heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
|
|
215
|
+
indices_cpu
|
|
216
|
+
)
|
|
217
|
+
indices_params = IndicesParams(
|
|
218
|
+
heavy_hitters, q, s, indices.dtype, offsets.dtype
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Compute batch parameters
|
|
222
|
+
B = int((offsets.numel() - 1) // T)
|
|
223
|
+
Bs = (
|
|
224
|
+
[sum(b_per_rank) for b_per_rank in batch_size_per_feature_per_rank]
|
|
225
|
+
if batch_size_per_feature_per_rank
|
|
226
|
+
else [B] * T
|
|
227
|
+
)
|
|
228
|
+
batch_params = BatchParams(
|
|
229
|
+
B=B,
|
|
230
|
+
sigma_B=(
|
|
231
|
+
int(
|
|
232
|
+
torch.ceil(
|
|
233
|
+
torch.std(
|
|
234
|
+
torch.tensor(
|
|
235
|
+
[
|
|
236
|
+
b
|
|
237
|
+
for bs in batch_size_per_feature_per_rank
|
|
238
|
+
for b in bs
|
|
239
|
+
]
|
|
240
|
+
).float()
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
if batch_size_per_feature_per_rank
|
|
245
|
+
else None
|
|
246
|
+
),
|
|
247
|
+
vbe_distribution=("normal" if batch_size_per_feature_per_rank else None),
|
|
248
|
+
vbe_num_ranks=(
|
|
249
|
+
len(batch_size_per_feature_per_rank)
|
|
250
|
+
if batch_size_per_feature_per_rank
|
|
251
|
+
else None
|
|
252
|
+
),
|
|
253
|
+
Bs=Bs,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Compute pooling parameters
|
|
257
|
+
bag_sizes = offsets[1:] - offsets[:-1]
|
|
258
|
+
if batch_size_per_feature_per_rank is None:
|
|
259
|
+
_B = int(bag_sizes.numel() // T)
|
|
260
|
+
assert _B == Bs[0], f"Expected constant batch size {Bs[0]} but got {_B}"
|
|
261
|
+
mixed_bag_sizes = len(set(bag_sizes)) > 1
|
|
262
|
+
pooling_params = PoolingParams(
|
|
263
|
+
L=(
|
|
264
|
+
int(torch.ceil(torch.mean(bag_sizes.float())))
|
|
265
|
+
if mixed_bag_sizes
|
|
266
|
+
else int(bag_sizes[0])
|
|
267
|
+
),
|
|
268
|
+
sigma_L=(
|
|
269
|
+
int(torch.ceil(torch.std(bag_sizes.float())))
|
|
270
|
+
if mixed_bag_sizes
|
|
271
|
+
else None
|
|
272
|
+
),
|
|
273
|
+
length_distribution=("normal" if mixed_bag_sizes else None),
|
|
274
|
+
Ls=self.extract_Ls(bag_sizes.tolist(), Bs),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
return TBEDataConfig(
|
|
278
|
+
T=T,
|
|
279
|
+
E=E,
|
|
280
|
+
D=D,
|
|
281
|
+
mixed_dim=mixed_dim,
|
|
282
|
+
weighted=(per_sample_weights is not None),
|
|
283
|
+
batch_params=batch_params,
|
|
284
|
+
indices_params=indices_params,
|
|
285
|
+
pooling_params=pooling_params,
|
|
286
|
+
use_cpu=(not torch.cuda.is_available()),
|
|
287
|
+
Es=Es,
|
|
288
|
+
Ds=Ds,
|
|
289
|
+
embedding_specs=embedding_specs,
|
|
290
|
+
feature_table_map=feature_table_map,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def report_stats(
|
|
294
|
+
self,
|
|
295
|
+
feature_rows: torch.Tensor,
|
|
296
|
+
feature_dims: torch.Tensor,
|
|
297
|
+
iteration: int,
|
|
298
|
+
indices: torch.Tensor,
|
|
299
|
+
offsets: torch.Tensor,
|
|
300
|
+
op_id: str = "",
|
|
301
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
302
|
+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
|
|
303
|
+
embedding_specs: Optional[List[Tuple[int, int]]] = None,
|
|
304
|
+
feature_table_map: Optional[List[int]] = None,
|
|
305
|
+
) -> None:
|
|
306
|
+
"""
|
|
307
|
+
Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
feature_rows (torch.Tensor): Number of rows in each feature.
|
|
311
|
+
feature_dims (torch.Tensor): Number of dimensions in each feature.
|
|
312
|
+
iteration (int): The current iteration number.
|
|
313
|
+
indices (torch.Tensor): The input indices tensor.
|
|
314
|
+
offsets (torch.Tensor): The input offsets tensor.
|
|
315
|
+
op_id (str, optional): The operation identifier. Defaults to an empty string.
|
|
316
|
+
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
|
|
317
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
|
|
318
|
+
embedding_specs (Optional[List[Tuple[int, int]]]): Embedding specs. Defaults to None.
|
|
319
|
+
feature_table_map (Optional[List[int]], optional): Feature table map. Defaults to None.
|
|
320
|
+
"""
|
|
321
|
+
if (
|
|
322
|
+
(iteration - self.report_iter_start) % self.report_interval == 0
|
|
323
|
+
and (iteration >= self.report_iter_start)
|
|
324
|
+
and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
|
|
325
|
+
):
|
|
326
|
+
# If indices tensor is empty (indices.numel() == 0), skip reporting
|
|
327
|
+
# TODO: Remove this once we have a better way to handle empty indices tensors
|
|
328
|
+
if indices.numel() == 0:
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
# Extract TBE config
|
|
332
|
+
config = self.extract_params(
|
|
333
|
+
feature_rows=feature_rows,
|
|
334
|
+
feature_dims=feature_dims,
|
|
335
|
+
indices=indices,
|
|
336
|
+
offsets=offsets,
|
|
337
|
+
per_sample_weights=per_sample_weights,
|
|
338
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
339
|
+
Es=feature_rows.tolist(),
|
|
340
|
+
Ds=feature_dims.tolist(),
|
|
341
|
+
embedding_specs=embedding_specs,
|
|
342
|
+
feature_table_map=feature_table_map,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Write the TBE config to FileStore
|
|
346
|
+
self.filestore.write(
|
|
347
|
+
f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json",
|
|
348
|
+
io.BytesIO(json.dumps(config.dict(), indent=2).encode()),
|
|
349
|
+
)
|
fbgemm_gpu/tbe/utils/offsets.py
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
|
-
from typing import Callable, Optional
|
|
9
|
+
from typing import Callable, Optional
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -21,9 +21,9 @@ def get_table_batched_offsets_from_dense(
|
|
|
21
21
|
L: Optional[int] = None,
|
|
22
22
|
total_B: Optional[int] = None,
|
|
23
23
|
use_cpu: bool = False,
|
|
24
|
-
) ->
|
|
24
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
25
25
|
if L is None and total_B is None:
|
|
26
|
-
|
|
26
|
+
T, B, L = merged_indices.size()
|
|
27
27
|
total_B = T * B
|
|
28
28
|
# pyre-fixme[6]: For 1st argument expected `Union[Sequence[SupportsIndex],
|
|
29
29
|
# SupportsIndex]` but got `Optional[int]`.
|
|
@@ -37,8 +37,8 @@ def get_table_batched_offsets_from_dense(
|
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
def get_offsets_from_dense(indices: torch.Tensor) ->
|
|
41
|
-
|
|
40
|
+
def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
+
B, L = indices.size()
|
|
42
42
|
return (
|
|
43
43
|
indices.contiguous().view(-1),
|
|
44
44
|
torch.tensor(
|
|
@@ -54,7 +54,7 @@ def b_indices(
|
|
|
54
54
|
use_cpu: bool = False,
|
|
55
55
|
do_pooling: bool = True,
|
|
56
56
|
) -> torch.Tensor:
|
|
57
|
-
|
|
57
|
+
indices, offsets = get_offsets_from_dense(x)
|
|
58
58
|
if do_pooling:
|
|
59
59
|
return b(
|
|
60
60
|
to_device(indices, use_cpu),
|
fbgemm_gpu/tbe/utils/quantize.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
# pyre-ignore-all-errors[61]
|
|
9
9
|
|
|
10
|
-
from typing import Optional
|
|
10
|
+
from typing import Optional
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
|
|
@@ -22,7 +22,7 @@ def quantize_embs(
|
|
|
22
22
|
weight: torch.Tensor,
|
|
23
23
|
weight_ty: SparseType,
|
|
24
24
|
fp8_config: Optional[FP8QuantizationConfig] = None,
|
|
25
|
-
) ->
|
|
25
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
26
26
|
weight = weight.detach()
|
|
27
27
|
if weight_ty == SparseType.FP32:
|
|
28
28
|
q_weight = weight.float()
|
|
@@ -91,7 +91,7 @@ def dequantize_embs(
|
|
|
91
91
|
th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)
|
|
92
92
|
|
|
93
93
|
if weight_ty == SparseType.INT4:
|
|
94
|
-
|
|
94
|
+
E, D_2 = th_weights.shape
|
|
95
95
|
D = D_2 * 2
|
|
96
96
|
|
|
97
97
|
def comp(i: int) -> torch.Tensor:
|
|
@@ -109,7 +109,7 @@ def dequantize_embs(
|
|
|
109
109
|
return to_device(torch.tensor(comps), use_cpu)
|
|
110
110
|
|
|
111
111
|
elif weight_ty == SparseType.INT2:
|
|
112
|
-
|
|
112
|
+
E, D_4 = th_weights.shape
|
|
113
113
|
D = D_4 * 4
|
|
114
114
|
|
|
115
115
|
# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
|
|
@@ -129,7 +129,7 @@ def dequantize_embs(
|
|
|
129
129
|
return to_device(torch.tensor(comps), use_cpu)
|
|
130
130
|
|
|
131
131
|
elif weight_ty == SparseType.INT8:
|
|
132
|
-
|
|
132
|
+
E, D = th_weights.shape
|
|
133
133
|
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
|
|
134
134
|
torch.float32
|
|
135
135
|
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
@@ -177,7 +177,7 @@ def fake_quantize_embs(
|
|
|
177
177
|
)
|
|
178
178
|
|
|
179
179
|
if weight_ty == SparseType.INT4:
|
|
180
|
-
|
|
180
|
+
E, D_2 = th_weights.shape
|
|
181
181
|
D = D_2 * 2
|
|
182
182
|
|
|
183
183
|
def comp(i: int) -> torch.Tensor:
|
|
@@ -195,7 +195,7 @@ def fake_quantize_embs(
|
|
|
195
195
|
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
196
196
|
|
|
197
197
|
elif weight_ty == SparseType.INT2:
|
|
198
|
-
|
|
198
|
+
E, D_4 = th_weights.shape
|
|
199
199
|
D = D_4 * 4
|
|
200
200
|
|
|
201
201
|
# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
|
|
@@ -215,7 +215,7 @@ def fake_quantize_embs(
|
|
|
215
215
|
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
216
216
|
|
|
217
217
|
elif weight_ty == SparseType.INT8:
|
|
218
|
-
|
|
218
|
+
E, D = th_weights.shape
|
|
219
219
|
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
|
|
220
220
|
torch.float32
|
|
221
221
|
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|