fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
|
|
11
|
+
import itertools
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import tempfile
|
|
15
|
+
from math import log2
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch # usort:skip
|
|
19
|
+
|
|
20
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
21
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
22
|
+
CacheAlgorithm,
|
|
23
|
+
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
24
|
+
EmbeddingLocation,
|
|
25
|
+
PoolingMode,
|
|
26
|
+
)
|
|
27
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
|
|
28
|
+
align_to_cacheline,
|
|
29
|
+
rounded_row_size_in_bytes,
|
|
30
|
+
unpadded_row_size_in_bytes,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from torch import distributed as dist, nn, Tensor # usort:skip
|
|
34
|
+
from torch.autograd.profiler import record_function
|
|
35
|
+
|
|
36
|
+
from .common import ASSOC
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
40
|
+
"""
|
|
41
|
+
SSD Table-batched version of nn.EmbeddingBag(sparse=False)
|
|
42
|
+
Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
embedding_specs: list[tuple[str, int, int, SparseType]]
|
|
46
|
+
_local_instance_index: int = -1
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
embedding_specs: list[
|
|
51
|
+
tuple[str, int, int, SparseType]
|
|
52
|
+
], # tuple of (feature_names, rows, dims, SparseType)
|
|
53
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
54
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
55
|
+
output_dtype: SparseType = SparseType.FP16,
|
|
56
|
+
row_alignment: Optional[int] = None,
|
|
57
|
+
fp8_exponent_bits: Optional[int] = None,
|
|
58
|
+
fp8_exponent_bias: Optional[int] = None,
|
|
59
|
+
cache_assoc: int = 32,
|
|
60
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
61
|
+
cache_sets: int = 0,
|
|
62
|
+
ssd_storage_directory: str = "/tmp",
|
|
63
|
+
ssd_shards: int = 1,
|
|
64
|
+
ssd_memtable_flush_period: int = -1,
|
|
65
|
+
ssd_memtable_flush_offset: int = -1,
|
|
66
|
+
ssd_l0_files_per_compact: int = 4,
|
|
67
|
+
ssd_rate_limit_mbps: int = 0,
|
|
68
|
+
ssd_size_ratio: int = 10,
|
|
69
|
+
ssd_compaction_trigger: int = 8,
|
|
70
|
+
ssd_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
|
|
71
|
+
ssd_max_write_buffer_num: int = 16,
|
|
72
|
+
ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
|
|
73
|
+
ssd_uniform_init_lower: float = -0.01,
|
|
74
|
+
ssd_uniform_init_upper: float = 0.01,
|
|
75
|
+
# Parameter Server Configs
|
|
76
|
+
ps_hosts: Optional[tuple[tuple[str, int]]] = None,
|
|
77
|
+
ps_max_key_per_request: Optional[int] = None,
|
|
78
|
+
ps_client_thread_num: Optional[int] = None,
|
|
79
|
+
ps_max_local_index_length: Optional[int] = None,
|
|
80
|
+
tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
|
|
81
|
+
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
82
|
+
super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()
|
|
83
|
+
|
|
84
|
+
assert cache_assoc == 32, "Only 32-way cache is supported now"
|
|
85
|
+
|
|
86
|
+
self.scale_bias_size_in_bytes = scale_bias_size_in_bytes
|
|
87
|
+
self.pooling_mode = pooling_mode
|
|
88
|
+
self.embedding_specs = embedding_specs
|
|
89
|
+
T_ = len(self.embedding_specs)
|
|
90
|
+
assert T_ > 0
|
|
91
|
+
device = torch.cuda.current_device()
|
|
92
|
+
if device is None:
|
|
93
|
+
self.current_device: torch.device = torch.device(
|
|
94
|
+
torch.cuda.current_device()
|
|
95
|
+
)
|
|
96
|
+
elif isinstance(device, torch.device):
|
|
97
|
+
self.current_device = device
|
|
98
|
+
else:
|
|
99
|
+
self.current_device = torch.device(device)
|
|
100
|
+
self.use_cpu: bool = self.current_device.type == "cpu"
|
|
101
|
+
|
|
102
|
+
self.feature_table_map: list[int] = (
|
|
103
|
+
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
104
|
+
)
|
|
105
|
+
T = len(self.feature_table_map)
|
|
106
|
+
assert T_ <= T
|
|
107
|
+
table_has_feature = [False] * T_
|
|
108
|
+
for t in self.feature_table_map:
|
|
109
|
+
table_has_feature[t] = True
|
|
110
|
+
assert all(table_has_feature), "Each table must have at least one feature!"
|
|
111
|
+
|
|
112
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
113
|
+
# (feature_names, rows, dims, weights_tys) = zip(*embedding_specs)
|
|
114
|
+
# Pyre workaround
|
|
115
|
+
rows: list[int] = [e[1] for e in embedding_specs]
|
|
116
|
+
dims: list[int] = [e[2] for e in embedding_specs]
|
|
117
|
+
weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
|
|
118
|
+
|
|
119
|
+
D_offsets = [dims[t] for t in self.feature_table_map]
|
|
120
|
+
D_offsets = [0] + list(itertools.accumulate(D_offsets))
|
|
121
|
+
self.total_D: int = D_offsets[-1]
|
|
122
|
+
self.register_buffer(
|
|
123
|
+
"D_offsets",
|
|
124
|
+
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if row_alignment is None:
|
|
128
|
+
self.row_alignment: int = 1 if self.use_cpu else 16
|
|
129
|
+
else:
|
|
130
|
+
self.row_alignment = row_alignment
|
|
131
|
+
|
|
132
|
+
for dim, weight_ty in zip(dims, weights_tys):
|
|
133
|
+
if not weight_ty.is_float():
|
|
134
|
+
assert (
|
|
135
|
+
dim % (8 / weight_ty.bit_rate()) == 0
|
|
136
|
+
), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"
|
|
137
|
+
|
|
138
|
+
def max_ty_D(ty: SparseType) -> int:
|
|
139
|
+
return max(
|
|
140
|
+
[dim for dim, weight_ty in zip(dims, weights_tys) if weight_ty == ty],
|
|
141
|
+
default=0,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.max_int2_D: int = max_ty_D(SparseType.INT2)
|
|
145
|
+
self.max_int4_D: int = max_ty_D(SparseType.INT4)
|
|
146
|
+
self.max_int8_D: int = max_ty_D(SparseType.INT8)
|
|
147
|
+
self.max_float8_D: int = max_ty_D(SparseType.FP8)
|
|
148
|
+
self.max_float16_D: int = max_ty_D(SparseType.FP16)
|
|
149
|
+
self.max_float32_D: int = max_ty_D(SparseType.FP32)
|
|
150
|
+
|
|
151
|
+
cached_dims = [
|
|
152
|
+
rounded_row_size_in_bytes(
|
|
153
|
+
embedding_spec[2], embedding_spec[3], 16, self.scale_bias_size_in_bytes
|
|
154
|
+
)
|
|
155
|
+
for embedding_spec in self.embedding_specs
|
|
156
|
+
]
|
|
157
|
+
self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
|
|
158
|
+
|
|
159
|
+
placements = []
|
|
160
|
+
offsets = []
|
|
161
|
+
uvm_size = 0
|
|
162
|
+
for _, num_embeddings, embedding_dim, weight_ty in embedding_specs:
|
|
163
|
+
embedding_dim = rounded_row_size_in_bytes(
|
|
164
|
+
embedding_dim, weight_ty, self.row_alignment, scale_bias_size_in_bytes
|
|
165
|
+
)
|
|
166
|
+
state_size = num_embeddings * embedding_dim
|
|
167
|
+
state_size = align_to_cacheline(state_size)
|
|
168
|
+
placements.append(EmbeddingLocation.MANAGED_CACHING)
|
|
169
|
+
offsets.append(uvm_size)
|
|
170
|
+
uvm_size += state_size
|
|
171
|
+
|
|
172
|
+
self.weights_physical_offsets: list[int] = offsets
|
|
173
|
+
|
|
174
|
+
weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
|
|
175
|
+
self.register_buffer(
|
|
176
|
+
"weights_tys",
|
|
177
|
+
torch.tensor(
|
|
178
|
+
weights_tys_int, device=self.current_device, dtype=torch.uint8
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
self.weight_initialized: bool = True
|
|
182
|
+
|
|
183
|
+
assert self.D_offsets.numel() == T + 1
|
|
184
|
+
hash_size_cumsum = [0] + list(itertools.accumulate(rows))
|
|
185
|
+
if hash_size_cumsum[-1] == 0:
|
|
186
|
+
self.total_hash_size_bits: int = 0
|
|
187
|
+
else:
|
|
188
|
+
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
189
|
+
# The last element is to easily access # of rows of each table by
|
|
190
|
+
self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
191
|
+
self.total_hash_size: int = hash_size_cumsum[-1]
|
|
192
|
+
# The last element is to easily access # of rows of each table by
|
|
193
|
+
# hash_size_cumsum[t + 1] - hash_size_cumsum[t]
|
|
194
|
+
hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
|
|
195
|
+
hash_size_cumsum[-1]
|
|
196
|
+
]
|
|
197
|
+
self.register_buffer(
|
|
198
|
+
"hash_size_cumsum",
|
|
199
|
+
torch.tensor(
|
|
200
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
assert cache_sets > 0
|
|
204
|
+
element_size = 1
|
|
205
|
+
cache_size = cache_sets * ASSOC * element_size * self.max_D_cache
|
|
206
|
+
logging.info(
|
|
207
|
+
f"Using cache for SSD with admission algorithm "
|
|
208
|
+
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_shards} shards, "
|
|
209
|
+
f"SSD storage directory: {ssd_storage_directory}, "
|
|
210
|
+
f"Memtable Flush Period: {ssd_memtable_flush_period}, "
|
|
211
|
+
f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
|
|
212
|
+
f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
|
|
213
|
+
f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
|
|
214
|
+
f"output dtype: {output_dtype}"
|
|
215
|
+
)
|
|
216
|
+
self.register_buffer(
|
|
217
|
+
"lxu_cache_state",
|
|
218
|
+
torch.zeros(cache_sets, ASSOC, dtype=torch.int64).fill_(-1),
|
|
219
|
+
)
|
|
220
|
+
self.register_buffer(
|
|
221
|
+
"lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
assert ssd_cache_location in (
|
|
225
|
+
EmbeddingLocation.MANAGED,
|
|
226
|
+
EmbeddingLocation.DEVICE,
|
|
227
|
+
)
|
|
228
|
+
if ssd_cache_location == EmbeddingLocation.MANAGED:
|
|
229
|
+
self.register_buffer(
|
|
230
|
+
"lxu_cache_weights",
|
|
231
|
+
torch.ops.fbgemm.new_managed_tensor(
|
|
232
|
+
torch.zeros(1, device=self.current_device, dtype=torch.uint8),
|
|
233
|
+
[cache_sets * ASSOC, self.max_D_cache],
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
self.register_buffer(
|
|
238
|
+
"lxu_cache_weights",
|
|
239
|
+
torch.zeros(
|
|
240
|
+
cache_sets * ASSOC,
|
|
241
|
+
self.max_D_cache,
|
|
242
|
+
device=self.current_device,
|
|
243
|
+
dtype=torch.uint8,
|
|
244
|
+
),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
assert (
|
|
248
|
+
cache_size
|
|
249
|
+
== self.lxu_cache_weights.numel()
|
|
250
|
+
* self.lxu_cache_weights.element_size()
|
|
251
|
+
), "The precomputed cache_size does not match the actual cache size"
|
|
252
|
+
|
|
253
|
+
os.makedirs(ssd_storage_directory, exist_ok=True)
|
|
254
|
+
|
|
255
|
+
ssd_directory = tempfile.mkdtemp(
|
|
256
|
+
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
|
|
257
|
+
)
|
|
258
|
+
if not ps_hosts:
|
|
259
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
260
|
+
# pyre-ignore[16]
|
|
261
|
+
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
|
|
262
|
+
ssd_directory,
|
|
263
|
+
ssd_shards,
|
|
264
|
+
ssd_shards,
|
|
265
|
+
ssd_memtable_flush_period,
|
|
266
|
+
ssd_memtable_flush_offset,
|
|
267
|
+
ssd_l0_files_per_compact,
|
|
268
|
+
self.max_D_cache,
|
|
269
|
+
ssd_rate_limit_mbps,
|
|
270
|
+
ssd_size_ratio,
|
|
271
|
+
ssd_compaction_trigger,
|
|
272
|
+
ssd_write_buffer_size,
|
|
273
|
+
ssd_max_write_buffer_num,
|
|
274
|
+
ssd_uniform_init_lower,
|
|
275
|
+
ssd_uniform_init_upper,
|
|
276
|
+
8, # row_storage_bitwidth
|
|
277
|
+
0, # ssd_block_cache_size
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
# create tbe unique id using rank index | pooling mode
|
|
281
|
+
if tbe_unique_id == -1:
|
|
282
|
+
SSDIntNBitTableBatchedEmbeddingBags._local_instance_index += 1
|
|
283
|
+
assert (
|
|
284
|
+
SSDIntNBitTableBatchedEmbeddingBags._local_instance_index < 8
|
|
285
|
+
), f"{SSDIntNBitTableBatchedEmbeddingBags._local_instance_index}, more than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
|
|
286
|
+
tbe_unique_id = (
|
|
287
|
+
dist.get_rank() << 3
|
|
288
|
+
| SSDIntNBitTableBatchedEmbeddingBags._local_instance_index
|
|
289
|
+
)
|
|
290
|
+
logging.info(f"tbe_unique_id: {tbe_unique_id}")
|
|
291
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
292
|
+
# pyre-ignore[16]
|
|
293
|
+
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
|
|
294
|
+
[host[0] for host in ps_hosts],
|
|
295
|
+
[host[1] for host in ps_hosts],
|
|
296
|
+
tbe_unique_id,
|
|
297
|
+
(
|
|
298
|
+
ps_max_local_index_length
|
|
299
|
+
if ps_max_local_index_length is not None
|
|
300
|
+
else 54
|
|
301
|
+
),
|
|
302
|
+
ps_client_thread_num if ps_client_thread_num is not None else 32,
|
|
303
|
+
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
|
|
304
|
+
0, # ssd_block_cache_size
|
|
305
|
+
self.max_D_cache,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# pyre-fixme[20]: Argument `self` expected.
|
|
309
|
+
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
|
|
310
|
+
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
|
|
311
|
+
self.ssd_set_start = torch.cuda.Event()
|
|
312
|
+
self.ssd_set_end = torch.cuda.Event()
|
|
313
|
+
|
|
314
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
315
|
+
# pyre-ignore[16]
|
|
316
|
+
self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
|
|
317
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
318
|
+
# pyre-ignore[16]
|
|
319
|
+
self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()
|
|
320
|
+
|
|
321
|
+
self.weights_dev: torch.Tensor = torch.empty(
|
|
322
|
+
0,
|
|
323
|
+
device=self.current_device,
|
|
324
|
+
dtype=torch.uint8,
|
|
325
|
+
)
|
|
326
|
+
self.register_buffer(
|
|
327
|
+
"weights_uvm",
|
|
328
|
+
torch.tensor((0,), device=self.current_device, dtype=torch.uint8),
|
|
329
|
+
)
|
|
330
|
+
self.register_buffer(
|
|
331
|
+
"weights_host",
|
|
332
|
+
torch.empty(0),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
self.register_buffer(
|
|
336
|
+
"weights_placements",
|
|
337
|
+
torch.tensor(
|
|
338
|
+
[EmbeddingLocation.MANAGED_CACHING for _ in range(T_)],
|
|
339
|
+
dtype=torch.int32,
|
|
340
|
+
),
|
|
341
|
+
)
|
|
342
|
+
weights_offsets = [0] + list(
|
|
343
|
+
itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)])
|
|
344
|
+
)
|
|
345
|
+
self.register_buffer(
|
|
346
|
+
"weights_offsets",
|
|
347
|
+
torch.tensor(
|
|
348
|
+
weights_offsets[:-1],
|
|
349
|
+
device=self.current_device,
|
|
350
|
+
dtype=torch.int64,
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if self.max_float8_D > 0:
|
|
355
|
+
default_config = SparseType.FP8.default_config()
|
|
356
|
+
self.fp8_exponent_bits: int = (
|
|
357
|
+
default_config.get("exponent_bits")
|
|
358
|
+
if fp8_exponent_bits is None
|
|
359
|
+
else fp8_exponent_bits
|
|
360
|
+
)
|
|
361
|
+
self.fp8_exponent_bias: int = (
|
|
362
|
+
default_config.get("exponent_bias")
|
|
363
|
+
if fp8_exponent_bias is None
|
|
364
|
+
else fp8_exponent_bias
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
self.fp8_exponent_bits = -1
|
|
368
|
+
self.fp8_exponent_bias = -1
|
|
369
|
+
|
|
370
|
+
@torch.jit.export
|
|
371
|
+
def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
|
|
372
|
+
(indices, offsets) = indices.long(), offsets.long()
|
|
373
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
374
|
+
self.hash_size_cumsum,
|
|
375
|
+
indices,
|
|
376
|
+
offsets,
|
|
377
|
+
)
|
|
378
|
+
self.timestep_counter.increment()
|
|
379
|
+
self.timestep_prefetch_size.increment()
|
|
380
|
+
(
|
|
381
|
+
inserted_indices,
|
|
382
|
+
evicted_indices,
|
|
383
|
+
assigned_cache_slots,
|
|
384
|
+
actions_count_gpu,
|
|
385
|
+
_,
|
|
386
|
+
_,
|
|
387
|
+
_,
|
|
388
|
+
_,
|
|
389
|
+
) = torch.ops.fbgemm.ssd_cache_populate_actions(
|
|
390
|
+
linear_cache_indices,
|
|
391
|
+
self.total_hash_size,
|
|
392
|
+
self.lxu_cache_state,
|
|
393
|
+
self.timestep_counter.get(),
|
|
394
|
+
1, # for now assume prefetch_dist == 1
|
|
395
|
+
self.lru_state,
|
|
396
|
+
)
|
|
397
|
+
actions_count_cpu = torch.empty(
|
|
398
|
+
actions_count_gpu.shape, pin_memory=True, dtype=actions_count_gpu.dtype
|
|
399
|
+
)
|
|
400
|
+
actions_count_cpu.copy_(actions_count_gpu, non_blocking=True)
|
|
401
|
+
assigned_cache_slots = assigned_cache_slots.long()
|
|
402
|
+
evicted_rows = self.lxu_cache_weights[
|
|
403
|
+
assigned_cache_slots.clamp_(min=0).long(), :
|
|
404
|
+
]
|
|
405
|
+
inserted_rows = torch.empty(
|
|
406
|
+
evicted_rows.shape,
|
|
407
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
408
|
+
pin_memory=True,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
current_stream = torch.cuda.current_stream()
|
|
412
|
+
|
|
413
|
+
# Ensure the previous iterations l3_db.set(..) has completed.
|
|
414
|
+
current_stream.wait_event(self.ssd_set_end)
|
|
415
|
+
inserted_indices_cpu = torch.empty(
|
|
416
|
+
inserted_indices.shape, pin_memory=True, dtype=inserted_indices.dtype
|
|
417
|
+
)
|
|
418
|
+
inserted_indices_cpu.copy_(inserted_indices, non_blocking=True)
|
|
419
|
+
self.ssd_db.get_cuda(
|
|
420
|
+
inserted_indices_cpu,
|
|
421
|
+
inserted_rows,
|
|
422
|
+
actions_count_cpu,
|
|
423
|
+
)
|
|
424
|
+
current_stream.record_event(self.ssd_set_start)
|
|
425
|
+
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
|
|
426
|
+
# Should we allocate on HBM?
|
|
427
|
+
inserted_rows_gpu = inserted_rows.to(self.current_device, non_blocking=True)
|
|
428
|
+
|
|
429
|
+
# self.lxu_cache_weights[assigned_cache_slots, :] = inserted_rows.cuda(non_blocking=True)
|
|
430
|
+
torch.ops.fbgemm.masked_index_put(
|
|
431
|
+
self.lxu_cache_weights,
|
|
432
|
+
assigned_cache_slots,
|
|
433
|
+
inserted_rows_gpu,
|
|
434
|
+
actions_count_gpu,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
with torch.cuda.stream(self.ssd_stream):
|
|
438
|
+
self.ssd_stream.wait_event(self.ssd_set_start)
|
|
439
|
+
evicted_rows_cpu = torch.empty(
|
|
440
|
+
evicted_rows.shape, pin_memory=True, dtype=evicted_rows.dtype
|
|
441
|
+
)
|
|
442
|
+
evicted_rows_cpu.copy_(evicted_rows, non_blocking=True)
|
|
443
|
+
evicted_indices_cpu = torch.empty(
|
|
444
|
+
evicted_indices.shape, pin_memory=True, dtype=evicted_indices.dtype
|
|
445
|
+
)
|
|
446
|
+
evicted_indices_cpu.copy_(evicted_indices, non_blocking=True)
|
|
447
|
+
evicted_rows.record_stream(self.ssd_stream)
|
|
448
|
+
evicted_indices.record_stream(self.ssd_stream)
|
|
449
|
+
self.ssd_db.set_cuda(
|
|
450
|
+
evicted_indices_cpu,
|
|
451
|
+
evicted_rows_cpu,
|
|
452
|
+
actions_count_cpu,
|
|
453
|
+
self.timestep_counter.get(),
|
|
454
|
+
)
|
|
455
|
+
# TODO: is this needed?
|
|
456
|
+
# Need a way to synchronize
|
|
457
|
+
# actions_count_cpu.record_stream(self.ssd_stream)
|
|
458
|
+
self.ssd_stream.record_event(self.ssd_set_end)
|
|
459
|
+
return linear_cache_indices
|
|
460
|
+
|
|
461
|
+
def forward(
|
|
462
|
+
self,
|
|
463
|
+
indices: Tensor,
|
|
464
|
+
offsets: Tensor,
|
|
465
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
466
|
+
) -> Tensor:
|
|
467
|
+
if self.timestep_prefetch_size.get() <= 0:
|
|
468
|
+
with record_function("## prefetch ##"):
|
|
469
|
+
linear_cache_indices = self.prefetch(indices, offsets)
|
|
470
|
+
else:
|
|
471
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
472
|
+
self.hash_size_cumsum,
|
|
473
|
+
indices,
|
|
474
|
+
offsets,
|
|
475
|
+
)
|
|
476
|
+
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
477
|
+
linear_cache_indices,
|
|
478
|
+
self.lxu_cache_state,
|
|
479
|
+
self.total_hash_size,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
self.timestep_prefetch_size.decrement()
|
|
483
|
+
|
|
484
|
+
assert (
|
|
485
|
+
self.weight_initialized
|
|
486
|
+
), "weight needs to be initialized before forward function"
|
|
487
|
+
|
|
488
|
+
# Note: CPU and CUDA ops use the same interface to facilitate JIT IR
|
|
489
|
+
# generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
|
|
490
|
+
# weights_placements
|
|
491
|
+
return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
|
|
492
|
+
dev_weights=self.weights_dev,
|
|
493
|
+
uvm_weights=self.weights_uvm,
|
|
494
|
+
weights_placements=self.weights_placements,
|
|
495
|
+
weights_offsets=self.weights_offsets,
|
|
496
|
+
weights_tys=self.weights_tys,
|
|
497
|
+
D_offsets=self.D_offsets,
|
|
498
|
+
total_D=self.total_D,
|
|
499
|
+
max_int2_D=self.max_int2_D,
|
|
500
|
+
max_int4_D=self.max_int4_D,
|
|
501
|
+
max_int8_D=self.max_int8_D,
|
|
502
|
+
max_float16_D=self.max_float16_D,
|
|
503
|
+
max_float32_D=self.max_float32_D,
|
|
504
|
+
indices=indices,
|
|
505
|
+
offsets=offsets,
|
|
506
|
+
pooling_mode=int(self.pooling_mode),
|
|
507
|
+
indice_weights=per_sample_weights,
|
|
508
|
+
output_dtype=self.output_dtype,
|
|
509
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
510
|
+
lxu_cache_locations=lxu_cache_locations,
|
|
511
|
+
row_alignment=self.row_alignment,
|
|
512
|
+
max_float8_D=self.max_float8_D,
|
|
513
|
+
fp8_exponent_bits=self.fp8_exponent_bits,
|
|
514
|
+
fp8_exponent_bias=self.fp8_exponent_bias,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
@torch.jit.export
|
|
518
|
+
def split_embedding_weights(
|
|
519
|
+
self, split_scale_shifts: bool = True
|
|
520
|
+
) -> list[tuple[Tensor, Optional[Tensor]]]:
|
|
521
|
+
"""
|
|
522
|
+
Returns a list of weights, split by table.
|
|
523
|
+
|
|
524
|
+
Testing only, very slow.
|
|
525
|
+
"""
|
|
526
|
+
splits: list[tuple[Tensor, Optional[Tensor]]] = []
|
|
527
|
+
rows_cumsum = 0
|
|
528
|
+
for _, row, dim, weight_ty in self.embedding_specs:
|
|
529
|
+
weights = torch.empty(
|
|
530
|
+
(
|
|
531
|
+
row,
|
|
532
|
+
rounded_row_size_in_bytes(
|
|
533
|
+
dim,
|
|
534
|
+
weight_ty,
|
|
535
|
+
self.row_alignment,
|
|
536
|
+
self.scale_bias_size_in_bytes,
|
|
537
|
+
),
|
|
538
|
+
),
|
|
539
|
+
dtype=torch.uint8,
|
|
540
|
+
)
|
|
541
|
+
self.ssd_db.get_cuda(
|
|
542
|
+
torch.arange(rows_cumsum, rows_cumsum + row).to(torch.int64),
|
|
543
|
+
weights,
|
|
544
|
+
torch.as_tensor([row]),
|
|
545
|
+
)
|
|
546
|
+
rows_cumsum += row
|
|
547
|
+
torch.cuda.synchronize(self.current_device)
|
|
548
|
+
|
|
549
|
+
weights_shifts = weights.detach()
|
|
550
|
+
|
|
551
|
+
if split_scale_shifts:
|
|
552
|
+
# remove the padding at the end of each row.
|
|
553
|
+
weights_shifts = weights_shifts[
|
|
554
|
+
:,
|
|
555
|
+
: unpadded_row_size_in_bytes(
|
|
556
|
+
dim, weight_ty, self.scale_bias_size_in_bytes
|
|
557
|
+
),
|
|
558
|
+
]
|
|
559
|
+
if (
|
|
560
|
+
weight_ty == SparseType.INT8
|
|
561
|
+
or weight_ty == SparseType.INT4
|
|
562
|
+
or weight_ty == SparseType.INT2
|
|
563
|
+
):
|
|
564
|
+
splits.append(
|
|
565
|
+
(
|
|
566
|
+
weights_shifts[:, self.scale_bias_size_in_bytes :],
|
|
567
|
+
weights_shifts[:, : self.scale_bias_size_in_bytes],
|
|
568
|
+
)
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
assert (
|
|
572
|
+
weight_ty == SparseType.FP8
|
|
573
|
+
or weight_ty == SparseType.FP16
|
|
574
|
+
or weight_ty == SparseType.FP32
|
|
575
|
+
)
|
|
576
|
+
splits.append(
|
|
577
|
+
(
|
|
578
|
+
weights_shifts,
|
|
579
|
+
None,
|
|
580
|
+
)
|
|
581
|
+
)
|
|
582
|
+
else:
|
|
583
|
+
splits.append((weights_shifts, None))
|
|
584
|
+
|
|
585
|
+
torch.cuda.synchronize(self.current_device)
|
|
586
|
+
return splits
|