fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# pyre-ignore-all-errors[56]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from typing import Optional, Union
|
|
14
|
+
|
|
15
|
+
import torch # usort:skip
|
|
16
|
+
from torch import Tensor # usort:skip
|
|
17
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
18
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
19
|
+
BoundsCheckMode,
|
|
20
|
+
CacheAlgorithm,
|
|
21
|
+
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
22
|
+
EmbeddingLocation,
|
|
23
|
+
PoolingMode,
|
|
24
|
+
RecordCacheMetrics,
|
|
25
|
+
)
|
|
26
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
|
|
27
|
+
inputs_to_device,
|
|
28
|
+
IntNBitTableBatchedEmbeddingBagsCodegen,
|
|
29
|
+
random_quant_scaled_tensor,
|
|
30
|
+
rounded_row_size_in_bytes,
|
|
31
|
+
)
|
|
32
|
+
from fbgemm_gpu.utils.loader import load_torch_module
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
load_torch_module(
|
|
36
|
+
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
|
|
37
|
+
)
|
|
38
|
+
except Exception:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
43
|
+
"""
|
|
44
|
+
KV Table-batched version of nn.EmbeddingBag(sparse=False)
|
|
45
|
+
Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__( # noqa C901
|
|
49
|
+
self,
|
|
50
|
+
embedding_specs: list[
|
|
51
|
+
tuple[str, int, int, SparseType, EmbeddingLocation]
|
|
52
|
+
], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
|
|
53
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
54
|
+
index_remapping: Optional[list[Tensor]] = None,
|
|
55
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
56
|
+
device: Optional[Union[str, int, torch.device]] = None,
|
|
57
|
+
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
58
|
+
weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
|
|
59
|
+
pruning_hash_load_factor: float = 0.5,
|
|
60
|
+
use_array_for_index_remapping: bool = True,
|
|
61
|
+
output_dtype: SparseType = SparseType.FP16,
|
|
62
|
+
cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
|
|
63
|
+
cache_load_factor: float = 0.2,
|
|
64
|
+
cache_sets: int = 0,
|
|
65
|
+
cache_reserved_memory: float = 0.0,
|
|
66
|
+
enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
|
|
67
|
+
record_cache_metrics: Optional[RecordCacheMetrics] = None,
|
|
68
|
+
gather_uvm_cache_stats: Optional[bool] = False,
|
|
69
|
+
row_alignment: Optional[int] = None,
|
|
70
|
+
fp8_exponent_bits: Optional[int] = None,
|
|
71
|
+
fp8_exponent_bias: Optional[int] = None,
|
|
72
|
+
cache_assoc: int = 32,
|
|
73
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
74
|
+
cacheline_alignment: bool = True,
|
|
75
|
+
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
|
|
76
|
+
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
|
|
77
|
+
feature_names_per_table: Optional[list[list[str]]] = None,
|
|
78
|
+
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
|
|
79
|
+
embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
|
|
80
|
+
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
81
|
+
super(KVEmbeddingInference, self).__init__(
|
|
82
|
+
embedding_specs=embedding_specs,
|
|
83
|
+
feature_table_map=feature_table_map,
|
|
84
|
+
index_remapping=index_remapping,
|
|
85
|
+
pooling_mode=pooling_mode,
|
|
86
|
+
device=device,
|
|
87
|
+
bounds_check_mode=bounds_check_mode,
|
|
88
|
+
weight_lists=weight_lists,
|
|
89
|
+
pruning_hash_load_factor=pruning_hash_load_factor,
|
|
90
|
+
use_array_for_index_remapping=use_array_for_index_remapping,
|
|
91
|
+
output_dtype=output_dtype,
|
|
92
|
+
cache_algorithm=cache_algorithm,
|
|
93
|
+
cache_load_factor=cache_load_factor,
|
|
94
|
+
cache_sets=cache_sets,
|
|
95
|
+
cache_reserved_memory=cache_reserved_memory,
|
|
96
|
+
enforce_hbm=enforce_hbm,
|
|
97
|
+
record_cache_metrics=record_cache_metrics,
|
|
98
|
+
gather_uvm_cache_stats=gather_uvm_cache_stats,
|
|
99
|
+
row_alignment=row_alignment,
|
|
100
|
+
fp8_exponent_bits=fp8_exponent_bits,
|
|
101
|
+
fp8_exponent_bias=fp8_exponent_bias,
|
|
102
|
+
cache_assoc=cache_assoc,
|
|
103
|
+
scale_bias_size_in_bytes=scale_bias_size_in_bytes,
|
|
104
|
+
cacheline_alignment=cacheline_alignment,
|
|
105
|
+
uvm_host_mapped=uvm_host_mapped,
|
|
106
|
+
reverse_qparam=reverse_qparam,
|
|
107
|
+
feature_names_per_table=feature_names_per_table,
|
|
108
|
+
indices_dtype=indices_dtype,
|
|
109
|
+
)
|
|
110
|
+
self.register_buffer(
|
|
111
|
+
"weights_ids",
|
|
112
|
+
torch.tensor(0, device=self.current_device, dtype=torch.int64),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
num_shards = 32
|
|
116
|
+
uniform_init_lower: float = -0.01
|
|
117
|
+
uniform_init_upper: float = 0.01
|
|
118
|
+
|
|
119
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
120
|
+
self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
|
|
121
|
+
num_shards,
|
|
122
|
+
uniform_init_lower,
|
|
123
|
+
uniform_init_upper,
|
|
124
|
+
embedding_cache_mode, # in embedding_cache_mode, we disable random init
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self.specs: list[tuple[int, int, int]] = [
|
|
128
|
+
(rows, dims, sparse_type.as_int())
|
|
129
|
+
for (_, rows, dims, sparse_type, _) in self.embedding_specs
|
|
130
|
+
]
|
|
131
|
+
# table shard offset if inference sharding is enabled, otherwise, should be all zeros
|
|
132
|
+
self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs)
|
|
133
|
+
self.kv_embedding_cache_initialized = False
|
|
134
|
+
self.hash_size_cumsum: torch.Tensor = torch.zeros(
|
|
135
|
+
0,
|
|
136
|
+
device=self.current_device,
|
|
137
|
+
dtype=torch.int64,
|
|
138
|
+
)
|
|
139
|
+
self.feature_hash_size_cumsum: torch.Tensor = torch.zeros(
|
|
140
|
+
0,
|
|
141
|
+
device=self.current_device,
|
|
142
|
+
dtype=torch.int64,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def construct_hash_size_cumsum(self) -> list[int]:
|
|
146
|
+
hash_size_cumsum = [0]
|
|
147
|
+
for spec in self.embedding_specs:
|
|
148
|
+
rows = spec[1]
|
|
149
|
+
hash_size_cumsum.append(hash_size_cumsum[-1] + rows)
|
|
150
|
+
return hash_size_cumsum
|
|
151
|
+
|
|
152
|
+
def calculate_indices_and_weights_offsets(
|
|
153
|
+
self, indices: Tensor, offsets: Tensor
|
|
154
|
+
) -> tuple[Tensor, Tensor]:
|
|
155
|
+
if self.pooling_mode is not PoolingMode.NONE:
|
|
156
|
+
T = self.weights_offsets.numel()
|
|
157
|
+
else:
|
|
158
|
+
T = self.D_offsets.numel() - 1
|
|
159
|
+
B = int((offsets.size(0) - 1) / T)
|
|
160
|
+
|
|
161
|
+
total_bytes_added = 0
|
|
162
|
+
new_indices = torch.tensor(
|
|
163
|
+
[0] * indices.size(0), device=self.current_device, dtype=indices.dtype
|
|
164
|
+
)
|
|
165
|
+
new_weights_offsets = torch.tensor(
|
|
166
|
+
[0] * T, device=self.current_device, dtype=self.weights_offsets.dtype
|
|
167
|
+
)
|
|
168
|
+
for t in range(T):
|
|
169
|
+
new_weights_offsets[t] = total_bytes_added
|
|
170
|
+
start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
|
|
171
|
+
index_size = end - start
|
|
172
|
+
new_indices[start:end] = torch.arange(index_size)
|
|
173
|
+
table_id = self.feature_table_map[t]
|
|
174
|
+
total_bytes_added += index_size * rounded_row_size_in_bytes(
|
|
175
|
+
self.embedding_specs[table_id][2], # dim
|
|
176
|
+
self.embedding_specs[table_id][3], # weight_ty
|
|
177
|
+
self.row_alignment,
|
|
178
|
+
self.scale_bias_size_in_bytes,
|
|
179
|
+
)
|
|
180
|
+
return new_indices, new_weights_offsets
|
|
181
|
+
|
|
182
|
+
def linearize_cache_indices(
|
|
183
|
+
self,
|
|
184
|
+
indices: torch.Tensor,
|
|
185
|
+
offsets: torch.Tensor,
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""
|
|
188
|
+
Linearize cache indices for KV cache.
|
|
189
|
+
"""
|
|
190
|
+
linearized_indices = torch.zeros(
|
|
191
|
+
indices.numel(),
|
|
192
|
+
device=indices.device,
|
|
193
|
+
dtype=torch.int64,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
T = self.feature_hash_size_cumsum.numel() - 1
|
|
197
|
+
B = int((offsets.size(0) - 1) / T)
|
|
198
|
+
|
|
199
|
+
for t in range(T):
|
|
200
|
+
start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
|
|
201
|
+
linearized_indices[start:end] = (
|
|
202
|
+
indices[start:end] + self.feature_hash_size_cumsum[t]
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return linearized_indices
|
|
206
|
+
|
|
207
|
+
def forward(
|
|
208
|
+
self,
|
|
209
|
+
indices: Tensor,
|
|
210
|
+
offsets: Tensor,
|
|
211
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
212
|
+
) -> Tensor:
|
|
213
|
+
assert (
|
|
214
|
+
self.weight_initialized
|
|
215
|
+
), "weight needs to be initialized before forward function"
|
|
216
|
+
|
|
217
|
+
indices, offsets, per_sample_weights = inputs_to_device(
|
|
218
|
+
indices, offsets, per_sample_weights, self.bounds_check_warning
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
lxu_cache_locations = self.lxu_cache_locations_list.pop()
|
|
222
|
+
|
|
223
|
+
weights_offsets = self.weights_offsets
|
|
224
|
+
weights = self.weights_host if self.host_size > 0 else self.weights_dev
|
|
225
|
+
|
|
226
|
+
if self.kv_embedding_cache_initialized:
|
|
227
|
+
indices = self.linearize_cache_indices(
|
|
228
|
+
indices,
|
|
229
|
+
offsets,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
weights = self.kv_embedding_cache.get_embeddings(indices)
|
|
233
|
+
|
|
234
|
+
indices, weights_offsets = self.calculate_indices_and_weights_offsets(
|
|
235
|
+
indices, offsets
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
|
|
239
|
+
dev_weights=weights,
|
|
240
|
+
uvm_weights=self.weights_uvm,
|
|
241
|
+
weights_placements=self.weights_placements,
|
|
242
|
+
weights_offsets=weights_offsets,
|
|
243
|
+
weights_tys=self.weights_tys,
|
|
244
|
+
D_offsets=self.D_offsets,
|
|
245
|
+
total_D=self.total_D,
|
|
246
|
+
max_int2_D=self.max_int2_D,
|
|
247
|
+
max_int4_D=self.max_int4_D,
|
|
248
|
+
max_int8_D=self.max_int8_D,
|
|
249
|
+
max_float16_D=self.max_float16_D,
|
|
250
|
+
max_float32_D=self.max_float32_D,
|
|
251
|
+
indices=indices,
|
|
252
|
+
offsets=offsets,
|
|
253
|
+
pooling_mode=int(self.pooling_mode),
|
|
254
|
+
indice_weights=per_sample_weights,
|
|
255
|
+
output_dtype=self.output_dtype,
|
|
256
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
257
|
+
lxu_cache_locations=lxu_cache_locations,
|
|
258
|
+
row_alignment=self.row_alignment,
|
|
259
|
+
max_float8_D=self.max_float8_D,
|
|
260
|
+
fp8_exponent_bits=self.fp8_exponent_bits,
|
|
261
|
+
fp8_exponent_bias=self.fp8_exponent_bias,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def fill_random_weights(self) -> None:
|
|
265
|
+
"""
|
|
266
|
+
Fill the buffer with random weights, table by table
|
|
267
|
+
"""
|
|
268
|
+
self.initialize_kv_embedding_cache()
|
|
269
|
+
for i, (_, num_embeddings, embedding_dim, weight_ty, _) in enumerate(
|
|
270
|
+
self.embedding_specs
|
|
271
|
+
):
|
|
272
|
+
embedding_dim = rounded_row_size_in_bytes(
|
|
273
|
+
embedding_dim, weight_ty, self.row_alignment
|
|
274
|
+
)
|
|
275
|
+
indices = torch.range(0, num_embeddings - 1, dtype=torch.int64)
|
|
276
|
+
weights = random_quant_scaled_tensor(
|
|
277
|
+
shape=torch.Size([num_embeddings, embedding_dim]),
|
|
278
|
+
device=self.current_device,
|
|
279
|
+
)
|
|
280
|
+
self.embedding_inplace_update_per_table(
|
|
281
|
+
i,
|
|
282
|
+
indices,
|
|
283
|
+
weights,
|
|
284
|
+
)
|
|
285
|
+
self.weight_initialized = True
|
|
286
|
+
|
|
287
|
+
@torch.jit.export
|
|
288
|
+
def init_tbe_config(self, table_sharding_offset: list[int]) -> None:
|
|
289
|
+
"""
|
|
290
|
+
Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc.
|
|
291
|
+
Should be called before loading weights.
|
|
292
|
+
"""
|
|
293
|
+
self.table_sharding_offset = table_sharding_offset
|
|
294
|
+
|
|
295
|
+
@torch.jit.export
|
|
296
|
+
def embedding_inplace_update(
|
|
297
|
+
self,
|
|
298
|
+
update_table_indices: list[int],
|
|
299
|
+
update_row_indices: list[list[int]],
|
|
300
|
+
update_weights: list[Tensor],
|
|
301
|
+
) -> None:
|
|
302
|
+
# function is not used for now on the inference side
|
|
303
|
+
for i in range(len(update_table_indices)):
|
|
304
|
+
self.embedding_inplace_update_per_table(
|
|
305
|
+
update_table_indices[i],
|
|
306
|
+
torch.tensor(
|
|
307
|
+
update_row_indices[i], device=self.current_device, dtype=torch.int64
|
|
308
|
+
),
|
|
309
|
+
update_weights[i],
|
|
310
|
+
None,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
@torch.jit.export
|
|
314
|
+
def embedding_inplace_update_per_table(
|
|
315
|
+
self,
|
|
316
|
+
table_id: int,
|
|
317
|
+
update_row_indices: Tensor,
|
|
318
|
+
update_weights: Tensor,
|
|
319
|
+
inplace_update_ts_sec: Optional[int] = None,
|
|
320
|
+
) -> None:
|
|
321
|
+
assert table_id < len(
|
|
322
|
+
self.embedding_specs
|
|
323
|
+
), f"table index {table_id} is out of range {len(self.embedding_specs)}"
|
|
324
|
+
# pyre-ignore [29]
|
|
325
|
+
table_offset = self.hash_size_cumsum[table_id]
|
|
326
|
+
sharding_offset = self.table_sharding_offset[table_id]
|
|
327
|
+
|
|
328
|
+
row_size = update_row_indices.numel()
|
|
329
|
+
if row_size == 0:
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
# convert global weight index to fused local weight index
|
|
333
|
+
row_indices = update_row_indices + table_offset - sharding_offset
|
|
334
|
+
# set weight by id
|
|
335
|
+
self.kv_embedding_cache.set_embeddings(
|
|
336
|
+
row_indices, update_weights, inplace_update_ts_sec
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
@torch.jit.export
|
|
340
|
+
def log_inplace_update_stats(
|
|
341
|
+
self,
|
|
342
|
+
) -> None:
|
|
343
|
+
self.kv_embedding_cache.log_inplace_update_stats()
|
|
344
|
+
|
|
345
|
+
@torch.jit.export
|
|
346
|
+
def embedding_trigger_evict(
|
|
347
|
+
self,
|
|
348
|
+
inplace_update_ts_sec: int,
|
|
349
|
+
) -> None:
|
|
350
|
+
self.kv_embedding_cache.trigger_evict(inplace_update_ts_sec)
|
|
351
|
+
|
|
352
|
+
@torch.jit.export
|
|
353
|
+
def embedding_wait_evict_completion(
|
|
354
|
+
self,
|
|
355
|
+
) -> None:
|
|
356
|
+
self.kv_embedding_cache.wait_evict_completion()
|
|
357
|
+
|
|
358
|
+
@torch.jit.export
|
|
359
|
+
def initialize_kv_embedding_cache(self) -> None:
|
|
360
|
+
if not self.kv_embedding_cache_initialized:
|
|
361
|
+
self.initialize_logical_weights_placements_and_offsets()
|
|
362
|
+
|
|
363
|
+
self.row_alignment = 8 # in order to use mempool implementation for kv embedding it needs to be divisible by 8
|
|
364
|
+
|
|
365
|
+
hash_size_cumsum = self.construct_hash_size_cumsum()
|
|
366
|
+
self.hash_size_cumsum = torch.tensor(
|
|
367
|
+
hash_size_cumsum,
|
|
368
|
+
dtype=torch.int64,
|
|
369
|
+
device=self.current_device,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
self.feature_hash_size_cumsum = torch.tensor(
|
|
373
|
+
[hash_size_cumsum[t] for t in self.feature_table_map]
|
|
374
|
+
+ [hash_size_cumsum[-1]],
|
|
375
|
+
dtype=torch.int64,
|
|
376
|
+
device=self.current_device,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
self.kv_embedding_cache.init(
|
|
380
|
+
self.specs,
|
|
381
|
+
self.row_alignment,
|
|
382
|
+
self.scale_bias_size_in_bytes,
|
|
383
|
+
self.hash_size_cumsum,
|
|
384
|
+
)
|
|
385
|
+
self.kv_embedding_cache_initialized = True
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_unique_indices_v2(
|
|
15
|
+
linear_indices: torch.Tensor,
|
|
16
|
+
max_indices: int,
|
|
17
|
+
compute_count: bool = False,
|
|
18
|
+
compute_inverse_indices: bool = False,
|
|
19
|
+
) -> Union[
|
|
20
|
+
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
|
21
|
+
tuple[
|
|
22
|
+
torch.Tensor,
|
|
23
|
+
torch.Tensor,
|
|
24
|
+
Optional[torch.Tensor],
|
|
25
|
+
],
|
|
26
|
+
tuple[torch.Tensor, torch.Tensor],
|
|
27
|
+
]:
|
|
28
|
+
"""
|
|
29
|
+
A wrapper for get_unique_indices for overloading the return type
|
|
30
|
+
based on inputs
|
|
31
|
+
"""
|
|
32
|
+
ret = torch.ops.fbgemm.get_unique_indices_with_inverse(
|
|
33
|
+
linear_indices,
|
|
34
|
+
max_indices,
|
|
35
|
+
compute_count,
|
|
36
|
+
compute_inverse_indices,
|
|
37
|
+
)
|
|
38
|
+
if compute_count and compute_inverse_indices:
|
|
39
|
+
# Return all tensors
|
|
40
|
+
return ret
|
|
41
|
+
if compute_count:
|
|
42
|
+
# Return (unique_indices, length, count)
|
|
43
|
+
return ret[:-1]
|
|
44
|
+
if compute_inverse_indices:
|
|
45
|
+
# Return (unique_indices, length, inverse_indices)
|
|
46
|
+
return ret[0], ret[1], ret[3]
|
|
47
|
+
# Return (unique_indices, length)
|
|
48
|
+
return ret[:-2]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# Load the prelude
|
|
11
|
+
from .common import ASSOC # noqa: F401
|
|
12
|
+
|
|
13
|
+
# Load the inference and training ops
|
|
14
|
+
from .inference import SSDIntNBitTableBatchedEmbeddingBags # noqa: F401
|
|
15
|
+
from .training import SSDTableBatchedEmbeddingBags # noqa: F401
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from fbgemm_gpu.utils.loader import load_torch_module
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
load_torch_module(
|
|
17
|
+
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
|
|
18
|
+
)
|
|
19
|
+
except Exception:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
ASSOC = 32
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def pad4(value: int) -> int:
|
|
26
|
+
"""
|
|
27
|
+
Compute the smallest multiple of 4 that is greater than or equal to the given value.
|
|
28
|
+
|
|
29
|
+
Parameters:
|
|
30
|
+
value (int): The integer to align (must be non-negative).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
int: The aligned value.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If the input is negative.
|
|
37
|
+
TypeError: If the input is not an integer.
|
|
38
|
+
"""
|
|
39
|
+
return (int(value) + 3) & ~3
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def tensor_pad4(value: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
"""
|
|
44
|
+
The equivalent of pad4 for tensors.
|
|
45
|
+
"""
|
|
46
|
+
return (value + 3) & ~3
|