fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,2042 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# pyre-ignore-all-errors[56]
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import uuid
|
|
14
|
+
from itertools import accumulate
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import fbgemm_gpu # noqa: F401
|
|
18
|
+
import torch # usort:skip
|
|
19
|
+
from torch import nn, Tensor # usort:skip
|
|
20
|
+
|
|
21
|
+
from fbgemm_gpu.config import FeatureGateName
|
|
22
|
+
from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType
|
|
23
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
24
|
+
BoundsCheckMode,
|
|
25
|
+
CacheAlgorithm,
|
|
26
|
+
CacheState,
|
|
27
|
+
construct_cache_state,
|
|
28
|
+
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
29
|
+
EmbeddingLocation,
|
|
30
|
+
EmbeddingSpecInfo,
|
|
31
|
+
get_bounds_check_version_for_platform,
|
|
32
|
+
get_new_embedding_location,
|
|
33
|
+
MAX_PREFETCH_DEPTH,
|
|
34
|
+
PoolingMode,
|
|
35
|
+
RecordCacheMetrics,
|
|
36
|
+
round_up,
|
|
37
|
+
SplitState,
|
|
38
|
+
tensor_to_device,
|
|
39
|
+
)
|
|
40
|
+
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
load_torch_module(
|
|
44
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_inference_gpu",
|
|
45
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_inference",
|
|
46
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_inference",
|
|
47
|
+
)
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
load_torch_module_bc(
|
|
53
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_inference_cpu",
|
|
54
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_inference",
|
|
55
|
+
)
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
import fbgemm_gpu # noqa
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def rounded_row_size_in_bytes(
|
|
63
|
+
dim: int,
|
|
64
|
+
weight_ty: SparseType,
|
|
65
|
+
row_alignment: int,
|
|
66
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
67
|
+
) -> int:
|
|
68
|
+
r = unpadded_row_size_in_bytes(dim, weight_ty, scale_bias_size_in_bytes)
|
|
69
|
+
# align each row to 16-byte boundaries.
|
|
70
|
+
return round_up(r, row_alignment)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def unpadded_row_size_in_bytes(
|
|
74
|
+
dim: int,
|
|
75
|
+
weight_ty: SparseType,
|
|
76
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
77
|
+
) -> int:
|
|
78
|
+
r = {
|
|
79
|
+
SparseType.FP32.value: dim * 4,
|
|
80
|
+
SparseType.FP16.value: dim * 2,
|
|
81
|
+
SparseType.FP8.value: dim,
|
|
82
|
+
SparseType.INT8.value: dim + scale_bias_size_in_bytes,
|
|
83
|
+
SparseType.INT4.value: dim // 2 + scale_bias_size_in_bytes,
|
|
84
|
+
SparseType.INT2.value: dim // 4 + scale_bias_size_in_bytes,
|
|
85
|
+
}[weight_ty.value]
|
|
86
|
+
return r
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def align_to_cacheline(a: int) -> int:
|
|
90
|
+
# align each table to 128b cache line boundary.
|
|
91
|
+
return round_up(a, 128)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def nbit_construct_split_state(
|
|
95
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
|
|
96
|
+
cacheable: bool,
|
|
97
|
+
row_alignment: int,
|
|
98
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
99
|
+
cacheline_alignment: bool = True,
|
|
100
|
+
) -> SplitState:
|
|
101
|
+
placements = torch.jit.annotate(list[EmbeddingLocation], [])
|
|
102
|
+
offsets = torch.jit.annotate(list[int], [])
|
|
103
|
+
dev_size = 0
|
|
104
|
+
host_size = 0
|
|
105
|
+
uvm_size = 0
|
|
106
|
+
for _, num_embeddings, embedding_dim, weight_ty, location in embedding_specs:
|
|
107
|
+
embedding_dim = rounded_row_size_in_bytes(
|
|
108
|
+
embedding_dim, weight_ty, row_alignment, scale_bias_size_in_bytes
|
|
109
|
+
)
|
|
110
|
+
state_size = num_embeddings * embedding_dim
|
|
111
|
+
if cacheline_alignment:
|
|
112
|
+
state_size = align_to_cacheline(state_size)
|
|
113
|
+
if location == EmbeddingLocation.HOST:
|
|
114
|
+
placements.append(EmbeddingLocation.HOST)
|
|
115
|
+
offsets.append(host_size)
|
|
116
|
+
host_size += state_size
|
|
117
|
+
elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA:
|
|
118
|
+
placements.append(location)
|
|
119
|
+
offsets.append(dev_size)
|
|
120
|
+
dev_size += state_size
|
|
121
|
+
else:
|
|
122
|
+
if cacheable and location == EmbeddingLocation.MANAGED_CACHING:
|
|
123
|
+
placements.append(EmbeddingLocation.MANAGED_CACHING)
|
|
124
|
+
else:
|
|
125
|
+
placements.append(EmbeddingLocation.MANAGED)
|
|
126
|
+
offsets.append(uvm_size)
|
|
127
|
+
uvm_size += state_size
|
|
128
|
+
assert len(placements) == len(offsets)
|
|
129
|
+
return SplitState(
|
|
130
|
+
dev_size=dev_size,
|
|
131
|
+
host_size=host_size,
|
|
132
|
+
uvm_size=uvm_size,
|
|
133
|
+
placements=placements,
|
|
134
|
+
offsets=offsets,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def random_quant_scaled_tensor(
|
|
139
|
+
shape: torch.Size,
|
|
140
|
+
device: torch.device,
|
|
141
|
+
output_tensor: Optional[torch.Tensor] = None,
|
|
142
|
+
) -> torch.Tensor:
|
|
143
|
+
if output_tensor is not None:
|
|
144
|
+
return torch.randint(
|
|
145
|
+
0,
|
|
146
|
+
255,
|
|
147
|
+
size=shape,
|
|
148
|
+
out=output_tensor,
|
|
149
|
+
dtype=torch.uint8,
|
|
150
|
+
device=device,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
return torch.randint(
|
|
154
|
+
0,
|
|
155
|
+
255,
|
|
156
|
+
size=shape,
|
|
157
|
+
dtype=torch.uint8,
|
|
158
|
+
device=device,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@torch.fx.wrap
|
|
163
|
+
def inputs_to_device(
|
|
164
|
+
indices: torch.Tensor,
|
|
165
|
+
offsets: torch.Tensor,
|
|
166
|
+
per_sample_weights: Optional[torch.Tensor],
|
|
167
|
+
bounds_check_warning: torch.Tensor,
|
|
168
|
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
169
|
+
if bounds_check_warning.device.type == "meta":
|
|
170
|
+
return indices, offsets, per_sample_weights
|
|
171
|
+
|
|
172
|
+
non_blocking = bounds_check_warning.device.type != "cpu"
|
|
173
|
+
if indices.device != bounds_check_warning.device:
|
|
174
|
+
indices = indices.to(bounds_check_warning.device, non_blocking=non_blocking)
|
|
175
|
+
if offsets.device != bounds_check_warning.device:
|
|
176
|
+
offsets = offsets.to(bounds_check_warning.device, non_blocking=non_blocking)
|
|
177
|
+
if (
|
|
178
|
+
per_sample_weights is not None
|
|
179
|
+
and per_sample_weights.device != bounds_check_warning.device
|
|
180
|
+
):
|
|
181
|
+
per_sample_weights = per_sample_weights.to(
|
|
182
|
+
bounds_check_warning.device, non_blocking=non_blocking
|
|
183
|
+
)
|
|
184
|
+
return indices, offsets, per_sample_weights
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
|
|
188
|
+
class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
189
|
+
"""
|
|
190
|
+
Table-batched version of nn.EmbeddingBag(sparse=False)
|
|
191
|
+
Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]):
|
|
195
|
+
A list of embedding specifications. Each spec describes a
|
|
196
|
+
specification of a physical embedding table. Each one is a tuple of
|
|
197
|
+
number of embedding rows, embedding dimension (must be a multiple of
|
|
198
|
+
4), table placement (`EmbeddingLocation`), and compute device
|
|
199
|
+
(`ComputeDevice`).
|
|
200
|
+
|
|
201
|
+
Available `EmbeddingLocation` options are
|
|
202
|
+
|
|
203
|
+
(1) `DEVICE` = placing an embedding table in the GPU global memory
|
|
204
|
+
(HBM)
|
|
205
|
+
|
|
206
|
+
(2) `MANAGED` = placing an embedding in the unified virtual memory
|
|
207
|
+
(accessible from both GPU and CPU)
|
|
208
|
+
|
|
209
|
+
(3) `MANAGED_CACHING` = placing an embedding table in the unified
|
|
210
|
+
virtual memory and using the GPU global memory (HBM) as a cache
|
|
211
|
+
|
|
212
|
+
(4) `HOST` = placing an embedding table in the CPU memory (DRAM)
|
|
213
|
+
|
|
214
|
+
(5) `MTIA` = placing an embedding table in the MTIA memory
|
|
215
|
+
|
|
216
|
+
Available `ComputeDevice` options are
|
|
217
|
+
|
|
218
|
+
(1) `CPU` = performing table lookup on CPU
|
|
219
|
+
|
|
220
|
+
(2) `CUDA` = performing table lookup on GPU
|
|
221
|
+
|
|
222
|
+
(3) `MTIA` = performing table lookup on MTIA
|
|
223
|
+
|
|
224
|
+
feature_table_map (Optional[List[int]] = None): An optional list that
|
|
225
|
+
specifies feature-table mapping. feature_table_map[i] indicates the
|
|
226
|
+
physical embedding table that feature i maps to.
|
|
227
|
+
|
|
228
|
+
index_remapping (Optional[List[Tensor]] = None): Index remapping for pruning
|
|
229
|
+
|
|
230
|
+
pooling_mode (PoolingMode = PoolingMode.SUM): Pooling mode. Available
|
|
231
|
+
`PoolingMode` options are
|
|
232
|
+
|
|
233
|
+
(1) `SUM` = Sum pooling
|
|
234
|
+
|
|
235
|
+
(2) `MEAN` = Mean pooling
|
|
236
|
+
|
|
237
|
+
(3) `NONE` = No pooling (sequence embedding)
|
|
238
|
+
|
|
239
|
+
device (Optional[Union[str, int, torch.device]] = None): The current
|
|
240
|
+
device to place tensors on
|
|
241
|
+
|
|
242
|
+
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING): Input
|
|
243
|
+
checking mode. Available `BoundsCheckMode` options are
|
|
244
|
+
|
|
245
|
+
(1) `NONE` = skip bounds check
|
|
246
|
+
|
|
247
|
+
(2) `FATAL` = throw an error when encountering an invalid
|
|
248
|
+
index/offset
|
|
249
|
+
|
|
250
|
+
(3) `WARNING` = print a warning message when encountering an
|
|
251
|
+
invalid index/offset and fix it (setting an invalid index to
|
|
252
|
+
zero and adjusting an invalid offset to be within the bound)
|
|
253
|
+
|
|
254
|
+
(4) `IGNORE` = silently fix an invalid index/offset (setting an
|
|
255
|
+
invalid index to zero and adjusting an invalid offset to be
|
|
256
|
+
within the bound)
|
|
257
|
+
|
|
258
|
+
weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None):
|
|
259
|
+
[T]
|
|
260
|
+
|
|
261
|
+
pruning_hash_load_factor (float = 0.5):
|
|
262
|
+
Load factor for pruning hash
|
|
263
|
+
|
|
264
|
+
use_array_for_index_remapping (bool = True):
|
|
265
|
+
If True, use array for index remapping. Otherwise, use hash map.
|
|
266
|
+
|
|
267
|
+
output_dtype (SparseType = SparseType.FP16): The data type of an output
|
|
268
|
+
tensor.
|
|
269
|
+
|
|
270
|
+
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU): The cache
|
|
271
|
+
algorithm (used when `EmbeddingLocation` is set to
|
|
272
|
+
`MANAGED_CACHING`). Options are
|
|
273
|
+
|
|
274
|
+
(1) `LRU` = least recently used
|
|
275
|
+
|
|
276
|
+
(2) `LFU` = least frequently used
|
|
277
|
+
|
|
278
|
+
cache_load_factor (float = 0.2): A factor used for determining the
|
|
279
|
+
cache capacity when `EmbeddingLocation.MANAGED_CACHING` is used.
|
|
280
|
+
The cache capacity is `cache_load_factor` * the total number of
|
|
281
|
+
rows in all embedding tables
|
|
282
|
+
|
|
283
|
+
cache_sets (int = 0): The number of cache sets (used when
|
|
284
|
+
`EmbeddingLocation` is set to `MANAGED_CACHING`)
|
|
285
|
+
|
|
286
|
+
cache_reserved_memory (float = 0.0): The amount of memory reserved in
|
|
287
|
+
HBM for non-cache purpose (used when `EmbeddingLocation` is set to
|
|
288
|
+
`MANAGED_CACHING`).
|
|
289
|
+
|
|
290
|
+
enforce_hbm (bool = False): If True, place all weights/momentums in HBM
|
|
291
|
+
when using `EmbeddingLocation.MANAGED_CACHING`
|
|
292
|
+
|
|
293
|
+
record_cache_metrics (Optional[RecordCacheMetrics] = None): Record
|
|
294
|
+
a number of hits, a number of requests, etc if
|
|
295
|
+
`RecordCacheMetrics.record_cache_miss_counter` is True and record
|
|
296
|
+
the similar metrics table-wise if
|
|
297
|
+
`RecordCacheMetrics.record_tablewise_cache_miss is True`
|
|
298
|
+
|
|
299
|
+
gather_uvm_cache_stats (Optional[bool] = False): If True, collect the
|
|
300
|
+
cache statistics when `EmbeddingLocation` is set to
|
|
301
|
+
`MANAGED_CACHING`
|
|
302
|
+
|
|
303
|
+
row_alignment (Optional[int] = None): Row alignment
|
|
304
|
+
|
|
305
|
+
fp8_exponent_bits (Optional[int] = None): Exponent bits when using FP8
|
|
306
|
+
|
|
307
|
+
fp8_exponent_bias (Optional[int] = None): Exponent bias when using FP8
|
|
308
|
+
|
|
309
|
+
cache_assoc (int = 32): Number of ways for cache
|
|
310
|
+
|
|
311
|
+
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES): Size
|
|
312
|
+
of scale and bias in bytes
|
|
313
|
+
|
|
314
|
+
cacheline_alignment (bool = True): If True, align each table to 128b
|
|
315
|
+
cache line boundary
|
|
316
|
+
|
|
317
|
+
uvm_host_mapped (bool = False): If True, allocate every UVM tensor
|
|
318
|
+
using `malloc` + `cudaHostRegister`. Otherwise use
|
|
319
|
+
`cudaMallocManaged`
|
|
320
|
+
|
|
321
|
+
reverse_qparam (bool = False): If True, load `qparams` at end of each
|
|
322
|
+
row. Otherwise, load `qparams` at begnning of each row.
|
|
323
|
+
|
|
324
|
+
feature_names_per_table (Optional[List[List[str]]] = None): An optional
|
|
325
|
+
list that specifies feature names per table. `feature_names_per_table[t]`
|
|
326
|
+
indicates the feature names of table `t`.
|
|
327
|
+
|
|
328
|
+
indices_dtype (torch.dtype = torch.int32): The expected dtype of the
|
|
329
|
+
indices tensor that will be passed to the `forward()` call. This
|
|
330
|
+
information will be used to construct the remap_indices array/hash.
|
|
331
|
+
Options are `torch.int32` and `torch.int64`.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
|
|
335
|
+
record_cache_metrics: RecordCacheMetrics
|
|
336
|
+
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
|
|
337
|
+
cache_miss_counter: torch.Tensor
|
|
338
|
+
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
|
339
|
+
uvm_cache_stats: torch.Tensor
|
|
340
|
+
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
|
|
341
|
+
local_uvm_cache_stats: torch.Tensor
|
|
342
|
+
# pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
|
|
343
|
+
weights_offsets: torch.Tensor
|
|
344
|
+
# pyre-fixme[13]: Attribute `weights_placements` is never initialized.
|
|
345
|
+
weights_placements: torch.Tensor
|
|
346
|
+
|
|
347
|
+
def __init__( # noqa C901
|
|
348
|
+
self,
|
|
349
|
+
embedding_specs: list[
|
|
350
|
+
tuple[str, int, int, SparseType, EmbeddingLocation]
|
|
351
|
+
], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
|
|
352
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
353
|
+
index_remapping: Optional[list[Tensor]] = None,
|
|
354
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
355
|
+
device: Optional[Union[str, int, torch.device]] = None,
|
|
356
|
+
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
357
|
+
weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
|
|
358
|
+
pruning_hash_load_factor: float = 0.5,
|
|
359
|
+
use_array_for_index_remapping: bool = True,
|
|
360
|
+
output_dtype: SparseType = SparseType.FP16,
|
|
361
|
+
cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
|
|
362
|
+
cache_load_factor: float = 0.2,
|
|
363
|
+
cache_sets: int = 0,
|
|
364
|
+
cache_reserved_memory: float = 0.0,
|
|
365
|
+
enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
|
|
366
|
+
record_cache_metrics: Optional[RecordCacheMetrics] = None,
|
|
367
|
+
gather_uvm_cache_stats: Optional[bool] = False,
|
|
368
|
+
row_alignment: Optional[int] = None,
|
|
369
|
+
fp8_exponent_bits: Optional[int] = None,
|
|
370
|
+
fp8_exponent_bias: Optional[int] = None,
|
|
371
|
+
cache_assoc: int = 32,
|
|
372
|
+
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
373
|
+
cacheline_alignment: bool = True,
|
|
374
|
+
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
|
|
375
|
+
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
|
|
376
|
+
feature_names_per_table: Optional[list[list[str]]] = None,
|
|
377
|
+
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
|
|
378
|
+
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
379
|
+
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
380
|
+
self.uuid = str(uuid.uuid4())
|
|
381
|
+
self.log(
|
|
382
|
+
f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# 64 for AMD
|
|
386
|
+
if cache_assoc == 32 and torch.version.hip is not None:
|
|
387
|
+
cache_assoc = 64
|
|
388
|
+
|
|
389
|
+
if device is None:
|
|
390
|
+
self.current_device: torch.device = torch.device(
|
|
391
|
+
torch.cuda.current_device()
|
|
392
|
+
)
|
|
393
|
+
elif isinstance(device, torch.device):
|
|
394
|
+
self.current_device = device
|
|
395
|
+
else:
|
|
396
|
+
self.current_device = torch.device(device)
|
|
397
|
+
self.use_cpu: bool = self.current_device.type == "cpu"
|
|
398
|
+
|
|
399
|
+
self.scale_bias_size_in_bytes = scale_bias_size_in_bytes
|
|
400
|
+
self.pooling_mode = pooling_mode
|
|
401
|
+
self.bounds_check_mode_int: int = bounds_check_mode.value
|
|
402
|
+
self.embedding_specs = embedding_specs
|
|
403
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
404
|
+
self.uvm_host_mapped = uvm_host_mapped
|
|
405
|
+
self.feature_names_per_table = feature_names_per_table
|
|
406
|
+
self.indices_dtype = indices_dtype
|
|
407
|
+
# (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
|
|
408
|
+
# Pyre workaround
|
|
409
|
+
self.feature_names: list[str] = [e[0] for e in embedding_specs]
|
|
410
|
+
self.cache_load_factor: float = cache_load_factor
|
|
411
|
+
self.cache_sets: int = cache_sets
|
|
412
|
+
self.cache_reserved_memory: float = cache_reserved_memory
|
|
413
|
+
rows: list[int] = [e[1] for e in embedding_specs]
|
|
414
|
+
dims: list[int] = [e[2] for e in embedding_specs]
|
|
415
|
+
weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
|
|
416
|
+
locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
|
|
417
|
+
# if target device is meta then we set use_cpu based on the embedding location
|
|
418
|
+
# information in embedding_specs.
|
|
419
|
+
if self.current_device.type == "meta":
|
|
420
|
+
self.use_cpu = all(loc == EmbeddingLocation.HOST for loc in locations)
|
|
421
|
+
|
|
422
|
+
if row_alignment is None:
|
|
423
|
+
self.row_alignment: int = 1 if self.use_cpu else 16
|
|
424
|
+
else:
|
|
425
|
+
self.row_alignment = row_alignment
|
|
426
|
+
|
|
427
|
+
if record_cache_metrics is not None:
|
|
428
|
+
self.record_cache_metrics = record_cache_metrics
|
|
429
|
+
else:
|
|
430
|
+
self.record_cache_metrics = RecordCacheMetrics(False, False)
|
|
431
|
+
|
|
432
|
+
self.gather_uvm_cache_stats = gather_uvm_cache_stats
|
|
433
|
+
# Define the size of uvm cache stats as class variable
|
|
434
|
+
# to make it work with torch jit script.
|
|
435
|
+
self.uvm_cache_stats_size = 6
|
|
436
|
+
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
|
|
437
|
+
# 4: N_conflict_unique_misses, 5: N_conflict_misses
|
|
438
|
+
|
|
439
|
+
# mixed D is not supported by no bag kernels
|
|
440
|
+
mixed_D = not all(d == dims[0] for d in dims)
|
|
441
|
+
if mixed_D:
|
|
442
|
+
assert (
|
|
443
|
+
self.pooling_mode != PoolingMode.NONE
|
|
444
|
+
), "Mixed dimension tables are only supported for pooling tables."
|
|
445
|
+
|
|
446
|
+
assert not self.use_cpu or all(
|
|
447
|
+
loc == EmbeddingLocation.HOST for loc in locations
|
|
448
|
+
), "CPU device requires EmbeddingLocation.HOST for location!"
|
|
449
|
+
assert self.use_cpu or all(
|
|
450
|
+
loc != EmbeddingLocation.HOST for loc in locations
|
|
451
|
+
), "EmbeddingLocation.HOST doesn't work for CUDA device!"
|
|
452
|
+
|
|
453
|
+
T_ = len(self.embedding_specs)
|
|
454
|
+
assert T_ > 0
|
|
455
|
+
|
|
456
|
+
self.feature_table_map: list[int] = (
|
|
457
|
+
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
458
|
+
)
|
|
459
|
+
T = len(self.feature_table_map)
|
|
460
|
+
assert T_ <= T
|
|
461
|
+
|
|
462
|
+
table_has_feature = [False] * T_
|
|
463
|
+
for t in self.feature_table_map:
|
|
464
|
+
table_has_feature[t] = True
|
|
465
|
+
assert all(table_has_feature), "Each table must have at least one feature!"
|
|
466
|
+
D_offsets = [dims[t] for t in self.feature_table_map]
|
|
467
|
+
D_offsets = [0] + list(accumulate(D_offsets))
|
|
468
|
+
self.total_D: int = D_offsets[-1]
|
|
469
|
+
for dim, weight_ty in zip(dims, weights_tys):
|
|
470
|
+
if not weight_ty.is_float():
|
|
471
|
+
assert (
|
|
472
|
+
dim % (8 / weight_ty.bit_rate()) == 0
|
|
473
|
+
), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"
|
|
474
|
+
|
|
475
|
+
def max_ty_D(ty: SparseType) -> int:
|
|
476
|
+
return max(
|
|
477
|
+
[
|
|
478
|
+
dim
|
|
479
|
+
for dim, weight_ty in zip(dims, weights_tys)
|
|
480
|
+
if weight_ty == ty or weight_ty.value == ty.value
|
|
481
|
+
],
|
|
482
|
+
default=0,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self.max_int2_D: int = max_ty_D(SparseType.INT2)
|
|
486
|
+
self.max_int4_D: int = max_ty_D(SparseType.INT4)
|
|
487
|
+
self.max_int8_D: int = max_ty_D(SparseType.INT8)
|
|
488
|
+
self.max_float8_D: int = max_ty_D(SparseType.FP8)
|
|
489
|
+
self.max_float16_D: int = max_ty_D(SparseType.FP16)
|
|
490
|
+
self.max_float32_D: int = max_ty_D(SparseType.FP32)
|
|
491
|
+
|
|
492
|
+
self.register_buffer(
|
|
493
|
+
"D_offsets",
|
|
494
|
+
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
|
|
495
|
+
)
|
|
496
|
+
assert self.D_offsets.numel() == T + 1
|
|
497
|
+
|
|
498
|
+
self.register_buffer(
|
|
499
|
+
"rows_per_table",
|
|
500
|
+
torch.tensor(
|
|
501
|
+
[rows[t] for t in self.feature_table_map],
|
|
502
|
+
device=self.current_device,
|
|
503
|
+
dtype=torch.int64,
|
|
504
|
+
),
|
|
505
|
+
)
|
|
506
|
+
self.register_buffer(
|
|
507
|
+
"bounds_check_warning",
|
|
508
|
+
torch.tensor([0], device=self.current_device, dtype=torch.int64),
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
|
|
512
|
+
self.register_buffer(
|
|
513
|
+
"weights_tys",
|
|
514
|
+
torch.tensor(
|
|
515
|
+
weights_tys_int, device=self.current_device, dtype=torch.uint8
|
|
516
|
+
),
|
|
517
|
+
)
|
|
518
|
+
self.weight_initialized: bool = False
|
|
519
|
+
|
|
520
|
+
self.weights_dev: torch.Tensor = torch.zeros(
|
|
521
|
+
0,
|
|
522
|
+
device=self.current_device,
|
|
523
|
+
dtype=torch.uint8,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
self.weights_host: torch.Tensor = torch.zeros(
|
|
527
|
+
0, device=self.current_device, dtype=torch.uint8
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self.weights_uvm: torch.Tensor = torch.empty(
|
|
531
|
+
0, device=self.current_device, dtype=torch.uint8
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
cached_dims = [
|
|
535
|
+
rounded_row_size_in_bytes(
|
|
536
|
+
embedding_spec[2], embedding_spec[3], 16, self.scale_bias_size_in_bytes
|
|
537
|
+
)
|
|
538
|
+
for embedding_spec in self.embedding_specs
|
|
539
|
+
if embedding_spec[4] == EmbeddingLocation.MANAGED_CACHING
|
|
540
|
+
]
|
|
541
|
+
self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
|
|
542
|
+
|
|
543
|
+
self.initialize_physical_weights_placements_and_offsets(cacheline_alignment)
|
|
544
|
+
self.enforce_hbm: bool = enforce_hbm
|
|
545
|
+
|
|
546
|
+
self.reverse_qparam = reverse_qparam
|
|
547
|
+
# Assign weights after weights and weights_offsets are initialized.
|
|
548
|
+
if weight_lists:
|
|
549
|
+
self._apply_split(
|
|
550
|
+
self.dev_size,
|
|
551
|
+
self.host_size,
|
|
552
|
+
self.uvm_size,
|
|
553
|
+
self.weights_physical_placements,
|
|
554
|
+
self.weights_physical_offsets,
|
|
555
|
+
self.enforce_hbm,
|
|
556
|
+
)
|
|
557
|
+
self.assign_embedding_weights(weight_lists)
|
|
558
|
+
|
|
559
|
+
# Handle index remapping for embedding pruning.
|
|
560
|
+
# All buffers are int64 in order to support both int32 and int64 indices.
|
|
561
|
+
self.register_buffer(
|
|
562
|
+
"index_remappings_array_offsets",
|
|
563
|
+
torch.empty(0, device=self.current_device, dtype=torch.int64),
|
|
564
|
+
)
|
|
565
|
+
self.register_buffer(
|
|
566
|
+
"index_remappings_array",
|
|
567
|
+
torch.empty(0, device=self.current_device, dtype=self.indices_dtype),
|
|
568
|
+
)
|
|
569
|
+
self.register_buffer(
|
|
570
|
+
"index_remapping_hash_table_offsets",
|
|
571
|
+
torch.empty(0, device=self.current_device, dtype=torch.int64),
|
|
572
|
+
)
|
|
573
|
+
self.register_buffer(
|
|
574
|
+
"index_remapping_hash_table",
|
|
575
|
+
torch.empty(0, device=self.current_device, dtype=self.indices_dtype),
|
|
576
|
+
)
|
|
577
|
+
self.register_buffer(
|
|
578
|
+
"original_rows_per_table",
|
|
579
|
+
torch.empty(0, device=self.current_device, dtype=torch.int64),
|
|
580
|
+
)
|
|
581
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
582
|
+
self.index_remapping_hash_table_cpu = None
|
|
583
|
+
|
|
584
|
+
if index_remapping:
|
|
585
|
+
self.set_index_remappings(
|
|
586
|
+
index_remapping, pruning_hash_load_factor, use_array_for_index_remapping
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Currently only support cache_precision == embedding_precision.
|
|
590
|
+
# Both are represented as uint8_t
|
|
591
|
+
cache_state = construct_cache_state(rows, locations, self.feature_table_map)
|
|
592
|
+
|
|
593
|
+
if self.record_cache_metrics.record_tablewise_cache_miss:
|
|
594
|
+
num_tables = len(cache_state.cache_hash_size_cumsum) - 1
|
|
595
|
+
self.register_buffer(
|
|
596
|
+
"table_wise_cache_miss",
|
|
597
|
+
torch.zeros(
|
|
598
|
+
num_tables,
|
|
599
|
+
device=self.current_device,
|
|
600
|
+
dtype=torch.int64,
|
|
601
|
+
),
|
|
602
|
+
)
|
|
603
|
+
# NOTE: make TorchScript work!
|
|
604
|
+
else:
|
|
605
|
+
self.register_buffer(
|
|
606
|
+
"table_wise_cache_miss",
|
|
607
|
+
torch.zeros(
|
|
608
|
+
0,
|
|
609
|
+
device=self.current_device,
|
|
610
|
+
dtype=torch.int64,
|
|
611
|
+
),
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
self.cache_assoc = cache_assoc
|
|
615
|
+
self._apply_cache_state(
|
|
616
|
+
cache_state,
|
|
617
|
+
cache_algorithm,
|
|
618
|
+
cache_load_factor,
|
|
619
|
+
cache_sets,
|
|
620
|
+
cache_reserved_memory,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
if self.max_float8_D > 0:
|
|
624
|
+
default_config = SparseType.FP8.default_config()
|
|
625
|
+
self.fp8_exponent_bits: int = (
|
|
626
|
+
default_config.get("exponent_bits")
|
|
627
|
+
if fp8_exponent_bits is None
|
|
628
|
+
else fp8_exponent_bits
|
|
629
|
+
)
|
|
630
|
+
self.fp8_exponent_bias: int = (
|
|
631
|
+
default_config.get("exponent_bias")
|
|
632
|
+
if fp8_exponent_bias is None
|
|
633
|
+
else fp8_exponent_bias
|
|
634
|
+
)
|
|
635
|
+
else:
|
|
636
|
+
self.fp8_exponent_bits = -1
|
|
637
|
+
self.fp8_exponent_bias = -1
|
|
638
|
+
|
|
639
|
+
self.bounds_check_version: int = get_bounds_check_version_for_platform()
|
|
640
|
+
|
|
641
|
+
@torch.jit.ignore
|
|
642
|
+
def log(self, msg: str) -> None:
|
|
643
|
+
"""
|
|
644
|
+
Log with TBE id prefix to distinguish between multiple TBE instances
|
|
645
|
+
per process
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
msg (str): The message to print
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
None
|
|
652
|
+
"""
|
|
653
|
+
logging.info(f"[TBE={self.uuid}] {msg}")
|
|
654
|
+
|
|
655
|
+
def get_cache_miss_counter(self) -> Tensor:
|
|
656
|
+
# cache_miss_counter[0]: cache_miss_forward_count which records the total number of forwards which has at least one cache miss
|
|
657
|
+
# cache_miss_counter[1]: unique_cache_miss_count which records to total number of unique (dedup) cache misses
|
|
658
|
+
# cache_miss_counter[2]: total number of unique (dedup) access count
|
|
659
|
+
# cache_miss_counter[3]: total number of non-dedup access count
|
|
660
|
+
|
|
661
|
+
# How to get cache miss ratio
|
|
662
|
+
# cache miss ratio (# of missed entries / # of unique requests): ( cache_miss_counter[1] / cache_miss_counter[2] )
|
|
663
|
+
# cache miss ratio (# of missed entries / # of total access): ( cache_miss_counter[1] / cache_miss_counter[3] )
|
|
664
|
+
assert (
|
|
665
|
+
self.record_cache_metrics.record_cache_miss_counter
|
|
666
|
+
), "record_cache_miss_counter should be true to access counter values"
|
|
667
|
+
|
|
668
|
+
return self.cache_miss_counter
|
|
669
|
+
|
|
670
|
+
@torch.jit.export
|
|
671
|
+
def get_table_wise_cache_miss(self) -> Tensor:
|
|
672
|
+
assert (
|
|
673
|
+
self.record_cache_metrics.record_tablewise_cache_miss
|
|
674
|
+
), "record_tablewise_cache_miss should be true to access counter values"
|
|
675
|
+
# table_wise_cache_miss contains all the cache miss count for each table in this embedding table object:
|
|
676
|
+
return self.table_wise_cache_miss
|
|
677
|
+
|
|
678
|
+
@torch.jit.export
|
|
679
|
+
def get_feature_num_per_table(self) -> list[int]:
|
|
680
|
+
if self.feature_names_per_table is None:
|
|
681
|
+
return []
|
|
682
|
+
return [len(feature_names) for feature_names in self.feature_names_per_table]
|
|
683
|
+
|
|
684
|
+
def reset_cache_miss_counter(self) -> None:
|
|
685
|
+
assert (
|
|
686
|
+
self.record_cache_metrics.record_cache_miss_counter
|
|
687
|
+
), "record_cache_miss_counter should be true to access counter values"
|
|
688
|
+
self.cache_miss_counter = torch.tensor(
|
|
689
|
+
[0, 0, 0, 0], device=self.current_device, dtype=torch.int64
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
def reset_uvm_cache_stats(self) -> None:
|
|
693
|
+
assert (
|
|
694
|
+
self.gather_uvm_cache_stats
|
|
695
|
+
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
696
|
+
self.uvm_cache_stats.zero_()
|
|
697
|
+
self.local_uvm_cache_stats.zero_()
|
|
698
|
+
|
|
699
|
+
def print_cache_miss_counter(self) -> None:
|
|
700
|
+
assert (
|
|
701
|
+
self.record_cache_metrics.record_cache_miss_counter
|
|
702
|
+
), "record_cache_miss_counter should be true to access counter values"
|
|
703
|
+
self.log(
|
|
704
|
+
f"\n"
|
|
705
|
+
f"Miss counter value [0] - # of miss occured iters : {self.cache_miss_counter[0]}, \n"
|
|
706
|
+
f"Miss counter value [1] - # of unique misses : {self.cache_miss_counter[1]}, \n"
|
|
707
|
+
f"Miss counter value [2] - # of unique requested indices : {self.cache_miss_counter[2]}, \n"
|
|
708
|
+
f"Miss counter value [3] - # of total requested indices : {self.cache_miss_counter[3]}, "
|
|
709
|
+
)
|
|
710
|
+
self.log(
|
|
711
|
+
f"unique_miss_rate using counter : {self.cache_miss_counter[1] / self.cache_miss_counter[2]}, \n"
|
|
712
|
+
)
|
|
713
|
+
self.log(
|
|
714
|
+
f"total_miss_rate using counter : {self.cache_miss_counter[1] / self.cache_miss_counter[3]}, \n"
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def get_uvm_cache_stats(self) -> Tensor:
|
|
718
|
+
assert (
|
|
719
|
+
self.gather_uvm_cache_stats
|
|
720
|
+
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
721
|
+
return self.uvm_cache_stats
|
|
722
|
+
|
|
723
|
+
def print_uvm_cache_stats(self) -> None:
|
|
724
|
+
assert (
|
|
725
|
+
self.gather_uvm_cache_stats
|
|
726
|
+
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
727
|
+
uvm_cache_stats = self.uvm_cache_stats.tolist()
|
|
728
|
+
self.log(
|
|
729
|
+
f"N_called: {uvm_cache_stats[0]}\n"
|
|
730
|
+
f"N_requested_indices: {uvm_cache_stats[1]}\n"
|
|
731
|
+
f"N_unique_indices: {uvm_cache_stats[2]}\n"
|
|
732
|
+
f"N_unique_misses: {uvm_cache_stats[3]}\n"
|
|
733
|
+
f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
|
|
734
|
+
f"N_conflict_misses: {uvm_cache_stats[5]}\n"
|
|
735
|
+
)
|
|
736
|
+
if uvm_cache_stats[1]:
|
|
737
|
+
self.log(
|
|
738
|
+
f"unique indices / requested indices: {uvm_cache_stats[2] / uvm_cache_stats[1]}\n"
|
|
739
|
+
f"unique misses / requested indices: {uvm_cache_stats[3] / uvm_cache_stats[1]}\n"
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
@torch.jit.export
|
|
743
|
+
def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
|
|
744
|
+
self.timestep_counter.increment()
|
|
745
|
+
self.timestep_prefetch_size.increment()
|
|
746
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
747
|
+
# a function.
|
|
748
|
+
if not self.lxu_cache_weights.numel():
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
752
|
+
self.cache_hash_size_cumsum,
|
|
753
|
+
indices,
|
|
754
|
+
offsets,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
if (
|
|
758
|
+
self.record_cache_metrics.record_cache_miss_counter
|
|
759
|
+
or self.record_cache_metrics.record_tablewise_cache_miss
|
|
760
|
+
):
|
|
761
|
+
lxu_cache_locations = (
|
|
762
|
+
torch.ops.fbgemm.lxu_cache_lookup(
|
|
763
|
+
linear_cache_indices,
|
|
764
|
+
self.lxu_cache_state,
|
|
765
|
+
self.total_cache_hash_size,
|
|
766
|
+
)
|
|
767
|
+
if self.cache_assoc in [32, 64]
|
|
768
|
+
else torch.ops.fbgemm.direct_mapped_lxu_cache_lookup(
|
|
769
|
+
linear_cache_indices,
|
|
770
|
+
self.lxu_cache_state,
|
|
771
|
+
self.total_cache_hash_size,
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
if self.record_cache_metrics.record_cache_miss_counter:
|
|
775
|
+
self._update_cache_miss_counter(
|
|
776
|
+
lxu_cache_locations, linear_cache_indices
|
|
777
|
+
)
|
|
778
|
+
if self.record_cache_metrics.record_tablewise_cache_miss:
|
|
779
|
+
self._update_tablewise_cache_miss(
|
|
780
|
+
lxu_cache_locations, linear_cache_indices, offsets
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
if self.cache_assoc in [32, 64]:
|
|
784
|
+
# 64 for AMD
|
|
785
|
+
self.prefetch_32way(linear_cache_indices)
|
|
786
|
+
elif self.cache_assoc == 1:
|
|
787
|
+
self.prefetch_1way(linear_cache_indices)
|
|
788
|
+
else:
|
|
789
|
+
raise ValueError(f"{self.cache_assoc} not in [1, 32, 64]")
|
|
790
|
+
|
|
791
|
+
def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
|
|
792
|
+
if self.cache_algorithm == CacheAlgorithm.LRU:
|
|
793
|
+
torch.ops.fbgemm.lru_cache_populate_byte(
|
|
794
|
+
self.weights_uvm,
|
|
795
|
+
self.cache_hash_size_cumsum,
|
|
796
|
+
self.total_cache_hash_size,
|
|
797
|
+
self.cache_index_table_map,
|
|
798
|
+
self.weights_offsets,
|
|
799
|
+
self.weights_tys,
|
|
800
|
+
self.D_offsets,
|
|
801
|
+
linear_cache_indices,
|
|
802
|
+
self.lxu_cache_state,
|
|
803
|
+
self.lxu_cache_weights,
|
|
804
|
+
self.timestep_counter.get(),
|
|
805
|
+
self.lxu_state,
|
|
806
|
+
16, # row_alignment; using default value.
|
|
807
|
+
self.gather_uvm_cache_stats,
|
|
808
|
+
self.local_uvm_cache_stats,
|
|
809
|
+
)
|
|
810
|
+
elif self.cache_algorithm == CacheAlgorithm.LFU:
|
|
811
|
+
torch.ops.fbgemm.lfu_cache_populate_byte(
|
|
812
|
+
self.weights_uvm,
|
|
813
|
+
self.cache_hash_size_cumsum,
|
|
814
|
+
self.total_cache_hash_size,
|
|
815
|
+
self.cache_index_table_map,
|
|
816
|
+
self.weights_offsets,
|
|
817
|
+
self.weights_tys,
|
|
818
|
+
self.D_offsets,
|
|
819
|
+
linear_cache_indices,
|
|
820
|
+
self.lxu_cache_state,
|
|
821
|
+
self.lxu_cache_weights,
|
|
822
|
+
self.lxu_state,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
assert (
|
|
826
|
+
self.lxu_cache_locations_list.size() < self.max_prefetch_depth
|
|
827
|
+
), f"self.lxu_cache_locations_list has grown to size: {self.lxu_cache_locations_list.size()}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
|
|
828
|
+
self.lxu_cache_locations_list.push(
|
|
829
|
+
torch.ops.fbgemm.lxu_cache_lookup(
|
|
830
|
+
linear_cache_indices,
|
|
831
|
+
self.lxu_cache_state,
|
|
832
|
+
self.total_cache_hash_size,
|
|
833
|
+
self.gather_uvm_cache_stats,
|
|
834
|
+
self.local_uvm_cache_stats,
|
|
835
|
+
)
|
|
836
|
+
)
|
|
837
|
+
if self.gather_uvm_cache_stats:
|
|
838
|
+
self._accumulate_uvm_cache_stats()
|
|
839
|
+
|
|
840
|
+
def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
|
|
841
|
+
if self.cache_algorithm == CacheAlgorithm.LRU:
|
|
842
|
+
torch.ops.fbgemm.direct_mapped_lru_cache_populate_byte(
|
|
843
|
+
self.weights_uvm,
|
|
844
|
+
self.cache_hash_size_cumsum,
|
|
845
|
+
self.total_cache_hash_size,
|
|
846
|
+
self.cache_index_table_map,
|
|
847
|
+
self.weights_offsets,
|
|
848
|
+
self.weights_tys,
|
|
849
|
+
self.D_offsets,
|
|
850
|
+
linear_cache_indices,
|
|
851
|
+
self.lxu_cache_state,
|
|
852
|
+
self.lxu_cache_weights,
|
|
853
|
+
self.timestep_counter.get(),
|
|
854
|
+
self.lxu_state,
|
|
855
|
+
self.lxu_cache_miss_timestamp,
|
|
856
|
+
16, # row_alignment; using default value.
|
|
857
|
+
self.gather_uvm_cache_stats,
|
|
858
|
+
self.local_uvm_cache_stats,
|
|
859
|
+
)
|
|
860
|
+
else:
|
|
861
|
+
raise ValueError("Direct Mapped for LRU only")
|
|
862
|
+
|
|
863
|
+
assert (
|
|
864
|
+
self.lxu_cache_locations_list.size() < self.max_prefetch_depth
|
|
865
|
+
), f"self.lxu_cache_locations_list has grown to size: {self.lxu_cache_locations_list.size()}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
|
|
866
|
+
self.lxu_cache_locations_list.push(
|
|
867
|
+
torch.ops.fbgemm.direct_mapped_lxu_cache_lookup(
|
|
868
|
+
linear_cache_indices,
|
|
869
|
+
self.lxu_cache_state,
|
|
870
|
+
self.total_cache_hash_size,
|
|
871
|
+
self.gather_uvm_cache_stats,
|
|
872
|
+
self.local_uvm_cache_stats,
|
|
873
|
+
)
|
|
874
|
+
)
|
|
875
|
+
if self.gather_uvm_cache_stats:
|
|
876
|
+
self._accumulate_uvm_cache_stats()
|
|
877
|
+
|
|
878
|
+
def _accumulate_uvm_cache_stats(self) -> None:
|
|
879
|
+
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
|
|
880
|
+
# We may wanna do this accumulation atomically, but as it's only for monitoring,
|
|
881
|
+
# slightly inaccurate result may be acceptable.
|
|
882
|
+
self.uvm_cache_stats = torch.add(
|
|
883
|
+
self.uvm_cache_stats, self.local_uvm_cache_stats
|
|
884
|
+
)
|
|
885
|
+
self.local_uvm_cache_stats.zero_()
|
|
886
|
+
|
|
887
|
+
def _update_cache_miss_counter(
|
|
888
|
+
self,
|
|
889
|
+
lxu_cache_locations: Tensor,
|
|
890
|
+
linear_cache_indices: Tensor,
|
|
891
|
+
) -> None:
|
|
892
|
+
CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32)
|
|
893
|
+
CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32)
|
|
894
|
+
|
|
895
|
+
cache_missed_locations = torch.where(
|
|
896
|
+
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
|
|
897
|
+
)
|
|
898
|
+
unique_ids_list = torch.unique(cache_missed_locations)
|
|
899
|
+
unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
|
|
900
|
+
|
|
901
|
+
miss_count = torch.sum(unique_ids_count_list)
|
|
902
|
+
|
|
903
|
+
self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64)
|
|
904
|
+
|
|
905
|
+
self.cache_miss_counter[1] += miss_count
|
|
906
|
+
|
|
907
|
+
# Number of unique requests
|
|
908
|
+
assert (
|
|
909
|
+
len(linear_cache_indices.size()) == 1
|
|
910
|
+
), f"linear_cache_indices should be 1-D was {len(linear_cache_indices.size())}-D"
|
|
911
|
+
|
|
912
|
+
assert (
|
|
913
|
+
self.cache_miss_counter.size()[0] == 4
|
|
914
|
+
), f"self.cache_miss_counter should be 4-D was {self.cache_miss_counter.size()[0]}-D"
|
|
915
|
+
|
|
916
|
+
self.cache_miss_counter[2] += torch.unique(linear_cache_indices).size()[0]
|
|
917
|
+
|
|
918
|
+
# Number of total requests
|
|
919
|
+
self.cache_miss_counter[3] += linear_cache_indices.size()[0]
|
|
920
|
+
|
|
921
|
+
def _update_tablewise_cache_miss(
|
|
922
|
+
self,
|
|
923
|
+
lxu_cache_locations: Tensor,
|
|
924
|
+
linear_cache_indices: Tensor,
|
|
925
|
+
offsets: Tensor,
|
|
926
|
+
) -> None:
|
|
927
|
+
CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32)
|
|
928
|
+
CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32)
|
|
929
|
+
|
|
930
|
+
# pyre-fixme[6]: For 1st argument expected
|
|
931
|
+
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`.
|
|
932
|
+
num_tables = len(self.cache_hash_size_cumsum) - 1
|
|
933
|
+
num_offsets_per_table = (len(offsets) - 1) // num_tables
|
|
934
|
+
cache_missed_locations = torch.where(
|
|
935
|
+
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
for i in range(num_tables):
|
|
939
|
+
start = offsets[i * num_offsets_per_table]
|
|
940
|
+
end = offsets[(i + 1) * num_offsets_per_table]
|
|
941
|
+
|
|
942
|
+
current_cache_missed_locations = cache_missed_locations[start:end]
|
|
943
|
+
unique_ids_list = torch.unique(current_cache_missed_locations)
|
|
944
|
+
unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
|
|
945
|
+
|
|
946
|
+
miss_count = torch.sum(unique_ids_count_list)
|
|
947
|
+
|
|
948
|
+
self.table_wise_cache_miss[i] += miss_count
|
|
949
|
+
|
|
950
|
+
def _forward_impl(
|
|
951
|
+
self,
|
|
952
|
+
indices: Tensor,
|
|
953
|
+
offsets: Tensor,
|
|
954
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
955
|
+
) -> Tensor:
|
|
956
|
+
assert (
|
|
957
|
+
self.weight_initialized
|
|
958
|
+
), "weight needs to be initialized before forward function"
|
|
959
|
+
|
|
960
|
+
indices, offsets, per_sample_weights = inputs_to_device(
|
|
961
|
+
indices, offsets, per_sample_weights, self.bounds_check_warning
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
# First bound check: check if the indices/offsets are within the boundary
|
|
965
|
+
# of the original embedding rows before pruning.
|
|
966
|
+
# Note that this is only applied when we enable pruning (if the perf becomes
|
|
967
|
+
# an issue, we can fuse it inside the remapping kernel).
|
|
968
|
+
if (
|
|
969
|
+
self.index_remapping_hash_table_cpu is not None
|
|
970
|
+
or self.index_remapping_hash_table.numel() > 0
|
|
971
|
+
or self.index_remappings_array.numel() > 0
|
|
972
|
+
):
|
|
973
|
+
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
|
|
974
|
+
torch.ops.fbgemm.bounds_check_indices(
|
|
975
|
+
self.original_rows_per_table,
|
|
976
|
+
indices,
|
|
977
|
+
offsets,
|
|
978
|
+
self.bounds_check_mode_int,
|
|
979
|
+
self.bounds_check_warning,
|
|
980
|
+
per_sample_weights,
|
|
981
|
+
bounds_check_version=self.bounds_check_version,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
# Index remapping changes input indices, and some of them becomes -1 (prunned rows).
|
|
985
|
+
# Hence, remapping should be done before prefetch and emb lookup
|
|
986
|
+
# so that these operations are with the remapped indices.
|
|
987
|
+
if self.index_remapping_hash_table_cpu is not None:
|
|
988
|
+
indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets)
|
|
989
|
+
elif self.index_remapping_hash_table.numel() > 0:
|
|
990
|
+
# Convert from raw indices to pruned indices
|
|
991
|
+
indices = torch.ops.fbgemm.pruned_hashmap_lookup(
|
|
992
|
+
indices,
|
|
993
|
+
offsets,
|
|
994
|
+
self.index_remapping_hash_table,
|
|
995
|
+
self.index_remapping_hash_table_offsets,
|
|
996
|
+
)
|
|
997
|
+
elif self.index_remappings_array.numel() > 0:
|
|
998
|
+
indices = torch.ops.fbgemm.pruned_array_lookup(
|
|
999
|
+
indices,
|
|
1000
|
+
offsets,
|
|
1001
|
+
self.index_remappings_array,
|
|
1002
|
+
self.index_remappings_array_offsets,
|
|
1003
|
+
)
|
|
1004
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
1005
|
+
# a function.
|
|
1006
|
+
if self.lxu_cache_weights.numel() > 0:
|
|
1007
|
+
if self.timestep_prefetch_size.get() <= 0:
|
|
1008
|
+
self.prefetch(indices, offsets)
|
|
1009
|
+
self.timestep_prefetch_size.decrement()
|
|
1010
|
+
|
|
1011
|
+
lxu_cache_locations = self.lxu_cache_locations_list.pop()
|
|
1012
|
+
|
|
1013
|
+
# Second bound check: check if the indices/offsets are within the boundary
|
|
1014
|
+
# of the pruned embedding rows after pruning.
|
|
1015
|
+
# Note: we cast to int as a TorchScript workaround.
|
|
1016
|
+
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
|
|
1017
|
+
torch.ops.fbgemm.bounds_check_indices(
|
|
1018
|
+
self.rows_per_table,
|
|
1019
|
+
indices,
|
|
1020
|
+
offsets,
|
|
1021
|
+
self.bounds_check_mode_int,
|
|
1022
|
+
self.bounds_check_warning,
|
|
1023
|
+
per_sample_weights,
|
|
1024
|
+
bounds_check_version=self.bounds_check_version,
|
|
1025
|
+
)
|
|
1026
|
+
# Note: CPU and CUDA ops use the same interface to facilitate JIT IR
|
|
1027
|
+
# generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
|
|
1028
|
+
# weights_placements
|
|
1029
|
+
return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
|
|
1030
|
+
dev_weights=self.weights_host if self.host_size > 0 else self.weights_dev,
|
|
1031
|
+
uvm_weights=self.weights_uvm,
|
|
1032
|
+
weights_placements=self.weights_placements,
|
|
1033
|
+
weights_offsets=self.weights_offsets,
|
|
1034
|
+
weights_tys=self.weights_tys,
|
|
1035
|
+
D_offsets=self.D_offsets,
|
|
1036
|
+
total_D=self.total_D,
|
|
1037
|
+
max_int2_D=self.max_int2_D,
|
|
1038
|
+
max_int4_D=self.max_int4_D,
|
|
1039
|
+
max_int8_D=self.max_int8_D,
|
|
1040
|
+
max_float16_D=self.max_float16_D,
|
|
1041
|
+
max_float32_D=self.max_float32_D,
|
|
1042
|
+
indices=indices,
|
|
1043
|
+
offsets=offsets,
|
|
1044
|
+
pooling_mode=int(self.pooling_mode),
|
|
1045
|
+
indice_weights=per_sample_weights,
|
|
1046
|
+
output_dtype=self.output_dtype,
|
|
1047
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
1048
|
+
lxu_cache_locations=lxu_cache_locations,
|
|
1049
|
+
row_alignment=self.row_alignment,
|
|
1050
|
+
max_float8_D=self.max_float8_D,
|
|
1051
|
+
fp8_exponent_bits=self.fp8_exponent_bits,
|
|
1052
|
+
fp8_exponent_bias=self.fp8_exponent_bias,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
def forward(
|
|
1056
|
+
self,
|
|
1057
|
+
indices: Tensor,
|
|
1058
|
+
offsets: Tensor,
|
|
1059
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
1060
|
+
) -> Tensor:
|
|
1061
|
+
return self._forward_impl(
|
|
1062
|
+
indices=indices, offsets=offsets, per_sample_weights=per_sample_weights
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
def initialize_logical_weights_placements_and_offsets(
|
|
1066
|
+
self,
|
|
1067
|
+
) -> None:
|
|
1068
|
+
assert len(self.weights_physical_offsets) == len(self.embedding_specs)
|
|
1069
|
+
assert len(self.weights_physical_offsets) == len(
|
|
1070
|
+
self.weights_physical_placements
|
|
1071
|
+
)
|
|
1072
|
+
offsets = [self.weights_physical_offsets[t] for t in self.feature_table_map]
|
|
1073
|
+
placements = [
|
|
1074
|
+
self.weights_physical_placements[t] for t in self.feature_table_map
|
|
1075
|
+
]
|
|
1076
|
+
self.weights_offsets = torch.tensor(
|
|
1077
|
+
offsets, device=self.current_device, dtype=torch.int64
|
|
1078
|
+
)
|
|
1079
|
+
self.weights_placements = torch.tensor(
|
|
1080
|
+
placements, device=self.current_device, dtype=torch.int32
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
def initialize_physical_weights_placements_and_offsets(
|
|
1084
|
+
self,
|
|
1085
|
+
cacheline_alignment: bool = True,
|
|
1086
|
+
) -> None:
|
|
1087
|
+
# Initialize physical weights placements and offsets
|
|
1088
|
+
# and host/dev/uvm sizes
|
|
1089
|
+
weight_split: SplitState = nbit_construct_split_state(
|
|
1090
|
+
self.embedding_specs,
|
|
1091
|
+
cacheable=True,
|
|
1092
|
+
row_alignment=self.row_alignment,
|
|
1093
|
+
scale_bias_size_in_bytes=self.scale_bias_size_in_bytes,
|
|
1094
|
+
cacheline_alignment=cacheline_alignment,
|
|
1095
|
+
)
|
|
1096
|
+
self.weights_physical_placements = [t.value for t in weight_split.placements]
|
|
1097
|
+
self.weights_physical_offsets = weight_split.offsets
|
|
1098
|
+
self.host_size = weight_split.host_size
|
|
1099
|
+
self.dev_size = weight_split.dev_size
|
|
1100
|
+
self.uvm_size = weight_split.uvm_size
|
|
1101
|
+
|
|
1102
|
+
@torch.jit.export
|
|
1103
|
+
def reset_weights_placements_and_offsets(
|
|
1104
|
+
self, device: torch.device, location: int
|
|
1105
|
+
) -> None:
|
|
1106
|
+
# Overwrite location in embedding_specs with new location
|
|
1107
|
+
# Use map since can't script enum call (ie. EmbeddingLocation(value))
|
|
1108
|
+
INT_TO_EMBEDDING_LOCATION = {
|
|
1109
|
+
EmbeddingLocation.DEVICE.value: EmbeddingLocation.DEVICE,
|
|
1110
|
+
EmbeddingLocation.MANAGED.value: EmbeddingLocation.MANAGED,
|
|
1111
|
+
EmbeddingLocation.MANAGED_CACHING.value: EmbeddingLocation.MANAGED_CACHING,
|
|
1112
|
+
EmbeddingLocation.HOST.value: EmbeddingLocation.HOST,
|
|
1113
|
+
EmbeddingLocation.MTIA.value: EmbeddingLocation.MTIA,
|
|
1114
|
+
}
|
|
1115
|
+
# Reset device/location denoted in embedding specs
|
|
1116
|
+
target_location = INT_TO_EMBEDDING_LOCATION[location]
|
|
1117
|
+
if target_location == EmbeddingLocation.MTIA:
|
|
1118
|
+
self.scale_bias_size_in_bytes = 8
|
|
1119
|
+
self.reset_embedding_spec_location(device, target_location)
|
|
1120
|
+
# Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor
|
|
1121
|
+
self.initialize_physical_weights_placements_and_offsets(
|
|
1122
|
+
cacheline_alignment=target_location != EmbeddingLocation.MTIA
|
|
1123
|
+
)
|
|
1124
|
+
self.initialize_logical_weights_placements_and_offsets()
|
|
1125
|
+
|
|
1126
|
+
def reset_embedding_spec_location(
|
|
1127
|
+
self, device: torch.device, target_location: EmbeddingLocation
|
|
1128
|
+
) -> None:
|
|
1129
|
+
self.current_device = device
|
|
1130
|
+
self.row_alignment = (
|
|
1131
|
+
1
|
|
1132
|
+
if target_location == EmbeddingLocation.HOST
|
|
1133
|
+
or target_location == EmbeddingLocation.MTIA
|
|
1134
|
+
else 16
|
|
1135
|
+
)
|
|
1136
|
+
self.embedding_specs = [
|
|
1137
|
+
(spec[0], spec[1], spec[2], spec[3], target_location)
|
|
1138
|
+
for spec in self.embedding_specs
|
|
1139
|
+
]
|
|
1140
|
+
|
|
1141
|
+
@torch.jit.export
|
|
1142
|
+
def recompute_module_buffers(self) -> None:
|
|
1143
|
+
"""
|
|
1144
|
+
Compute module buffers that're on meta device and are not materialized
|
|
1145
|
+
in reset_weights_placements_and_offsets(). Currently those buffers are
|
|
1146
|
+
`weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`.
|
|
1147
|
+
Pruning related or uvm related buffers are not computed right now.
|
|
1148
|
+
"""
|
|
1149
|
+
if (
|
|
1150
|
+
self.weights_tys.device == self.current_device
|
|
1151
|
+
or self.current_device.type == "meta"
|
|
1152
|
+
):
|
|
1153
|
+
return
|
|
1154
|
+
|
|
1155
|
+
weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs]
|
|
1156
|
+
self.weights_tys = torch.tensor(
|
|
1157
|
+
[weights_tys_int[t] for t in self.feature_table_map],
|
|
1158
|
+
device=self.current_device,
|
|
1159
|
+
dtype=torch.uint8,
|
|
1160
|
+
)
|
|
1161
|
+
rows = [e[1] for e in self.embedding_specs]
|
|
1162
|
+
self.rows_per_table = torch.tensor(
|
|
1163
|
+
[rows[t] for t in self.feature_table_map],
|
|
1164
|
+
device=self.current_device,
|
|
1165
|
+
dtype=torch.int64,
|
|
1166
|
+
)
|
|
1167
|
+
dims = [e[2] for e in self.embedding_specs]
|
|
1168
|
+
D_offsets_list = [0]
|
|
1169
|
+
for t in self.feature_table_map:
|
|
1170
|
+
D_offsets_list.append(dims[t] + D_offsets_list[-1])
|
|
1171
|
+
self.D_offsets = torch.tensor(
|
|
1172
|
+
D_offsets_list, device=self.current_device, dtype=torch.int32
|
|
1173
|
+
)
|
|
1174
|
+
self.bounds_check_warning = torch.tensor(
|
|
1175
|
+
[0], device=self.current_device, dtype=torch.int64
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
# For pruning related or uvm related buffers, we just set them as empty tensors.
|
|
1179
|
+
self.index_remapping_hash_table = torch.empty_like(
|
|
1180
|
+
self.index_remapping_hash_table, device=self.current_device
|
|
1181
|
+
)
|
|
1182
|
+
self.index_remapping_hash_table_offsets = torch.empty_like(
|
|
1183
|
+
self.index_remapping_hash_table_offsets, device=self.current_device
|
|
1184
|
+
)
|
|
1185
|
+
self.index_remappings_array = torch.empty_like(
|
|
1186
|
+
self.index_remappings_array, device=self.current_device
|
|
1187
|
+
)
|
|
1188
|
+
self.index_remappings_array_offsets = torch.empty_like(
|
|
1189
|
+
self.index_remappings_array_offsets, device=self.current_device
|
|
1190
|
+
)
|
|
1191
|
+
# pyre-fixme[16]: `IntNBitTableBatchedEmbeddingBagsCodegen` has no attribute
|
|
1192
|
+
# `lxu_cache_weights`.
|
|
1193
|
+
self.lxu_cache_weights = torch.empty_like(
|
|
1194
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
1195
|
+
# `Union[Module, Tensor]`.
|
|
1196
|
+
self.lxu_cache_weights,
|
|
1197
|
+
device=self.current_device,
|
|
1198
|
+
)
|
|
1199
|
+
self.original_rows_per_table = torch.empty_like(
|
|
1200
|
+
self.original_rows_per_table, device=self.current_device
|
|
1201
|
+
)
|
|
1202
|
+
self.table_wise_cache_miss = torch.empty_like(
|
|
1203
|
+
self.table_wise_cache_miss, device=self.current_device
|
|
1204
|
+
)
|
|
1205
|
+
self.weights_uvm = torch.empty_like(
|
|
1206
|
+
self.weights_uvm, device=self.current_device
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
def _apply_split(
|
|
1210
|
+
self,
|
|
1211
|
+
dev_size: int,
|
|
1212
|
+
host_size: int,
|
|
1213
|
+
uvm_size: int,
|
|
1214
|
+
placements: list[int],
|
|
1215
|
+
offsets: list[int],
|
|
1216
|
+
enforce_hbm: bool,
|
|
1217
|
+
) -> None:
|
|
1218
|
+
assert not self.weight_initialized, "Weights have already been initialized."
|
|
1219
|
+
self.weight_initialized = True
|
|
1220
|
+
self.weights_physical_placements = placements
|
|
1221
|
+
self.weights_physical_offsets = offsets
|
|
1222
|
+
|
|
1223
|
+
self.host_size = host_size
|
|
1224
|
+
self.dev_size = dev_size
|
|
1225
|
+
self.uvm_size = uvm_size
|
|
1226
|
+
|
|
1227
|
+
self.initialize_logical_weights_placements_and_offsets()
|
|
1228
|
+
|
|
1229
|
+
if dev_size > 0:
|
|
1230
|
+
self.weights_dev = torch.zeros(
|
|
1231
|
+
dev_size,
|
|
1232
|
+
device=self.current_device,
|
|
1233
|
+
dtype=torch.uint8,
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
if host_size > 0:
|
|
1237
|
+
self.weights_host = torch.zeros(
|
|
1238
|
+
host_size, device=self.current_device, dtype=torch.uint8
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
if uvm_size > 0:
|
|
1242
|
+
assert not self.use_cpu
|
|
1243
|
+
if enforce_hbm:
|
|
1244
|
+
if not torch.jit.is_scripting():
|
|
1245
|
+
self.log("Enforce hbm for the cache location")
|
|
1246
|
+
self.weights_uvm = torch.zeros(
|
|
1247
|
+
uvm_size,
|
|
1248
|
+
device=self.current_device,
|
|
1249
|
+
dtype=torch.uint8,
|
|
1250
|
+
)
|
|
1251
|
+
else:
|
|
1252
|
+
self.weights_uvm = torch.zeros(
|
|
1253
|
+
uvm_size,
|
|
1254
|
+
out=torch.ops.fbgemm.new_unified_tensor(
|
|
1255
|
+
torch.zeros(1, device=self.D_offsets.device, dtype=torch.uint8),
|
|
1256
|
+
[uvm_size],
|
|
1257
|
+
self.uvm_host_mapped,
|
|
1258
|
+
),
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
def _apply_cache_state(
|
|
1262
|
+
self,
|
|
1263
|
+
cache_state: CacheState,
|
|
1264
|
+
cache_algorithm: CacheAlgorithm,
|
|
1265
|
+
cache_load_factor: float,
|
|
1266
|
+
cache_sets: int,
|
|
1267
|
+
cache_reserved_memory: float,
|
|
1268
|
+
) -> None:
|
|
1269
|
+
assert self.cache_assoc in [
|
|
1270
|
+
1,
|
|
1271
|
+
32,
|
|
1272
|
+
64,
|
|
1273
|
+
], "Only 1-way or 32-way(64-way for AMD) implmeneted for now"
|
|
1274
|
+
|
|
1275
|
+
self.cache_algorithm = cache_algorithm
|
|
1276
|
+
# pyre-ignore[16]
|
|
1277
|
+
self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
|
|
1278
|
+
# pyre-ignore[16]
|
|
1279
|
+
self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()
|
|
1280
|
+
|
|
1281
|
+
self.max_prefetch_depth = MAX_PREFETCH_DEPTH
|
|
1282
|
+
|
|
1283
|
+
if self.current_device.type == "meta":
|
|
1284
|
+
# To reslove "Cannot copy out of meta tensor; no data!" error
|
|
1285
|
+
lxu_cache_locations_empty = torch.empty(0, dtype=torch.int32).fill_(-1)
|
|
1286
|
+
else:
|
|
1287
|
+
lxu_cache_locations_empty = torch.empty(
|
|
1288
|
+
0, device=self.current_device, dtype=torch.int32
|
|
1289
|
+
).fill_(-1)
|
|
1290
|
+
# pyre-ignore[16]
|
|
1291
|
+
self.lxu_cache_locations_list = torch.classes.fbgemm.TensorQueue(
|
|
1292
|
+
lxu_cache_locations_empty
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
# NOTE: no cache for CPU mode!
|
|
1296
|
+
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
|
|
1297
|
+
self.register_buffer(
|
|
1298
|
+
"lxu_cache_weights",
|
|
1299
|
+
torch.zeros(0, 0, device=self.current_device, dtype=torch.uint8),
|
|
1300
|
+
)
|
|
1301
|
+
# NOTE: make TorchScript work!
|
|
1302
|
+
self.register_buffer(
|
|
1303
|
+
"cache_hash_size_cumsum",
|
|
1304
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1305
|
+
persistent=False,
|
|
1306
|
+
)
|
|
1307
|
+
self.total_cache_hash_size = 0
|
|
1308
|
+
self.register_buffer(
|
|
1309
|
+
"cache_index_table_map",
|
|
1310
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1311
|
+
persistent=False,
|
|
1312
|
+
)
|
|
1313
|
+
self.register_buffer(
|
|
1314
|
+
"lxu_cache_state",
|
|
1315
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1316
|
+
persistent=False,
|
|
1317
|
+
)
|
|
1318
|
+
self.register_buffer(
|
|
1319
|
+
"lxu_state",
|
|
1320
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1321
|
+
persistent=False,
|
|
1322
|
+
)
|
|
1323
|
+
self.register_buffer(
|
|
1324
|
+
"lxu_cache_miss_timestamp",
|
|
1325
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1326
|
+
persistent=False,
|
|
1327
|
+
)
|
|
1328
|
+
self.register_buffer(
|
|
1329
|
+
"cache_miss_counter",
|
|
1330
|
+
torch.tensor(
|
|
1331
|
+
[0, 0, 0, 0], dtype=torch.int64, device=self.current_device
|
|
1332
|
+
),
|
|
1333
|
+
persistent=False,
|
|
1334
|
+
)
|
|
1335
|
+
self.register_buffer(
|
|
1336
|
+
"uvm_cache_stats",
|
|
1337
|
+
torch.zeros(
|
|
1338
|
+
size=(self.uvm_cache_stats_size,),
|
|
1339
|
+
device=self.current_device,
|
|
1340
|
+
dtype=torch.int64,
|
|
1341
|
+
),
|
|
1342
|
+
persistent=False,
|
|
1343
|
+
)
|
|
1344
|
+
self.register_buffer(
|
|
1345
|
+
"local_uvm_cache_stats",
|
|
1346
|
+
torch.zeros(
|
|
1347
|
+
size=(self.uvm_cache_stats_size,),
|
|
1348
|
+
device=self.current_device,
|
|
1349
|
+
dtype=torch.int32,
|
|
1350
|
+
),
|
|
1351
|
+
persistent=False,
|
|
1352
|
+
)
|
|
1353
|
+
return
|
|
1354
|
+
|
|
1355
|
+
assert cache_load_factor > 0
|
|
1356
|
+
if cache_sets <= 0:
|
|
1357
|
+
total_memory = torch.cuda.get_device_properties(
|
|
1358
|
+
self.current_device
|
|
1359
|
+
).total_memory
|
|
1360
|
+
free_memory = (
|
|
1361
|
+
total_memory
|
|
1362
|
+
- torch.cuda.memory_reserved(self.current_device)
|
|
1363
|
+
- int(cache_reserved_memory)
|
|
1364
|
+
)
|
|
1365
|
+
assert free_memory > 0
|
|
1366
|
+
cache_sets = (
|
|
1367
|
+
int(cache_state.total_cache_hash_size * cache_load_factor)
|
|
1368
|
+
+ self.cache_assoc
|
|
1369
|
+
- 1
|
|
1370
|
+
) // self.cache_assoc
|
|
1371
|
+
# Note that element_size has been included in max_D_cache (in Bytes)
|
|
1372
|
+
cache_size = cache_sets * self.cache_assoc * self.max_D_cache
|
|
1373
|
+
if cache_size > free_memory:
|
|
1374
|
+
cache_sets = (
|
|
1375
|
+
int(1.0 * free_memory / self.max_D_cache) + self.cache_assoc - 1
|
|
1376
|
+
) // self.cache_assoc
|
|
1377
|
+
cache_sets = 1 if cache_sets == 0 else cache_sets
|
|
1378
|
+
cache_load_factor = (
|
|
1379
|
+
1.0 * cache_sets * self.cache_assoc / int(cache_state.total_cache_hash_size)
|
|
1380
|
+
)
|
|
1381
|
+
assert cache_sets > 0
|
|
1382
|
+
if cache_algorithm == CacheAlgorithm.LFU:
|
|
1383
|
+
assert cache_sets < 2**24 - 1
|
|
1384
|
+
cache_size = cache_sets * self.cache_assoc * self.max_D_cache
|
|
1385
|
+
self.log(
|
|
1386
|
+
f"Using on-device cache with admission algorithm "
|
|
1387
|
+
f"{cache_algorithm}, {cache_sets} sets, "
|
|
1388
|
+
f"cache_load_factor: {cache_load_factor : .3f}, "
|
|
1389
|
+
f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB"
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
self.total_cache_hash_size = cache_state.total_cache_hash_size
|
|
1393
|
+
self.register_buffer(
|
|
1394
|
+
"cache_hash_size_cumsum",
|
|
1395
|
+
torch.tensor(
|
|
1396
|
+
cache_state.cache_hash_size_cumsum,
|
|
1397
|
+
device=self.current_device,
|
|
1398
|
+
dtype=torch.int64,
|
|
1399
|
+
),
|
|
1400
|
+
)
|
|
1401
|
+
self.register_buffer(
|
|
1402
|
+
"cache_index_table_map",
|
|
1403
|
+
torch.tensor(
|
|
1404
|
+
cache_state.cache_index_table_map,
|
|
1405
|
+
device=self.current_device,
|
|
1406
|
+
dtype=torch.int32,
|
|
1407
|
+
),
|
|
1408
|
+
)
|
|
1409
|
+
self.register_buffer(
|
|
1410
|
+
"lxu_cache_state",
|
|
1411
|
+
torch.zeros(
|
|
1412
|
+
cache_sets,
|
|
1413
|
+
self.cache_assoc,
|
|
1414
|
+
device=self.current_device,
|
|
1415
|
+
dtype=torch.int64,
|
|
1416
|
+
).fill_(-1),
|
|
1417
|
+
)
|
|
1418
|
+
self.register_buffer(
|
|
1419
|
+
"lxu_cache_weights",
|
|
1420
|
+
torch.zeros(
|
|
1421
|
+
cache_sets * self.cache_assoc,
|
|
1422
|
+
self.max_D_cache,
|
|
1423
|
+
device=self.current_device,
|
|
1424
|
+
dtype=torch.uint8,
|
|
1425
|
+
),
|
|
1426
|
+
)
|
|
1427
|
+
self.register_buffer(
|
|
1428
|
+
"lxu_state",
|
|
1429
|
+
torch.zeros(
|
|
1430
|
+
size=(
|
|
1431
|
+
(self.total_cache_hash_size + 1,)
|
|
1432
|
+
if cache_algorithm == CacheAlgorithm.LFU
|
|
1433
|
+
else (cache_sets, self.cache_assoc)
|
|
1434
|
+
),
|
|
1435
|
+
device=self.current_device,
|
|
1436
|
+
dtype=torch.int64,
|
|
1437
|
+
),
|
|
1438
|
+
)
|
|
1439
|
+
if self.cache_assoc == 1:
|
|
1440
|
+
self.register_buffer(
|
|
1441
|
+
"lxu_cache_miss_timestamp",
|
|
1442
|
+
torch.zeros(
|
|
1443
|
+
cache_sets,
|
|
1444
|
+
self.cache_assoc,
|
|
1445
|
+
device=self.current_device,
|
|
1446
|
+
dtype=torch.int64,
|
|
1447
|
+
),
|
|
1448
|
+
)
|
|
1449
|
+
else:
|
|
1450
|
+
# make TorchScript work
|
|
1451
|
+
self.register_buffer(
|
|
1452
|
+
"lxu_cache_miss_timestamp",
|
|
1453
|
+
torch.zeros(1, device=self.current_device, dtype=torch.int64),
|
|
1454
|
+
persistent=False,
|
|
1455
|
+
)
|
|
1456
|
+
self.register_buffer(
|
|
1457
|
+
"cache_miss_counter",
|
|
1458
|
+
torch.tensor([0, 0, 0, 0], device=self.current_device, dtype=torch.int64),
|
|
1459
|
+
)
|
|
1460
|
+
self.register_buffer(
|
|
1461
|
+
"uvm_cache_stats",
|
|
1462
|
+
torch.zeros(
|
|
1463
|
+
size=(self.uvm_cache_stats_size,),
|
|
1464
|
+
device=self.current_device,
|
|
1465
|
+
dtype=torch.int64,
|
|
1466
|
+
),
|
|
1467
|
+
persistent=False,
|
|
1468
|
+
)
|
|
1469
|
+
self.register_buffer(
|
|
1470
|
+
"local_uvm_cache_stats",
|
|
1471
|
+
torch.zeros(
|
|
1472
|
+
size=(self.uvm_cache_stats_size,),
|
|
1473
|
+
device=self.current_device,
|
|
1474
|
+
dtype=torch.int32,
|
|
1475
|
+
),
|
|
1476
|
+
persistent=False,
|
|
1477
|
+
)
|
|
1478
|
+
if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
|
|
1479
|
+
raise ValueError(
|
|
1480
|
+
f"cache_algorithm must be {CacheAlgorithm.LRU} "
|
|
1481
|
+
f"or {CacheAlgorithm.LFU}"
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1484
|
+
if self.gather_uvm_cache_stats:
|
|
1485
|
+
self.reset_uvm_cache_stats()
|
|
1486
|
+
|
|
1487
|
+
def reset_cache_states(self) -> None:
|
|
1488
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
1489
|
+
# a function.
|
|
1490
|
+
if not self.lxu_cache_weights.numel():
|
|
1491
|
+
return
|
|
1492
|
+
self.lxu_cache_state.fill_(-1)
|
|
1493
|
+
self.lxu_state.fill_(0)
|
|
1494
|
+
self.timestep_counter.reset()
|
|
1495
|
+
|
|
1496
|
+
def move_to_device_with_cache(
|
|
1497
|
+
self, device: torch.device, cache_load_factor: float
|
|
1498
|
+
) -> None:
|
|
1499
|
+
"""
|
|
1500
|
+
Moves the TBE to the specified device, and updates the cache state accordingly.
|
|
1501
|
+
"""
|
|
1502
|
+
if (
|
|
1503
|
+
self.current_device == device
|
|
1504
|
+
and self.cache_load_factor == cache_load_factor
|
|
1505
|
+
):
|
|
1506
|
+
return
|
|
1507
|
+
|
|
1508
|
+
location = get_new_embedding_location(device, cache_load_factor)
|
|
1509
|
+
if device.type != "cpu":
|
|
1510
|
+
self.use_cpu = False
|
|
1511
|
+
|
|
1512
|
+
weights = self.split_embedding_weights()
|
|
1513
|
+
is_meta = self.current_device.type == "meta"
|
|
1514
|
+
index_remapping_array: torch.Tensor
|
|
1515
|
+
index_remappings_array_offsets: torch.Tensor
|
|
1516
|
+
original_rows_per_table: torch.Tensor
|
|
1517
|
+
if not is_meta:
|
|
1518
|
+
# Record weights and pruning tensors for setting
|
|
1519
|
+
# weights and pruning tensors for TBE on new device
|
|
1520
|
+
if device.type == "cpu":
|
|
1521
|
+
for i, weight in enumerate(weights):
|
|
1522
|
+
weights[i] = (
|
|
1523
|
+
weight[0].to(device),
|
|
1524
|
+
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `to`.
|
|
1525
|
+
weight[1].to(device) if weight[1] is not None else None,
|
|
1526
|
+
)
|
|
1527
|
+
(
|
|
1528
|
+
index_remapping_array,
|
|
1529
|
+
index_remappings_array_offsets,
|
|
1530
|
+
original_rows_per_table,
|
|
1531
|
+
) = (
|
|
1532
|
+
self.index_remappings_array.to(device),
|
|
1533
|
+
self.index_remappings_array_offsets.to(device),
|
|
1534
|
+
self.original_rows_per_table.to(device),
|
|
1535
|
+
)
|
|
1536
|
+
|
|
1537
|
+
self.reset_weights_placements_and_offsets(device, location.value)
|
|
1538
|
+
self.recompute_module_buffers()
|
|
1539
|
+
self.weight_initialized = False
|
|
1540
|
+
self.initialize_weights()
|
|
1541
|
+
|
|
1542
|
+
# Ensure all weights are on the same device
|
|
1543
|
+
if device.type != "cpu":
|
|
1544
|
+
self.weights_host = torch.zeros(0, device=device, dtype=torch.uint8)
|
|
1545
|
+
|
|
1546
|
+
if location != EmbeddingLocation.DEVICE:
|
|
1547
|
+
self.weights_dev = torch.zeros(0, device=device, dtype=torch.uint8)
|
|
1548
|
+
|
|
1549
|
+
for name, buf in self.named_buffers():
|
|
1550
|
+
if buf.is_meta:
|
|
1551
|
+
self.register_buffer(name, tensor_to_device(buf, device))
|
|
1552
|
+
|
|
1553
|
+
self.current_device = device
|
|
1554
|
+
|
|
1555
|
+
if not is_meta:
|
|
1556
|
+
self.assign_embedding_weights(weights)
|
|
1557
|
+
self.index_remappings_array = index_remapping_array
|
|
1558
|
+
self.index_remappings_array_offsets = index_remappings_array_offsets
|
|
1559
|
+
self.original_rows_per_table = original_rows_per_table
|
|
1560
|
+
|
|
1561
|
+
if cache_load_factor is not None:
|
|
1562
|
+
self.update_cache_load_factor(cache_load_factor)
|
|
1563
|
+
|
|
1564
|
+
def update_cache_load_factor(self, cache_load_factor: float = 0.2) -> None:
|
|
1565
|
+
"""
|
|
1566
|
+
Updates cache_load_factor and embedding location for weights after TBE has already been initialized
|
|
1567
|
+
Assumes that the location of the weights is already set correctly
|
|
1568
|
+
"""
|
|
1569
|
+
rows = [
|
|
1570
|
+
embedding_spec[EmbeddingSpecInfo.rows]
|
|
1571
|
+
for embedding_spec in self.embedding_specs
|
|
1572
|
+
]
|
|
1573
|
+
locations = [
|
|
1574
|
+
embedding_spec[EmbeddingSpecInfo.embedding_location]
|
|
1575
|
+
for embedding_spec in self.embedding_specs
|
|
1576
|
+
]
|
|
1577
|
+
# pyre-ignore[6]
|
|
1578
|
+
cache_state = construct_cache_state(rows, locations, self.feature_table_map)
|
|
1579
|
+
|
|
1580
|
+
cached_dims = [
|
|
1581
|
+
rounded_row_size_in_bytes(
|
|
1582
|
+
embedding_spec[EmbeddingSpecInfo.dims], # pyre-ignore[6]
|
|
1583
|
+
embedding_spec[EmbeddingSpecInfo.sparse_type], # pyre-ignore[6]
|
|
1584
|
+
16,
|
|
1585
|
+
self.scale_bias_size_in_bytes,
|
|
1586
|
+
)
|
|
1587
|
+
for embedding_spec in self.embedding_specs
|
|
1588
|
+
if embedding_spec[EmbeddingSpecInfo.embedding_location]
|
|
1589
|
+
== EmbeddingLocation.MANAGED_CACHING
|
|
1590
|
+
]
|
|
1591
|
+
|
|
1592
|
+
self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
|
|
1593
|
+
|
|
1594
|
+
self._apply_cache_state(
|
|
1595
|
+
cache_state,
|
|
1596
|
+
self.cache_algorithm,
|
|
1597
|
+
cache_load_factor,
|
|
1598
|
+
self.cache_sets,
|
|
1599
|
+
self.cache_reserved_memory,
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
@torch.jit.export
|
|
1603
|
+
def split_embedding_weights_with_scale_bias(
|
|
1604
|
+
self, split_scale_bias_mode: int = 1
|
|
1605
|
+
) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
|
|
1606
|
+
"""
|
|
1607
|
+
Returns a list of weights, split by table
|
|
1608
|
+
split_scale_bias_mode:
|
|
1609
|
+
0: return one row;
|
|
1610
|
+
1: return weights + scale_bias;
|
|
1611
|
+
2: return weights, scale, bias.
|
|
1612
|
+
"""
|
|
1613
|
+
assert self.weight_initialized
|
|
1614
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
|
|
1615
|
+
for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
|
|
1616
|
+
placement = self.weights_physical_placements[t]
|
|
1617
|
+
if (
|
|
1618
|
+
placement == EmbeddingLocation.DEVICE.value
|
|
1619
|
+
or placement == EmbeddingLocation.MTIA.value
|
|
1620
|
+
):
|
|
1621
|
+
weights = self.weights_dev
|
|
1622
|
+
elif placement == EmbeddingLocation.HOST.value:
|
|
1623
|
+
weights = self.weights_host
|
|
1624
|
+
else:
|
|
1625
|
+
weights = self.weights_uvm
|
|
1626
|
+
offset = self.weights_physical_offsets[t]
|
|
1627
|
+
weights_shifts = weights.detach()[
|
|
1628
|
+
offset : offset
|
|
1629
|
+
+ rows
|
|
1630
|
+
* rounded_row_size_in_bytes(
|
|
1631
|
+
dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
|
|
1632
|
+
)
|
|
1633
|
+
].view(
|
|
1634
|
+
rows,
|
|
1635
|
+
rounded_row_size_in_bytes(
|
|
1636
|
+
dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
|
|
1637
|
+
),
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
if split_scale_bias_mode == 1 or split_scale_bias_mode == 2:
|
|
1641
|
+
# remove the padding at the end of each row.
|
|
1642
|
+
weights_shifts = weights_shifts[
|
|
1643
|
+
:,
|
|
1644
|
+
: unpadded_row_size_in_bytes(
|
|
1645
|
+
dim, weight_ty, self.scale_bias_size_in_bytes
|
|
1646
|
+
),
|
|
1647
|
+
]
|
|
1648
|
+
if (
|
|
1649
|
+
weight_ty.value == SparseType.INT8.value
|
|
1650
|
+
or weight_ty.value == SparseType.INT4.value
|
|
1651
|
+
or weight_ty.value == SparseType.INT2.value
|
|
1652
|
+
):
|
|
1653
|
+
if split_scale_bias_mode == 1:
|
|
1654
|
+
if self.reverse_qparam:
|
|
1655
|
+
splits.append(
|
|
1656
|
+
(
|
|
1657
|
+
weights_shifts[
|
|
1658
|
+
:, 0 : (0 - self.scale_bias_size_in_bytes)
|
|
1659
|
+
],
|
|
1660
|
+
weights_shifts[
|
|
1661
|
+
:, (0 - self.scale_bias_size_in_bytes) :
|
|
1662
|
+
],
|
|
1663
|
+
None,
|
|
1664
|
+
)
|
|
1665
|
+
)
|
|
1666
|
+
else:
|
|
1667
|
+
splits.append(
|
|
1668
|
+
(
|
|
1669
|
+
weights_shifts[:, self.scale_bias_size_in_bytes :],
|
|
1670
|
+
weights_shifts[:, : self.scale_bias_size_in_bytes],
|
|
1671
|
+
None,
|
|
1672
|
+
)
|
|
1673
|
+
)
|
|
1674
|
+
elif split_scale_bias_mode == 2:
|
|
1675
|
+
if self.reverse_qparam:
|
|
1676
|
+
# weights_shifts: [0:-4] is real weights; [-4:-2] is scale; [-2:] is bias
|
|
1677
|
+
splits.append(
|
|
1678
|
+
(
|
|
1679
|
+
weights_shifts[
|
|
1680
|
+
:, 0 : (0 - self.scale_bias_size_in_bytes)
|
|
1681
|
+
],
|
|
1682
|
+
weights_shifts[
|
|
1683
|
+
:,
|
|
1684
|
+
(0 - self.scale_bias_size_in_bytes) : (
|
|
1685
|
+
0 - self.scale_bias_size_in_bytes // 2
|
|
1686
|
+
),
|
|
1687
|
+
].view(torch.float16),
|
|
1688
|
+
weights_shifts[
|
|
1689
|
+
:, (0 - self.scale_bias_size_in_bytes // 2) :
|
|
1690
|
+
].view(torch.float16),
|
|
1691
|
+
)
|
|
1692
|
+
)
|
|
1693
|
+
else:
|
|
1694
|
+
# weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
|
|
1695
|
+
splits.append(
|
|
1696
|
+
(
|
|
1697
|
+
weights_shifts[:, self.scale_bias_size_in_bytes :],
|
|
1698
|
+
weights_shifts[
|
|
1699
|
+
:, : self.scale_bias_size_in_bytes // 2
|
|
1700
|
+
].view(torch.float16),
|
|
1701
|
+
weights_shifts[
|
|
1702
|
+
:,
|
|
1703
|
+
self.scale_bias_size_in_bytes
|
|
1704
|
+
// 2 : self.scale_bias_size_in_bytes,
|
|
1705
|
+
].view(torch.float16),
|
|
1706
|
+
)
|
|
1707
|
+
)
|
|
1708
|
+
else:
|
|
1709
|
+
raise ValueError("split_scale_bias_mode is not supported")
|
|
1710
|
+
|
|
1711
|
+
elif (
|
|
1712
|
+
weight_ty.value == SparseType.FP8.value
|
|
1713
|
+
or weight_ty.value == SparseType.FP16.value
|
|
1714
|
+
or weight_ty.value == SparseType.FP32.value
|
|
1715
|
+
):
|
|
1716
|
+
splits.append(
|
|
1717
|
+
(
|
|
1718
|
+
weights_shifts,
|
|
1719
|
+
None,
|
|
1720
|
+
None,
|
|
1721
|
+
)
|
|
1722
|
+
)
|
|
1723
|
+
else:
|
|
1724
|
+
raise ValueError("weight_ty is not supported")
|
|
1725
|
+
|
|
1726
|
+
else: # split_scale_bias_mode == 0:
|
|
1727
|
+
splits.append((weights_shifts, None, None))
|
|
1728
|
+
|
|
1729
|
+
return splits
|
|
1730
|
+
|
|
1731
|
+
@torch.jit.export
|
|
1732
|
+
def split_embedding_weights(
|
|
1733
|
+
self,
|
|
1734
|
+
split_scale_shifts: bool = True,
|
|
1735
|
+
# When true, return list of two tensors, the first with weights and
|
|
1736
|
+
# the second with scale_bias.
|
|
1737
|
+
# This should've been named as split_scale_bias.
|
|
1738
|
+
# Keep as is for backward compatibility.
|
|
1739
|
+
) -> list[tuple[Tensor, Optional[Tensor]]]:
|
|
1740
|
+
"""
|
|
1741
|
+
Returns a list of weights, split by table
|
|
1742
|
+
"""
|
|
1743
|
+
# fmt: off
|
|
1744
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
|
|
1745
|
+
self.split_embedding_weights_with_scale_bias(
|
|
1746
|
+
split_scale_bias_mode=(1 if split_scale_shifts else 0)
|
|
1747
|
+
)
|
|
1748
|
+
)
|
|
1749
|
+
# fmt: on
|
|
1750
|
+
return [
|
|
1751
|
+
(split_weight_scale_bias[0], split_weight_scale_bias[1])
|
|
1752
|
+
for split_weight_scale_bias in splits
|
|
1753
|
+
]
|
|
1754
|
+
|
|
1755
|
+
@torch.jit.export
|
|
1756
|
+
def initialize_weights(self) -> None:
|
|
1757
|
+
if not self.weight_initialized:
|
|
1758
|
+
self._apply_split(
|
|
1759
|
+
self.dev_size,
|
|
1760
|
+
self.host_size,
|
|
1761
|
+
self.uvm_size,
|
|
1762
|
+
self.weights_physical_placements,
|
|
1763
|
+
self.weights_physical_offsets,
|
|
1764
|
+
self.enforce_hbm,
|
|
1765
|
+
)
|
|
1766
|
+
self.weight_initialized = True
|
|
1767
|
+
|
|
1768
|
+
def fill_random_weights(self) -> None:
|
|
1769
|
+
"""
|
|
1770
|
+
Fill the buffer with random weights, table by table
|
|
1771
|
+
"""
|
|
1772
|
+
self.initialize_weights()
|
|
1773
|
+
weights = self.split_embedding_weights()
|
|
1774
|
+
for dest_weight in weights:
|
|
1775
|
+
random_quant_scaled_tensor(
|
|
1776
|
+
shape=dest_weight[0].shape,
|
|
1777
|
+
device=self.current_device,
|
|
1778
|
+
output_tensor=dest_weight[0],
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
def assign_embedding_weights(
|
|
1782
|
+
self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
|
|
1783
|
+
) -> None:
|
|
1784
|
+
"""
|
|
1785
|
+
Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
|
|
1786
|
+
"""
|
|
1787
|
+
weights = self.split_embedding_weights()
|
|
1788
|
+
assert len(q_weight_list) == len(weights)
|
|
1789
|
+
|
|
1790
|
+
for dest_weight, input_weight in zip(weights, q_weight_list):
|
|
1791
|
+
dest_weight[0].copy_(input_weight[0])
|
|
1792
|
+
if input_weight[1] is not None:
|
|
1793
|
+
assert dest_weight[1] is not None
|
|
1794
|
+
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `copy_`.
|
|
1795
|
+
dest_weight[1].copy_(input_weight[1])
|
|
1796
|
+
else:
|
|
1797
|
+
assert dest_weight[1] is None
|
|
1798
|
+
|
|
1799
|
+
@torch.jit.export
|
|
1800
|
+
def set_index_remappings_array(
|
|
1801
|
+
self,
|
|
1802
|
+
index_remapping: list[Tensor],
|
|
1803
|
+
) -> None:
|
|
1804
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1805
|
+
index_remappings_array_offsets = [0]
|
|
1806
|
+
original_feature_rows = torch.jit.annotate(list[int], [])
|
|
1807
|
+
last_offset = 0
|
|
1808
|
+
for t, mapping in enumerate(index_remapping):
|
|
1809
|
+
if mapping is not None:
|
|
1810
|
+
current_original_row = mapping.numel()
|
|
1811
|
+
last_offset += current_original_row
|
|
1812
|
+
original_feature_rows.append(current_original_row)
|
|
1813
|
+
else:
|
|
1814
|
+
original_feature_rows.append(rows[t])
|
|
1815
|
+
index_remappings_array_offsets.append(last_offset)
|
|
1816
|
+
|
|
1817
|
+
self.index_remappings_array_offsets = torch.tensor(
|
|
1818
|
+
index_remappings_array_offsets,
|
|
1819
|
+
device=self.current_device,
|
|
1820
|
+
dtype=torch.int64,
|
|
1821
|
+
)
|
|
1822
|
+
if len(original_feature_rows) == 0:
|
|
1823
|
+
original_feature_rows = rows
|
|
1824
|
+
self.original_rows_per_table = torch.tensor(
|
|
1825
|
+
[original_feature_rows[t] for t in self.feature_table_map],
|
|
1826
|
+
device=self.current_device,
|
|
1827
|
+
dtype=torch.int64,
|
|
1828
|
+
)
|
|
1829
|
+
|
|
1830
|
+
index_remappings_filter_nones = []
|
|
1831
|
+
for mapping in index_remapping:
|
|
1832
|
+
if mapping is not None:
|
|
1833
|
+
index_remappings_filter_nones.append(mapping)
|
|
1834
|
+
if len(index_remappings_filter_nones) == 0:
|
|
1835
|
+
self.index_remappings_array = torch.empty(
|
|
1836
|
+
0, dtype=self.indices_dtype, device=self.current_device
|
|
1837
|
+
)
|
|
1838
|
+
else:
|
|
1839
|
+
self.index_remappings_array = torch.cat(index_remappings_filter_nones).to(
|
|
1840
|
+
dtype=self.indices_dtype, device=self.current_device
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
def set_index_remappings(
|
|
1844
|
+
self,
|
|
1845
|
+
index_remapping: list[Tensor],
|
|
1846
|
+
pruning_hash_load_factor: float = 0.5,
|
|
1847
|
+
use_array_for_index_remapping: bool = True,
|
|
1848
|
+
) -> None:
|
|
1849
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1850
|
+
T = len(self.embedding_specs)
|
|
1851
|
+
# Hash mapping pruning
|
|
1852
|
+
if not use_array_for_index_remapping:
|
|
1853
|
+
capacities = [
|
|
1854
|
+
(
|
|
1855
|
+
round_up(int(row * 1.0 / pruning_hash_load_factor), 32)
|
|
1856
|
+
if index_remap is not None
|
|
1857
|
+
else 0
|
|
1858
|
+
)
|
|
1859
|
+
for (index_remap, row) in zip(index_remapping, rows)
|
|
1860
|
+
]
|
|
1861
|
+
hash_table = torch.empty(
|
|
1862
|
+
(sum(capacities), 2),
|
|
1863
|
+
dtype=self.indices_dtype,
|
|
1864
|
+
)
|
|
1865
|
+
hash_table[:, :] = -1
|
|
1866
|
+
hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long()
|
|
1867
|
+
|
|
1868
|
+
merged_index_remappings = [
|
|
1869
|
+
mapping if mapping is not None else Tensor(list(range(row)))
|
|
1870
|
+
for (mapping, row) in zip(index_remapping, rows)
|
|
1871
|
+
]
|
|
1872
|
+
original_feature_rows = [
|
|
1873
|
+
mapping.numel() for mapping in merged_index_remappings
|
|
1874
|
+
]
|
|
1875
|
+
if len(original_feature_rows) == 0:
|
|
1876
|
+
original_feature_rows = rows
|
|
1877
|
+
self.original_rows_per_table = torch.tensor(
|
|
1878
|
+
[original_feature_rows[t] for t in self.feature_table_map],
|
|
1879
|
+
device=self.current_device,
|
|
1880
|
+
dtype=torch.int64,
|
|
1881
|
+
)
|
|
1882
|
+
dense_indices = torch.cat(merged_index_remappings, dim=0).int()
|
|
1883
|
+
indices = torch.cat(
|
|
1884
|
+
[torch.arange(row) for row in original_feature_rows], dim=0
|
|
1885
|
+
).int()
|
|
1886
|
+
offsets = torch.tensor([0] + list(accumulate(original_feature_rows))).int()
|
|
1887
|
+
|
|
1888
|
+
if self.use_cpu:
|
|
1889
|
+
self.index_remapping_hash_table_cpu = (
|
|
1890
|
+
# pyre-ignore[16]
|
|
1891
|
+
torch.classes.fbgemm.PrunedMapCPU()
|
|
1892
|
+
)
|
|
1893
|
+
self.index_remapping_hash_table_cpu.insert(
|
|
1894
|
+
indices, dense_indices, offsets, T
|
|
1895
|
+
)
|
|
1896
|
+
else:
|
|
1897
|
+
# pruned_hashmap_insert only has cpu implementation: Move dense_indices to CPU
|
|
1898
|
+
torch.ops.fbgemm.pruned_hashmap_insert(
|
|
1899
|
+
indices,
|
|
1900
|
+
dense_indices.cpu(),
|
|
1901
|
+
offsets,
|
|
1902
|
+
hash_table,
|
|
1903
|
+
hash_table_offsets,
|
|
1904
|
+
)
|
|
1905
|
+
self.index_remapping_hash_table = hash_table.to(
|
|
1906
|
+
dtype=self.indices_dtype, device=self.current_device
|
|
1907
|
+
)
|
|
1908
|
+
self.index_remapping_hash_table_offsets = hash_table_offsets.to(
|
|
1909
|
+
self.current_device
|
|
1910
|
+
)
|
|
1911
|
+
self.index_remapping_hash_table_cpu = None
|
|
1912
|
+
# Array mapping pruning
|
|
1913
|
+
else:
|
|
1914
|
+
self.set_index_remappings_array(index_remapping)
|
|
1915
|
+
|
|
1916
|
+
def _embedding_inplace_update_per_table(
|
|
1917
|
+
self,
|
|
1918
|
+
update_table_idx: int,
|
|
1919
|
+
update_row_indices: list[int],
|
|
1920
|
+
update_weights: Tensor,
|
|
1921
|
+
) -> None:
|
|
1922
|
+
row_size = len(update_row_indices)
|
|
1923
|
+
if row_size == 0:
|
|
1924
|
+
return
|
|
1925
|
+
# pyre-fixme[9]: update_row_indices has type `List[int]`; used as `Tensor`.
|
|
1926
|
+
update_row_indices = torch.tensor(
|
|
1927
|
+
update_row_indices,
|
|
1928
|
+
device=self.current_device,
|
|
1929
|
+
dtype=torch.int64,
|
|
1930
|
+
)
|
|
1931
|
+
table_values = self.split_embedding_weights(split_scale_shifts=False)[
|
|
1932
|
+
update_table_idx
|
|
1933
|
+
]
|
|
1934
|
+
table_values[0].scatter_(
|
|
1935
|
+
dim=0,
|
|
1936
|
+
# pyre-fixme[16]: `List` has no attribute `view`.
|
|
1937
|
+
index=update_row_indices.view(row_size, 1).expand_as(update_weights),
|
|
1938
|
+
src=update_weights,
|
|
1939
|
+
)
|
|
1940
|
+
|
|
1941
|
+
@torch.jit.export
|
|
1942
|
+
def embedding_inplace_update(
|
|
1943
|
+
self,
|
|
1944
|
+
update_table_indices: list[int],
|
|
1945
|
+
update_row_indices: list[list[int]],
|
|
1946
|
+
update_weights: list[Tensor],
|
|
1947
|
+
) -> None:
|
|
1948
|
+
for i in range(len(update_table_indices)):
|
|
1949
|
+
self._embedding_inplace_update_per_table(
|
|
1950
|
+
update_table_indices[i],
|
|
1951
|
+
update_row_indices[i],
|
|
1952
|
+
update_weights[i],
|
|
1953
|
+
)
|
|
1954
|
+
|
|
1955
|
+
def embedding_inplace_update_internal(
|
|
1956
|
+
self,
|
|
1957
|
+
update_table_indices: list[int],
|
|
1958
|
+
update_row_indices: list[int],
|
|
1959
|
+
update_weights: Tensor,
|
|
1960
|
+
) -> None:
|
|
1961
|
+
assert len(update_table_indices) == len(update_row_indices)
|
|
1962
|
+
update_offsets = []
|
|
1963
|
+
update_offset = 0
|
|
1964
|
+
for table_idx in update_table_indices:
|
|
1965
|
+
D_bytes = rounded_row_size_in_bytes(
|
|
1966
|
+
self.embedding_specs[table_idx][2],
|
|
1967
|
+
self.embedding_specs[table_idx][3],
|
|
1968
|
+
self.row_alignment,
|
|
1969
|
+
self.scale_bias_size_in_bytes,
|
|
1970
|
+
)
|
|
1971
|
+
update_offsets.append(update_offset)
|
|
1972
|
+
update_offset += D_bytes
|
|
1973
|
+
update_offsets.append(update_offset)
|
|
1974
|
+
|
|
1975
|
+
# pyre-fixme[9]: update_table_indices has type `List[int]`; used as `Tensor`.
|
|
1976
|
+
update_table_indices = torch.tensor(
|
|
1977
|
+
update_table_indices,
|
|
1978
|
+
device=self.current_device,
|
|
1979
|
+
dtype=torch.int32,
|
|
1980
|
+
)
|
|
1981
|
+
# pyre-fixme[9]: update_row_indices has type `List[int]`; used as `Tensor`.
|
|
1982
|
+
update_row_indices = torch.tensor(
|
|
1983
|
+
update_row_indices,
|
|
1984
|
+
device=self.current_device,
|
|
1985
|
+
dtype=torch.int64,
|
|
1986
|
+
)
|
|
1987
|
+
update_offsets = torch.tensor(
|
|
1988
|
+
update_offsets,
|
|
1989
|
+
device=self.current_device,
|
|
1990
|
+
dtype=torch.int64,
|
|
1991
|
+
)
|
|
1992
|
+
|
|
1993
|
+
# Only support array based pruning for now.
|
|
1994
|
+
assert self.index_remapping_hash_table_cpu is None
|
|
1995
|
+
assert self.index_remapping_hash_table.numel() == 0
|
|
1996
|
+
assert self.index_remappings_array.numel() >= 0
|
|
1997
|
+
|
|
1998
|
+
if self.index_remappings_array.numel() > 0:
|
|
1999
|
+
update_row_indices = torch.ops.fbgemm.pruned_array_lookup_from_row_idx(
|
|
2000
|
+
update_row_indices,
|
|
2001
|
+
update_table_indices,
|
|
2002
|
+
self.index_remappings_array,
|
|
2003
|
+
self.index_remappings_array_offsets,
|
|
2004
|
+
)
|
|
2005
|
+
|
|
2006
|
+
lxu_cache_locations = None
|
|
2007
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
2008
|
+
# a function.
|
|
2009
|
+
if self.lxu_cache_weights.numel() > 0:
|
|
2010
|
+
linear_cache_indices = (
|
|
2011
|
+
torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
|
|
2012
|
+
self.cache_hash_size_cumsum,
|
|
2013
|
+
update_table_indices,
|
|
2014
|
+
update_row_indices,
|
|
2015
|
+
)
|
|
2016
|
+
)
|
|
2017
|
+
|
|
2018
|
+
if self.cache_assoc in [32, 64]:
|
|
2019
|
+
# 64 for AMD
|
|
2020
|
+
self.prefetch_32way(linear_cache_indices)
|
|
2021
|
+
elif self.cache_assoc == 1:
|
|
2022
|
+
self.prefetch_1way(linear_cache_indices)
|
|
2023
|
+
else:
|
|
2024
|
+
raise ValueError(f"{self.cache_assoc} not in [1, 32, 64]")
|
|
2025
|
+
|
|
2026
|
+
lxu_cache_locations = self.lxu_cache_locations_list.pop()
|
|
2027
|
+
|
|
2028
|
+
torch.ops.fbgemm.emb_inplace_update(
|
|
2029
|
+
dev_weights=self.weights_host if self.host_size > 0 else self.weights_dev,
|
|
2030
|
+
uvm_weights=self.weights_uvm,
|
|
2031
|
+
weights_placements=self.weights_placements,
|
|
2032
|
+
weights_offsets=self.weights_offsets,
|
|
2033
|
+
weights_tys=self.weights_tys,
|
|
2034
|
+
D_offsets=self.D_offsets,
|
|
2035
|
+
update_weights=update_weights,
|
|
2036
|
+
update_table_indices=update_table_indices,
|
|
2037
|
+
update_row_indices=update_row_indices,
|
|
2038
|
+
update_offsets=update_offsets,
|
|
2039
|
+
row_alignment=self.row_alignment,
|
|
2040
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
2041
|
+
lxu_cache_locations=lxu_cache_locations,
|
|
2042
|
+
)
|