fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import functools
|
|
12
|
+
from typing import Optional, Union
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
_HANDLED_FUNCTIONS = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def implements(torch_function):
|
|
20
|
+
def decorator(func):
|
|
21
|
+
functools.update_wrapper(func, torch_function)
|
|
22
|
+
_HANDLED_FUNCTIONS[torch_function] = func
|
|
23
|
+
return func
|
|
24
|
+
|
|
25
|
+
return decorator
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PartiallyMaterializedTensor:
|
|
29
|
+
"""
|
|
30
|
+
A tensor-like object that represents a partially materialized tensor in memory.
|
|
31
|
+
|
|
32
|
+
Caller can use `narrow()` to get a view of the backing storage,
|
|
33
|
+
or use `full_tensor()` to get the full tensor (this could OOM).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, wrapped, is_virtual: bool = False) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Ensure caller loads the module before creating this object.
|
|
39
|
+
|
|
40
|
+
```
|
|
41
|
+
load_torch_module(
|
|
42
|
+
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
|
|
43
|
+
)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
|
|
48
|
+
wrapped: torch.classes.fbgemm.KVTensorWrapper
|
|
49
|
+
"""
|
|
50
|
+
self._wrapped = wrapped
|
|
51
|
+
self._is_virtual = is_virtual
|
|
52
|
+
self._requires_grad = False
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def wrapped(self):
|
|
56
|
+
"""
|
|
57
|
+
Get the wrapped extension class for C++ interop.
|
|
58
|
+
"""
|
|
59
|
+
return self._wrapped
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def is_virtual(self):
|
|
63
|
+
"""
|
|
64
|
+
Indicate whether PMT is a virtual tensor.
|
|
65
|
+
This indicator is needed for checkpoint or publish.
|
|
66
|
+
They need to know wheether it is PMT for kvzch or for normal emb table
|
|
67
|
+
for kvzch, checkpoint and publish need to call all-gather to recalculate the correct
|
|
68
|
+
metadata of the ShardedTensor
|
|
69
|
+
"""
|
|
70
|
+
return self._is_virtual
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
74
|
+
if kwargs is None:
|
|
75
|
+
kwargs = {}
|
|
76
|
+
if func not in _HANDLED_FUNCTIONS:
|
|
77
|
+
return NotImplemented
|
|
78
|
+
return _HANDLED_FUNCTIONS[func](cls, *args, **kwargs)
|
|
79
|
+
|
|
80
|
+
@implements(torch.narrow)
|
|
81
|
+
def narrow(self, dim: int, start: int, length: int) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
This loads a narrowed view of the backing storage.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
a torch tensor
|
|
87
|
+
"""
|
|
88
|
+
return self._wrapped.narrow(dim, start, length)
|
|
89
|
+
|
|
90
|
+
def set_weights_and_ids(self, weights: torch.Tensor, ids: torch.Tensor) -> None:
|
|
91
|
+
self._wrapped.set_weights_and_ids(weights, ids)
|
|
92
|
+
|
|
93
|
+
def get_weights_by_ids(self, ids: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
return self._wrapped.get_weights_by_ids(ids)
|
|
95
|
+
|
|
96
|
+
def __reduce__(self):
|
|
97
|
+
return (
|
|
98
|
+
PartiallyMaterializedTensor,
|
|
99
|
+
(self._wrapped,),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def full_tensor(self) -> torch.Tensor:
|
|
103
|
+
"""
|
|
104
|
+
This loads the full tensor into memory (may OOM).
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
a torch tensor
|
|
108
|
+
"""
|
|
109
|
+
return self.narrow(0, 0, self.size(0))
|
|
110
|
+
|
|
111
|
+
@implements(torch.detach)
|
|
112
|
+
def detach(self) -> PartiallyMaterializedTensor:
|
|
113
|
+
self._requires_grad = False
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
def to(self, *args, **kwargs) -> PartiallyMaterializedTensor:
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
def is_floating_point(self):
|
|
120
|
+
# this class only deals with embedding vectors
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
@implements(torch._has_compatible_shallow_copy_type)
|
|
124
|
+
def _has_compatible_shallow_copy_type(*args, **kwargs):
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
def requires_grad_(self, requires_grad=True) -> PartiallyMaterializedTensor:
|
|
128
|
+
self._requires_grad = requires_grad
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def requires_grad(self) -> bool:
|
|
133
|
+
return self._requires_grad
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def grad(self) -> Optional[torch.Tensor]:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def is_leaf(self) -> bool:
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def shape(self) -> torch.Size:
|
|
145
|
+
"""
|
|
146
|
+
Shape of the full tensor.
|
|
147
|
+
"""
|
|
148
|
+
return torch.Size(self._wrapped.shape)
|
|
149
|
+
|
|
150
|
+
def size(self, dim: Optional[int] = None) -> Union[int, torch.Size]:
|
|
151
|
+
sz = self.shape
|
|
152
|
+
if dim is None:
|
|
153
|
+
return sz
|
|
154
|
+
if dim >= len(sz) or dim < 0:
|
|
155
|
+
raise IndexError(
|
|
156
|
+
f"Dimension out of range (expected to be {len(sz)}, but got {dim})"
|
|
157
|
+
)
|
|
158
|
+
return sz[dim]
|
|
159
|
+
|
|
160
|
+
def is_contiguous(self):
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
def is_pinned(self):
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def dtype(self) -> torch.dtype:
|
|
168
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
169
|
+
return self._wrapped.dtype
|
|
170
|
+
mapping = {"c10::Half": "half"}
|
|
171
|
+
dtype_str: str = self._wrapped.dtype_str
|
|
172
|
+
dtype_str = mapping.get(dtype_str, dtype_str)
|
|
173
|
+
|
|
174
|
+
dtype = getattr(torch, dtype_str)
|
|
175
|
+
assert isinstance(dtype, torch.dtype)
|
|
176
|
+
return dtype
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def device(self) -> torch.device:
|
|
180
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
181
|
+
return self._wrapped.device
|
|
182
|
+
device_str: str = self._wrapped.device_str
|
|
183
|
+
device = torch.device(device_str)
|
|
184
|
+
assert isinstance(device, torch.device)
|
|
185
|
+
return device
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def layout(self) -> torch.layout:
|
|
189
|
+
if isinstance(self._wrapped, torch.Tensor):
|
|
190
|
+
return self._wrapped.layout
|
|
191
|
+
layout_str_mapping = {
|
|
192
|
+
"SparseCsr": "sparse_csr",
|
|
193
|
+
"Strided": "strided",
|
|
194
|
+
"SparseCsc": "sparse_csc",
|
|
195
|
+
"Jagged": "jagged",
|
|
196
|
+
}
|
|
197
|
+
layout_str: str = self._wrapped.layout_str
|
|
198
|
+
layout_str = layout_str_mapping[layout_str]
|
|
199
|
+
layout = getattr(torch, layout_str)
|
|
200
|
+
assert isinstance(layout, torch.layout)
|
|
201
|
+
return layout
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def __class__(self):
|
|
205
|
+
# this is a hack to avoid assertion error in torch.nn.Module.register_parameter()
|
|
206
|
+
return torch.nn.Parameter
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def grad_fn(self):
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
def view(self, *args, **kwargs):
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
def is_meta(*args, **kwargs):
|
|
216
|
+
return False
|
|
217
|
+
|
|
218
|
+
def copy_(self, src, non_blocking=False):
|
|
219
|
+
# noop
|
|
220
|
+
pass
|
|
221
|
+
|
|
222
|
+
def numel(self):
|
|
223
|
+
return torch.tensor(self.shape).prod().item()
|
|
224
|
+
|
|
225
|
+
def nelement(self):
|
|
226
|
+
return torch.tensor(self.shape).prod().item()
|
|
227
|
+
|
|
228
|
+
def element_size(self):
|
|
229
|
+
return torch.tensor([], dtype=self.dtype).element_size()
|
|
230
|
+
|
|
231
|
+
def __deepcopy__(self, memo):
|
|
232
|
+
# torch.classes.fbgemm.KVTensorWrapper doesn't support deepcopy
|
|
233
|
+
new_obj = PartiallyMaterializedTensor(self._wrapped)
|
|
234
|
+
memo[id(self)] = new_obj
|
|
235
|
+
return new_obj
|
|
236
|
+
|
|
237
|
+
def required_grad(self) -> bool:
|
|
238
|
+
return True
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def is_quantized(self) -> bool:
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
@implements(torch.equal)
|
|
245
|
+
def __eq__(self, tensor1, tensor2, **kwargs):
|
|
246
|
+
if not isinstance(tensor2, PartiallyMaterializedTensor):
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
|
|
250
|
+
|
|
251
|
+
def get_kvtensor_serializable_metadata(self) -> list[str]:
|
|
252
|
+
return self._wrapped.get_kvtensor_serializable_metadata()
|
|
253
|
+
|
|
254
|
+
def __hash__(self):
|
|
255
|
+
return id(self)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def is_mps(self):
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def is_sparse(self):
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
@implements(torch.isclose)
|
|
266
|
+
def isclose(self, tensor1, tensor2, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
267
|
+
return torch.isclose(
|
|
268
|
+
tensor1.full_tensor(),
|
|
269
|
+
tensor2.full_tensor(),
|
|
270
|
+
rtol=rtol,
|
|
271
|
+
atol=atol,
|
|
272
|
+
equal_nan=equal_nan,
|
|
273
|
+
)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
import fbgemm_gpu # noqa F401
|
|
17
|
+
import torch # usort:skip
|
|
18
|
+
|
|
19
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
|
|
20
|
+
BatchParams,
|
|
21
|
+
IndicesParams,
|
|
22
|
+
PoolingParams,
|
|
23
|
+
TBEDataConfig,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
open_source: bool = False
|
|
27
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
28
|
+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
|
|
29
|
+
|
|
30
|
+
if open_source:
|
|
31
|
+
from fbgemm_gpu.utils import FileStore
|
|
32
|
+
|
|
33
|
+
else:
|
|
34
|
+
try:
|
|
35
|
+
from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
|
|
36
|
+
|
|
37
|
+
torch.ops.load_library(
|
|
38
|
+
"//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_estimator"
|
|
39
|
+
)
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TBEBenchmarkParamsReporter:
|
|
45
|
+
"""
|
|
46
|
+
TBEBenchmarkParamsReporter is responsible for extracting and reporting the configuration data of TBE processes.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
report_interval: int,
|
|
52
|
+
report_iter_start: int = 0,
|
|
53
|
+
report_iter_end: int = -1,
|
|
54
|
+
bucket: Optional[str] = None,
|
|
55
|
+
path_prefix: Optional[str] = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Initializes the TBEBenchmarkParamsReporter with the specified parameters.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
report_interval (int): The interval at which reports are generated.
|
|
62
|
+
report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
|
|
63
|
+
report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
|
|
64
|
+
bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
|
|
65
|
+
path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
assert report_interval > 0, "report_interval must be greater than 0"
|
|
69
|
+
assert (
|
|
70
|
+
report_iter_start >= 0
|
|
71
|
+
), "report_iter_start must be greater than or equal to 0"
|
|
72
|
+
assert (
|
|
73
|
+
report_iter_end >= -1
|
|
74
|
+
), "report_iter_end must be greater than or equal to -1"
|
|
75
|
+
assert (
|
|
76
|
+
report_iter_end == -1 or report_iter_start <= report_iter_end
|
|
77
|
+
), "report_iter_start must be less than or equal to report_iter_end"
|
|
78
|
+
|
|
79
|
+
self.report_interval = report_interval
|
|
80
|
+
self.report_iter_start = report_iter_start
|
|
81
|
+
self.report_iter_end = report_iter_end
|
|
82
|
+
|
|
83
|
+
if path_prefix is not None and path_prefix.endswith("/"):
|
|
84
|
+
path_prefix = path_prefix[:-1]
|
|
85
|
+
|
|
86
|
+
self.path_prefix = path_prefix
|
|
87
|
+
|
|
88
|
+
default_bucket = "/tmp" if open_source else "tlparse_reports"
|
|
89
|
+
bucket = (
|
|
90
|
+
bucket
|
|
91
|
+
if bucket is not None
|
|
92
|
+
else os.environ.get("FBGEMM_TBE_REPORTING_BUCKET", default_bucket)
|
|
93
|
+
)
|
|
94
|
+
self.filestore = FileStore(bucket)
|
|
95
|
+
|
|
96
|
+
if self.path_prefix is not None and not self.filestore.exists(self.path_prefix):
|
|
97
|
+
self.filestore.create_directory(self.path_prefix)
|
|
98
|
+
|
|
99
|
+
self.logger: logging.Logger = logging.getLogger(__name__)
|
|
100
|
+
self.logger.setLevel(logging.INFO)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def create(cls) -> "TBEBenchmarkParamsReporter":
|
|
104
|
+
"""
|
|
105
|
+
This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
|
|
106
|
+
|
|
107
|
+
If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
|
|
108
|
+
- Reports input parameters (TBEDataConfig).
|
|
109
|
+
- Writes the output as a JSON file.
|
|
110
|
+
|
|
111
|
+
Additionally, the following environment variables are considered:
|
|
112
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
|
|
113
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
|
|
114
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
|
|
115
|
+
- `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
|
|
119
|
+
"""
|
|
120
|
+
report_interval = int(
|
|
121
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
|
|
122
|
+
)
|
|
123
|
+
report_iter_start = int(
|
|
124
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
|
|
125
|
+
)
|
|
126
|
+
report_iter_end = int(
|
|
127
|
+
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
|
|
128
|
+
)
|
|
129
|
+
bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
|
|
130
|
+
path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
|
|
131
|
+
|
|
132
|
+
return cls(
|
|
133
|
+
report_interval=report_interval,
|
|
134
|
+
report_iter_start=report_iter_start,
|
|
135
|
+
report_iter_end=report_iter_end,
|
|
136
|
+
bucket=bucket,
|
|
137
|
+
path_prefix=path_prefix,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def extract_params(
|
|
141
|
+
self,
|
|
142
|
+
feature_rows: torch.Tensor,
|
|
143
|
+
feature_dims: torch.Tensor,
|
|
144
|
+
indices: torch.Tensor,
|
|
145
|
+
offsets: torch.Tensor,
|
|
146
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
147
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
148
|
+
) -> TBEDataConfig:
|
|
149
|
+
"""
|
|
150
|
+
Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
feature_rows (torch.Tensor): Number of rows in each feature.
|
|
154
|
+
feature_dims (torch.Tensor): Number of dimensions in each feature.
|
|
155
|
+
indices (torch.Tensor): The input indices tensor.
|
|
156
|
+
offsets (torch.Tensor): The input offsets tensor.
|
|
157
|
+
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
|
|
158
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
TBEDataConfig: The configuration data for TBE benchmarking.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
Es = feature_rows.tolist()
|
|
165
|
+
Ds = feature_dims.tolist()
|
|
166
|
+
|
|
167
|
+
assert len(Es) == len(
|
|
168
|
+
Ds
|
|
169
|
+
), "feature_rows and feature_dims must have the same length"
|
|
170
|
+
|
|
171
|
+
# Transfer indices back to CPU for EEG analysis
|
|
172
|
+
indices_cpu = indices.cpu()
|
|
173
|
+
|
|
174
|
+
# Set T to be the number of features we are looking at
|
|
175
|
+
T = len(Ds)
|
|
176
|
+
# Set E to be the mean of the rowcounts to avoid biasing
|
|
177
|
+
E = (
|
|
178
|
+
Es[0]
|
|
179
|
+
if len(set(Es)) == 1
|
|
180
|
+
else torch.ceil(
|
|
181
|
+
torch.mean(torch.tensor(feature_rows, dtype=torch.float))
|
|
182
|
+
).item()
|
|
183
|
+
)
|
|
184
|
+
# Set mixed_dim to be True if there are multiple dims
|
|
185
|
+
mixed_dim = len(set(Ds)) > 1
|
|
186
|
+
# Set D to be the mean of the dims to avoid biasing
|
|
187
|
+
D = (
|
|
188
|
+
Ds[0]
|
|
189
|
+
if not mixed_dim
|
|
190
|
+
else torch.ceil(
|
|
191
|
+
torch.mean(torch.tensor(feature_dims, dtype=torch.float))
|
|
192
|
+
).item()
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Compute indices distribution parameters
|
|
196
|
+
heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
|
|
197
|
+
indices_cpu
|
|
198
|
+
)
|
|
199
|
+
indices_params = IndicesParams(
|
|
200
|
+
heavy_hitters, q, s, indices.dtype, offsets.dtype
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Compute batch parameters
|
|
204
|
+
batch_params = BatchParams(
|
|
205
|
+
B=int((offsets.numel() - 1) // T),
|
|
206
|
+
sigma_B=(
|
|
207
|
+
int(
|
|
208
|
+
torch.ceil(
|
|
209
|
+
torch.std(
|
|
210
|
+
torch.tensor(
|
|
211
|
+
[
|
|
212
|
+
b
|
|
213
|
+
for bs in batch_size_per_feature_per_rank
|
|
214
|
+
for b in bs
|
|
215
|
+
]
|
|
216
|
+
).float()
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
if batch_size_per_feature_per_rank
|
|
221
|
+
else None
|
|
222
|
+
),
|
|
223
|
+
vbe_distribution=("normal" if batch_size_per_feature_per_rank else None),
|
|
224
|
+
vbe_num_ranks=(
|
|
225
|
+
len(batch_size_per_feature_per_rank)
|
|
226
|
+
if batch_size_per_feature_per_rank
|
|
227
|
+
else None
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Compute pooling parameters
|
|
232
|
+
bag_sizes = offsets[1:] - offsets[:-1]
|
|
233
|
+
mixed_bag_sizes = len(set(bag_sizes)) > 1
|
|
234
|
+
pooling_params = PoolingParams(
|
|
235
|
+
L=(
|
|
236
|
+
int(torch.ceil(torch.mean(bag_sizes.float())))
|
|
237
|
+
if mixed_bag_sizes
|
|
238
|
+
else int(bag_sizes[0])
|
|
239
|
+
),
|
|
240
|
+
sigma_L=(
|
|
241
|
+
int(torch.ceil(torch.std(bag_sizes.float())))
|
|
242
|
+
if mixed_bag_sizes
|
|
243
|
+
else None
|
|
244
|
+
),
|
|
245
|
+
length_distribution=("normal" if mixed_bag_sizes else None),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return TBEDataConfig(
|
|
249
|
+
T=T,
|
|
250
|
+
E=E,
|
|
251
|
+
D=D,
|
|
252
|
+
mixed_dim=mixed_dim,
|
|
253
|
+
weighted=(per_sample_weights is not None),
|
|
254
|
+
batch_params=batch_params,
|
|
255
|
+
indices_params=indices_params,
|
|
256
|
+
pooling_params=pooling_params,
|
|
257
|
+
use_cpu=(not torch.cuda.is_available()),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def report_stats(
|
|
261
|
+
self,
|
|
262
|
+
feature_rows: torch.Tensor,
|
|
263
|
+
feature_dims: torch.Tensor,
|
|
264
|
+
iteration: int,
|
|
265
|
+
indices: torch.Tensor,
|
|
266
|
+
offsets: torch.Tensor,
|
|
267
|
+
op_id: str = "",
|
|
268
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
269
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
270
|
+
) -> None:
|
|
271
|
+
"""
|
|
272
|
+
Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
feature_rows (torch.Tensor): Number of rows in each feature.
|
|
276
|
+
feature_dims (torch.Tensor): Number of dimensions in each feature.
|
|
277
|
+
iteration (int): The current iteration number.
|
|
278
|
+
indices (torch.Tensor): The input indices tensor.
|
|
279
|
+
offsets (torch.Tensor): The input offsets tensor.
|
|
280
|
+
op_id (str, optional): The operation identifier. Defaults to an empty string.
|
|
281
|
+
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
|
|
282
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
|
|
283
|
+
"""
|
|
284
|
+
if (
|
|
285
|
+
(iteration - self.report_iter_start) % self.report_interval == 0
|
|
286
|
+
and (iteration >= self.report_iter_start)
|
|
287
|
+
and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
|
|
288
|
+
):
|
|
289
|
+
# If indices tensor is empty (indices.numel() == 0), skip reporting
|
|
290
|
+
# TODO: Remove this once we have a better way to handle empty indices tensors
|
|
291
|
+
if indices.numel() == 0:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
# Extract TBE config
|
|
295
|
+
config = self.extract_params(
|
|
296
|
+
feature_rows=feature_rows,
|
|
297
|
+
feature_dims=feature_dims,
|
|
298
|
+
indices=indices,
|
|
299
|
+
offsets=offsets,
|
|
300
|
+
per_sample_weights=per_sample_weights,
|
|
301
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Ad-hoc fix for adding Es and Ds to JSON output
|
|
305
|
+
# TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig
|
|
306
|
+
adhoc_config = config.dict()
|
|
307
|
+
adhoc_config["Es"] = feature_rows.tolist()
|
|
308
|
+
adhoc_config["Ds"] = feature_dims.tolist()
|
|
309
|
+
if batch_size_per_feature_per_rank:
|
|
310
|
+
adhoc_config["Bs"] = [
|
|
311
|
+
sum(batch_size_per_feature_per_rank[f])
|
|
312
|
+
for f in range(len(adhoc_config["Es"]))
|
|
313
|
+
]
|
|
314
|
+
|
|
315
|
+
bag_sizes = (offsets[1:] - offsets[:-1]).tolist()
|
|
316
|
+
adhoc_config["Ls"] = []
|
|
317
|
+
pointer_counter = 0
|
|
318
|
+
if batch_size_per_feature_per_rank:
|
|
319
|
+
for batchs_size in adhoc_config["Bs"]:
|
|
320
|
+
current_L = 0
|
|
321
|
+
for _i in range(batchs_size):
|
|
322
|
+
current_L += bag_sizes[pointer_counter]
|
|
323
|
+
pointer_counter += 1
|
|
324
|
+
adhoc_config["Ls"].append(current_L / batchs_size)
|
|
325
|
+
else:
|
|
326
|
+
batch_size = int(len(bag_sizes) // len(adhoc_config["Es"]))
|
|
327
|
+
|
|
328
|
+
for _j in range(len(adhoc_config["Es"])):
|
|
329
|
+
current_L = 0
|
|
330
|
+
for _i in range(batch_size):
|
|
331
|
+
current_L += bag_sizes[pointer_counter]
|
|
332
|
+
pointer_counter += 1
|
|
333
|
+
adhoc_config["Ls"].append(current_L / batch_size)
|
|
334
|
+
|
|
335
|
+
# Write the TBE config to FileStore
|
|
336
|
+
self.filestore.write(
|
|
337
|
+
f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json",
|
|
338
|
+
io.BytesIO(json.dumps(adhoc_config, indent=2).encode()),
|
|
339
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
|
|
10
|
+
from .common import get_device, round_up, to_device # noqa: F401
|
|
11
|
+
from .offsets import b_indices, get_table_batched_offsets_from_dense # noqa: F401
|
|
12
|
+
from .quantize import dequantize_embs, fake_quantize_embs, quantize_embs # noqa: F401
|
|
13
|
+
from .requests import generate_requests, TBERequest # noqa: F401
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from typing import TypeVar
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
Deviceable = TypeVar(
|
|
14
|
+
"Deviceable", torch.nn.EmbeddingBag, torch.nn.Embedding, torch.Tensor
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def round_up(a: int, b: int) -> int:
|
|
19
|
+
return int((a + b - 1) // b) * b
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_device() -> torch.device:
|
|
23
|
+
if torch.cuda.is_available():
|
|
24
|
+
# pyre-fixme[7]: Expected `device` but got `Union[int, device]`.
|
|
25
|
+
return torch.cuda.current_device()
|
|
26
|
+
elif torch.mtia.is_available():
|
|
27
|
+
# pyre-fixme[7]: Expected `device` but got `Union[int, device]`.
|
|
28
|
+
return torch.mtia.current_device()
|
|
29
|
+
else:
|
|
30
|
+
return torch.device("cpu")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
|
|
34
|
+
if use_cpu:
|
|
35
|
+
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
|
|
36
|
+
return t.cpu()
|
|
37
|
+
elif torch.cuda.is_available():
|
|
38
|
+
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
|
|
39
|
+
return t.cuda()
|
|
40
|
+
else:
|
|
41
|
+
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
|
|
42
|
+
return t.to(device="mtia")
|