fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,2042 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # pyre-ignore-all-errors[56]
11
+
12
+ import logging
13
+ import uuid
14
+ from itertools import accumulate
15
+ from typing import Optional, Union
16
+
17
+ import fbgemm_gpu # noqa: F401
18
+ import torch # usort:skip
19
+ from torch import nn, Tensor # usort:skip
20
+
21
+ from fbgemm_gpu.config import FeatureGateName
22
+ from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType
23
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
24
+ BoundsCheckMode,
25
+ CacheAlgorithm,
26
+ CacheState,
27
+ construct_cache_state,
28
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
29
+ EmbeddingLocation,
30
+ EmbeddingSpecInfo,
31
+ get_bounds_check_version_for_platform,
32
+ get_new_embedding_location,
33
+ MAX_PREFETCH_DEPTH,
34
+ PoolingMode,
35
+ RecordCacheMetrics,
36
+ round_up,
37
+ SplitState,
38
+ tensor_to_device,
39
+ )
40
+ from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
41
+
42
+ try:
43
+ load_torch_module(
44
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_inference_gpu",
45
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_inference",
46
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_inference",
47
+ )
48
+ except Exception:
49
+ pass
50
+
51
+ try:
52
+ load_torch_module_bc(
53
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_inference_cpu",
54
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_inference",
55
+ )
56
+ except Exception:
57
+ pass
58
+
59
+ import fbgemm_gpu # noqa
60
+
61
+
62
+ def rounded_row_size_in_bytes(
63
+ dim: int,
64
+ weight_ty: SparseType,
65
+ row_alignment: int,
66
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
67
+ ) -> int:
68
+ r = unpadded_row_size_in_bytes(dim, weight_ty, scale_bias_size_in_bytes)
69
+ # align each row to 16-byte boundaries.
70
+ return round_up(r, row_alignment)
71
+
72
+
73
+ def unpadded_row_size_in_bytes(
74
+ dim: int,
75
+ weight_ty: SparseType,
76
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
77
+ ) -> int:
78
+ r = {
79
+ SparseType.FP32.value: dim * 4,
80
+ SparseType.FP16.value: dim * 2,
81
+ SparseType.FP8.value: dim,
82
+ SparseType.INT8.value: dim + scale_bias_size_in_bytes,
83
+ SparseType.INT4.value: dim // 2 + scale_bias_size_in_bytes,
84
+ SparseType.INT2.value: dim // 4 + scale_bias_size_in_bytes,
85
+ }[weight_ty.value]
86
+ return r
87
+
88
+
89
+ def align_to_cacheline(a: int) -> int:
90
+ # align each table to 128b cache line boundary.
91
+ return round_up(a, 128)
92
+
93
+
94
+ def nbit_construct_split_state(
95
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
96
+ cacheable: bool,
97
+ row_alignment: int,
98
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
99
+ cacheline_alignment: bool = True,
100
+ ) -> SplitState:
101
+ placements = torch.jit.annotate(list[EmbeddingLocation], [])
102
+ offsets = torch.jit.annotate(list[int], [])
103
+ dev_size = 0
104
+ host_size = 0
105
+ uvm_size = 0
106
+ for _, num_embeddings, embedding_dim, weight_ty, location in embedding_specs:
107
+ embedding_dim = rounded_row_size_in_bytes(
108
+ embedding_dim, weight_ty, row_alignment, scale_bias_size_in_bytes
109
+ )
110
+ state_size = num_embeddings * embedding_dim
111
+ if cacheline_alignment:
112
+ state_size = align_to_cacheline(state_size)
113
+ if location == EmbeddingLocation.HOST:
114
+ placements.append(EmbeddingLocation.HOST)
115
+ offsets.append(host_size)
116
+ host_size += state_size
117
+ elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA:
118
+ placements.append(location)
119
+ offsets.append(dev_size)
120
+ dev_size += state_size
121
+ else:
122
+ if cacheable and location == EmbeddingLocation.MANAGED_CACHING:
123
+ placements.append(EmbeddingLocation.MANAGED_CACHING)
124
+ else:
125
+ placements.append(EmbeddingLocation.MANAGED)
126
+ offsets.append(uvm_size)
127
+ uvm_size += state_size
128
+ assert len(placements) == len(offsets)
129
+ return SplitState(
130
+ dev_size=dev_size,
131
+ host_size=host_size,
132
+ uvm_size=uvm_size,
133
+ placements=placements,
134
+ offsets=offsets,
135
+ )
136
+
137
+
138
+ def random_quant_scaled_tensor(
139
+ shape: torch.Size,
140
+ device: torch.device,
141
+ output_tensor: Optional[torch.Tensor] = None,
142
+ ) -> torch.Tensor:
143
+ if output_tensor is not None:
144
+ return torch.randint(
145
+ 0,
146
+ 255,
147
+ size=shape,
148
+ out=output_tensor,
149
+ dtype=torch.uint8,
150
+ device=device,
151
+ )
152
+ else:
153
+ return torch.randint(
154
+ 0,
155
+ 255,
156
+ size=shape,
157
+ dtype=torch.uint8,
158
+ device=device,
159
+ )
160
+
161
+
162
+ @torch.fx.wrap
163
+ def inputs_to_device(
164
+ indices: torch.Tensor,
165
+ offsets: torch.Tensor,
166
+ per_sample_weights: Optional[torch.Tensor],
167
+ bounds_check_warning: torch.Tensor,
168
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
169
+ if bounds_check_warning.device.type == "meta":
170
+ return indices, offsets, per_sample_weights
171
+
172
+ non_blocking = bounds_check_warning.device.type != "cpu"
173
+ if indices.device != bounds_check_warning.device:
174
+ indices = indices.to(bounds_check_warning.device, non_blocking=non_blocking)
175
+ if offsets.device != bounds_check_warning.device:
176
+ offsets = offsets.to(bounds_check_warning.device, non_blocking=non_blocking)
177
+ if (
178
+ per_sample_weights is not None
179
+ and per_sample_weights.device != bounds_check_warning.device
180
+ ):
181
+ per_sample_weights = per_sample_weights.to(
182
+ bounds_check_warning.device, non_blocking=non_blocking
183
+ )
184
+ return indices, offsets, per_sample_weights
185
+
186
+
187
+ # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
188
+ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
189
+ """
190
+ Table-batched version of nn.EmbeddingBag(sparse=False)
191
+ Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
192
+
193
+ Args:
194
+ embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]):
195
+ A list of embedding specifications. Each spec describes a
196
+ specification of a physical embedding table. Each one is a tuple of
197
+ number of embedding rows, embedding dimension (must be a multiple of
198
+ 4), table placement (`EmbeddingLocation`), and compute device
199
+ (`ComputeDevice`).
200
+
201
+ Available `EmbeddingLocation` options are
202
+
203
+ (1) `DEVICE` = placing an embedding table in the GPU global memory
204
+ (HBM)
205
+
206
+ (2) `MANAGED` = placing an embedding in the unified virtual memory
207
+ (accessible from both GPU and CPU)
208
+
209
+ (3) `MANAGED_CACHING` = placing an embedding table in the unified
210
+ virtual memory and using the GPU global memory (HBM) as a cache
211
+
212
+ (4) `HOST` = placing an embedding table in the CPU memory (DRAM)
213
+
214
+ (5) `MTIA` = placing an embedding table in the MTIA memory
215
+
216
+ Available `ComputeDevice` options are
217
+
218
+ (1) `CPU` = performing table lookup on CPU
219
+
220
+ (2) `CUDA` = performing table lookup on GPU
221
+
222
+ (3) `MTIA` = performing table lookup on MTIA
223
+
224
+ feature_table_map (Optional[List[int]] = None): An optional list that
225
+ specifies feature-table mapping. feature_table_map[i] indicates the
226
+ physical embedding table that feature i maps to.
227
+
228
+ index_remapping (Optional[List[Tensor]] = None): Index remapping for pruning
229
+
230
+ pooling_mode (PoolingMode = PoolingMode.SUM): Pooling mode. Available
231
+ `PoolingMode` options are
232
+
233
+ (1) `SUM` = Sum pooling
234
+
235
+ (2) `MEAN` = Mean pooling
236
+
237
+ (3) `NONE` = No pooling (sequence embedding)
238
+
239
+ device (Optional[Union[str, int, torch.device]] = None): The current
240
+ device to place tensors on
241
+
242
+ bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING): Input
243
+ checking mode. Available `BoundsCheckMode` options are
244
+
245
+ (1) `NONE` = skip bounds check
246
+
247
+ (2) `FATAL` = throw an error when encountering an invalid
248
+ index/offset
249
+
250
+ (3) `WARNING` = print a warning message when encountering an
251
+ invalid index/offset and fix it (setting an invalid index to
252
+ zero and adjusting an invalid offset to be within the bound)
253
+
254
+ (4) `IGNORE` = silently fix an invalid index/offset (setting an
255
+ invalid index to zero and adjusting an invalid offset to be
256
+ within the bound)
257
+
258
+ weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None):
259
+ [T]
260
+
261
+ pruning_hash_load_factor (float = 0.5):
262
+ Load factor for pruning hash
263
+
264
+ use_array_for_index_remapping (bool = True):
265
+ If True, use array for index remapping. Otherwise, use hash map.
266
+
267
+ output_dtype (SparseType = SparseType.FP16): The data type of an output
268
+ tensor.
269
+
270
+ cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU): The cache
271
+ algorithm (used when `EmbeddingLocation` is set to
272
+ `MANAGED_CACHING`). Options are
273
+
274
+ (1) `LRU` = least recently used
275
+
276
+ (2) `LFU` = least frequently used
277
+
278
+ cache_load_factor (float = 0.2): A factor used for determining the
279
+ cache capacity when `EmbeddingLocation.MANAGED_CACHING` is used.
280
+ The cache capacity is `cache_load_factor` * the total number of
281
+ rows in all embedding tables
282
+
283
+ cache_sets (int = 0): The number of cache sets (used when
284
+ `EmbeddingLocation` is set to `MANAGED_CACHING`)
285
+
286
+ cache_reserved_memory (float = 0.0): The amount of memory reserved in
287
+ HBM for non-cache purpose (used when `EmbeddingLocation` is set to
288
+ `MANAGED_CACHING`).
289
+
290
+ enforce_hbm (bool = False): If True, place all weights/momentums in HBM
291
+ when using `EmbeddingLocation.MANAGED_CACHING`
292
+
293
+ record_cache_metrics (Optional[RecordCacheMetrics] = None): Record
294
+ a number of hits, a number of requests, etc if
295
+ `RecordCacheMetrics.record_cache_miss_counter` is True and record
296
+ the similar metrics table-wise if
297
+ `RecordCacheMetrics.record_tablewise_cache_miss is True`
298
+
299
+ gather_uvm_cache_stats (Optional[bool] = False): If True, collect the
300
+ cache statistics when `EmbeddingLocation` is set to
301
+ `MANAGED_CACHING`
302
+
303
+ row_alignment (Optional[int] = None): Row alignment
304
+
305
+ fp8_exponent_bits (Optional[int] = None): Exponent bits when using FP8
306
+
307
+ fp8_exponent_bias (Optional[int] = None): Exponent bias when using FP8
308
+
309
+ cache_assoc (int = 32): Number of ways for cache
310
+
311
+ scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES): Size
312
+ of scale and bias in bytes
313
+
314
+ cacheline_alignment (bool = True): If True, align each table to 128b
315
+ cache line boundary
316
+
317
+ uvm_host_mapped (bool = False): If True, allocate every UVM tensor
318
+ using `malloc` + `cudaHostRegister`. Otherwise use
319
+ `cudaMallocManaged`
320
+
321
+ reverse_qparam (bool = False): If True, load `qparams` at end of each
322
+ row. Otherwise, load `qparams` at begnning of each row.
323
+
324
+ feature_names_per_table (Optional[List[List[str]]] = None): An optional
325
+ list that specifies feature names per table. `feature_names_per_table[t]`
326
+ indicates the feature names of table `t`.
327
+
328
+ indices_dtype (torch.dtype = torch.int32): The expected dtype of the
329
+ indices tensor that will be passed to the `forward()` call. This
330
+ information will be used to construct the remap_indices array/hash.
331
+ Options are `torch.int32` and `torch.int64`.
332
+ """
333
+
334
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
335
+ record_cache_metrics: RecordCacheMetrics
336
+ # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
337
+ cache_miss_counter: torch.Tensor
338
+ # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
339
+ uvm_cache_stats: torch.Tensor
340
+ # pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
341
+ local_uvm_cache_stats: torch.Tensor
342
+ # pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
343
+ weights_offsets: torch.Tensor
344
+ # pyre-fixme[13]: Attribute `weights_placements` is never initialized.
345
+ weights_placements: torch.Tensor
346
+
347
+ def __init__( # noqa C901
348
+ self,
349
+ embedding_specs: list[
350
+ tuple[str, int, int, SparseType, EmbeddingLocation]
351
+ ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
352
+ feature_table_map: Optional[list[int]] = None, # [T]
353
+ index_remapping: Optional[list[Tensor]] = None,
354
+ pooling_mode: PoolingMode = PoolingMode.SUM,
355
+ device: Optional[Union[str, int, torch.device]] = None,
356
+ bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
357
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
358
+ pruning_hash_load_factor: float = 0.5,
359
+ use_array_for_index_remapping: bool = True,
360
+ output_dtype: SparseType = SparseType.FP16,
361
+ cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
362
+ cache_load_factor: float = 0.2,
363
+ cache_sets: int = 0,
364
+ cache_reserved_memory: float = 0.0,
365
+ enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
366
+ record_cache_metrics: Optional[RecordCacheMetrics] = None,
367
+ gather_uvm_cache_stats: Optional[bool] = False,
368
+ row_alignment: Optional[int] = None,
369
+ fp8_exponent_bits: Optional[int] = None,
370
+ fp8_exponent_bias: Optional[int] = None,
371
+ cache_assoc: int = 32,
372
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
373
+ cacheline_alignment: bool = True,
374
+ uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
375
+ reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
376
+ feature_names_per_table: Optional[list[list[str]]] = None,
377
+ indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
378
+ ) -> None: # noqa C901 # tuple of (rows, dims,)
379
+ super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
380
+ self.uuid = str(uuid.uuid4())
381
+ self.log(
382
+ f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
383
+ )
384
+
385
+ # 64 for AMD
386
+ if cache_assoc == 32 and torch.version.hip is not None:
387
+ cache_assoc = 64
388
+
389
+ if device is None:
390
+ self.current_device: torch.device = torch.device(
391
+ torch.cuda.current_device()
392
+ )
393
+ elif isinstance(device, torch.device):
394
+ self.current_device = device
395
+ else:
396
+ self.current_device = torch.device(device)
397
+ self.use_cpu: bool = self.current_device.type == "cpu"
398
+
399
+ self.scale_bias_size_in_bytes = scale_bias_size_in_bytes
400
+ self.pooling_mode = pooling_mode
401
+ self.bounds_check_mode_int: int = bounds_check_mode.value
402
+ self.embedding_specs = embedding_specs
403
+ self.output_dtype: int = output_dtype.as_int()
404
+ self.uvm_host_mapped = uvm_host_mapped
405
+ self.feature_names_per_table = feature_names_per_table
406
+ self.indices_dtype = indices_dtype
407
+ # (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
408
+ # Pyre workaround
409
+ self.feature_names: list[str] = [e[0] for e in embedding_specs]
410
+ self.cache_load_factor: float = cache_load_factor
411
+ self.cache_sets: int = cache_sets
412
+ self.cache_reserved_memory: float = cache_reserved_memory
413
+ rows: list[int] = [e[1] for e in embedding_specs]
414
+ dims: list[int] = [e[2] for e in embedding_specs]
415
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
416
+ locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
417
+ # if target device is meta then we set use_cpu based on the embedding location
418
+ # information in embedding_specs.
419
+ if self.current_device.type == "meta":
420
+ self.use_cpu = all(loc == EmbeddingLocation.HOST for loc in locations)
421
+
422
+ if row_alignment is None:
423
+ self.row_alignment: int = 1 if self.use_cpu else 16
424
+ else:
425
+ self.row_alignment = row_alignment
426
+
427
+ if record_cache_metrics is not None:
428
+ self.record_cache_metrics = record_cache_metrics
429
+ else:
430
+ self.record_cache_metrics = RecordCacheMetrics(False, False)
431
+
432
+ self.gather_uvm_cache_stats = gather_uvm_cache_stats
433
+ # Define the size of uvm cache stats as class variable
434
+ # to make it work with torch jit script.
435
+ self.uvm_cache_stats_size = 6
436
+ # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
437
+ # 4: N_conflict_unique_misses, 5: N_conflict_misses
438
+
439
+ # mixed D is not supported by no bag kernels
440
+ mixed_D = not all(d == dims[0] for d in dims)
441
+ if mixed_D:
442
+ assert (
443
+ self.pooling_mode != PoolingMode.NONE
444
+ ), "Mixed dimension tables are only supported for pooling tables."
445
+
446
+ assert not self.use_cpu or all(
447
+ loc == EmbeddingLocation.HOST for loc in locations
448
+ ), "CPU device requires EmbeddingLocation.HOST for location!"
449
+ assert self.use_cpu or all(
450
+ loc != EmbeddingLocation.HOST for loc in locations
451
+ ), "EmbeddingLocation.HOST doesn't work for CUDA device!"
452
+
453
+ T_ = len(self.embedding_specs)
454
+ assert T_ > 0
455
+
456
+ self.feature_table_map: list[int] = (
457
+ feature_table_map if feature_table_map is not None else list(range(T_))
458
+ )
459
+ T = len(self.feature_table_map)
460
+ assert T_ <= T
461
+
462
+ table_has_feature = [False] * T_
463
+ for t in self.feature_table_map:
464
+ table_has_feature[t] = True
465
+ assert all(table_has_feature), "Each table must have at least one feature!"
466
+ D_offsets = [dims[t] for t in self.feature_table_map]
467
+ D_offsets = [0] + list(accumulate(D_offsets))
468
+ self.total_D: int = D_offsets[-1]
469
+ for dim, weight_ty in zip(dims, weights_tys):
470
+ if not weight_ty.is_float():
471
+ assert (
472
+ dim % (8 / weight_ty.bit_rate()) == 0
473
+ ), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"
474
+
475
+ def max_ty_D(ty: SparseType) -> int:
476
+ return max(
477
+ [
478
+ dim
479
+ for dim, weight_ty in zip(dims, weights_tys)
480
+ if weight_ty == ty or weight_ty.value == ty.value
481
+ ],
482
+ default=0,
483
+ )
484
+
485
+ self.max_int2_D: int = max_ty_D(SparseType.INT2)
486
+ self.max_int4_D: int = max_ty_D(SparseType.INT4)
487
+ self.max_int8_D: int = max_ty_D(SparseType.INT8)
488
+ self.max_float8_D: int = max_ty_D(SparseType.FP8)
489
+ self.max_float16_D: int = max_ty_D(SparseType.FP16)
490
+ self.max_float32_D: int = max_ty_D(SparseType.FP32)
491
+
492
+ self.register_buffer(
493
+ "D_offsets",
494
+ torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
495
+ )
496
+ assert self.D_offsets.numel() == T + 1
497
+
498
+ self.register_buffer(
499
+ "rows_per_table",
500
+ torch.tensor(
501
+ [rows[t] for t in self.feature_table_map],
502
+ device=self.current_device,
503
+ dtype=torch.int64,
504
+ ),
505
+ )
506
+ self.register_buffer(
507
+ "bounds_check_warning",
508
+ torch.tensor([0], device=self.current_device, dtype=torch.int64),
509
+ )
510
+
511
+ weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
512
+ self.register_buffer(
513
+ "weights_tys",
514
+ torch.tensor(
515
+ weights_tys_int, device=self.current_device, dtype=torch.uint8
516
+ ),
517
+ )
518
+ self.weight_initialized: bool = False
519
+
520
+ self.weights_dev: torch.Tensor = torch.zeros(
521
+ 0,
522
+ device=self.current_device,
523
+ dtype=torch.uint8,
524
+ )
525
+
526
+ self.weights_host: torch.Tensor = torch.zeros(
527
+ 0, device=self.current_device, dtype=torch.uint8
528
+ )
529
+
530
+ self.weights_uvm: torch.Tensor = torch.empty(
531
+ 0, device=self.current_device, dtype=torch.uint8
532
+ )
533
+
534
+ cached_dims = [
535
+ rounded_row_size_in_bytes(
536
+ embedding_spec[2], embedding_spec[3], 16, self.scale_bias_size_in_bytes
537
+ )
538
+ for embedding_spec in self.embedding_specs
539
+ if embedding_spec[4] == EmbeddingLocation.MANAGED_CACHING
540
+ ]
541
+ self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
542
+
543
+ self.initialize_physical_weights_placements_and_offsets(cacheline_alignment)
544
+ self.enforce_hbm: bool = enforce_hbm
545
+
546
+ self.reverse_qparam = reverse_qparam
547
+ # Assign weights after weights and weights_offsets are initialized.
548
+ if weight_lists:
549
+ self._apply_split(
550
+ self.dev_size,
551
+ self.host_size,
552
+ self.uvm_size,
553
+ self.weights_physical_placements,
554
+ self.weights_physical_offsets,
555
+ self.enforce_hbm,
556
+ )
557
+ self.assign_embedding_weights(weight_lists)
558
+
559
+ # Handle index remapping for embedding pruning.
560
+ # All buffers are int64 in order to support both int32 and int64 indices.
561
+ self.register_buffer(
562
+ "index_remappings_array_offsets",
563
+ torch.empty(0, device=self.current_device, dtype=torch.int64),
564
+ )
565
+ self.register_buffer(
566
+ "index_remappings_array",
567
+ torch.empty(0, device=self.current_device, dtype=self.indices_dtype),
568
+ )
569
+ self.register_buffer(
570
+ "index_remapping_hash_table_offsets",
571
+ torch.empty(0, device=self.current_device, dtype=torch.int64),
572
+ )
573
+ self.register_buffer(
574
+ "index_remapping_hash_table",
575
+ torch.empty(0, device=self.current_device, dtype=self.indices_dtype),
576
+ )
577
+ self.register_buffer(
578
+ "original_rows_per_table",
579
+ torch.empty(0, device=self.current_device, dtype=torch.int64),
580
+ )
581
+ # pyre-fixme[4]: Attribute must be annotated.
582
+ self.index_remapping_hash_table_cpu = None
583
+
584
+ if index_remapping:
585
+ self.set_index_remappings(
586
+ index_remapping, pruning_hash_load_factor, use_array_for_index_remapping
587
+ )
588
+
589
+ # Currently only support cache_precision == embedding_precision.
590
+ # Both are represented as uint8_t
591
+ cache_state = construct_cache_state(rows, locations, self.feature_table_map)
592
+
593
+ if self.record_cache_metrics.record_tablewise_cache_miss:
594
+ num_tables = len(cache_state.cache_hash_size_cumsum) - 1
595
+ self.register_buffer(
596
+ "table_wise_cache_miss",
597
+ torch.zeros(
598
+ num_tables,
599
+ device=self.current_device,
600
+ dtype=torch.int64,
601
+ ),
602
+ )
603
+ # NOTE: make TorchScript work!
604
+ else:
605
+ self.register_buffer(
606
+ "table_wise_cache_miss",
607
+ torch.zeros(
608
+ 0,
609
+ device=self.current_device,
610
+ dtype=torch.int64,
611
+ ),
612
+ )
613
+
614
+ self.cache_assoc = cache_assoc
615
+ self._apply_cache_state(
616
+ cache_state,
617
+ cache_algorithm,
618
+ cache_load_factor,
619
+ cache_sets,
620
+ cache_reserved_memory,
621
+ )
622
+
623
+ if self.max_float8_D > 0:
624
+ default_config = SparseType.FP8.default_config()
625
+ self.fp8_exponent_bits: int = (
626
+ default_config.get("exponent_bits")
627
+ if fp8_exponent_bits is None
628
+ else fp8_exponent_bits
629
+ )
630
+ self.fp8_exponent_bias: int = (
631
+ default_config.get("exponent_bias")
632
+ if fp8_exponent_bias is None
633
+ else fp8_exponent_bias
634
+ )
635
+ else:
636
+ self.fp8_exponent_bits = -1
637
+ self.fp8_exponent_bias = -1
638
+
639
+ self.bounds_check_version: int = get_bounds_check_version_for_platform()
640
+
641
+ @torch.jit.ignore
642
+ def log(self, msg: str) -> None:
643
+ """
644
+ Log with TBE id prefix to distinguish between multiple TBE instances
645
+ per process
646
+
647
+ Args:
648
+ msg (str): The message to print
649
+
650
+ Returns:
651
+ None
652
+ """
653
+ logging.info(f"[TBE={self.uuid}] {msg}")
654
+
655
+ def get_cache_miss_counter(self) -> Tensor:
656
+ # cache_miss_counter[0]: cache_miss_forward_count which records the total number of forwards which has at least one cache miss
657
+ # cache_miss_counter[1]: unique_cache_miss_count which records to total number of unique (dedup) cache misses
658
+ # cache_miss_counter[2]: total number of unique (dedup) access count
659
+ # cache_miss_counter[3]: total number of non-dedup access count
660
+
661
+ # How to get cache miss ratio
662
+ # cache miss ratio (# of missed entries / # of unique requests): ( cache_miss_counter[1] / cache_miss_counter[2] )
663
+ # cache miss ratio (# of missed entries / # of total access): ( cache_miss_counter[1] / cache_miss_counter[3] )
664
+ assert (
665
+ self.record_cache_metrics.record_cache_miss_counter
666
+ ), "record_cache_miss_counter should be true to access counter values"
667
+
668
+ return self.cache_miss_counter
669
+
670
+ @torch.jit.export
671
+ def get_table_wise_cache_miss(self) -> Tensor:
672
+ assert (
673
+ self.record_cache_metrics.record_tablewise_cache_miss
674
+ ), "record_tablewise_cache_miss should be true to access counter values"
675
+ # table_wise_cache_miss contains all the cache miss count for each table in this embedding table object:
676
+ return self.table_wise_cache_miss
677
+
678
+ @torch.jit.export
679
+ def get_feature_num_per_table(self) -> list[int]:
680
+ if self.feature_names_per_table is None:
681
+ return []
682
+ return [len(feature_names) for feature_names in self.feature_names_per_table]
683
+
684
+ def reset_cache_miss_counter(self) -> None:
685
+ assert (
686
+ self.record_cache_metrics.record_cache_miss_counter
687
+ ), "record_cache_miss_counter should be true to access counter values"
688
+ self.cache_miss_counter = torch.tensor(
689
+ [0, 0, 0, 0], device=self.current_device, dtype=torch.int64
690
+ )
691
+
692
+ def reset_uvm_cache_stats(self) -> None:
693
+ assert (
694
+ self.gather_uvm_cache_stats
695
+ ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
696
+ self.uvm_cache_stats.zero_()
697
+ self.local_uvm_cache_stats.zero_()
698
+
699
+ def print_cache_miss_counter(self) -> None:
700
+ assert (
701
+ self.record_cache_metrics.record_cache_miss_counter
702
+ ), "record_cache_miss_counter should be true to access counter values"
703
+ self.log(
704
+ f"\n"
705
+ f"Miss counter value [0] - # of miss occured iters : {self.cache_miss_counter[0]}, \n"
706
+ f"Miss counter value [1] - # of unique misses : {self.cache_miss_counter[1]}, \n"
707
+ f"Miss counter value [2] - # of unique requested indices : {self.cache_miss_counter[2]}, \n"
708
+ f"Miss counter value [3] - # of total requested indices : {self.cache_miss_counter[3]}, "
709
+ )
710
+ self.log(
711
+ f"unique_miss_rate using counter : {self.cache_miss_counter[1] / self.cache_miss_counter[2]}, \n"
712
+ )
713
+ self.log(
714
+ f"total_miss_rate using counter : {self.cache_miss_counter[1] / self.cache_miss_counter[3]}, \n"
715
+ )
716
+
717
+ def get_uvm_cache_stats(self) -> Tensor:
718
+ assert (
719
+ self.gather_uvm_cache_stats
720
+ ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
721
+ return self.uvm_cache_stats
722
+
723
+ def print_uvm_cache_stats(self) -> None:
724
+ assert (
725
+ self.gather_uvm_cache_stats
726
+ ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
727
+ uvm_cache_stats = self.uvm_cache_stats.tolist()
728
+ self.log(
729
+ f"N_called: {uvm_cache_stats[0]}\n"
730
+ f"N_requested_indices: {uvm_cache_stats[1]}\n"
731
+ f"N_unique_indices: {uvm_cache_stats[2]}\n"
732
+ f"N_unique_misses: {uvm_cache_stats[3]}\n"
733
+ f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
734
+ f"N_conflict_misses: {uvm_cache_stats[5]}\n"
735
+ )
736
+ if uvm_cache_stats[1]:
737
+ self.log(
738
+ f"unique indices / requested indices: {uvm_cache_stats[2] / uvm_cache_stats[1]}\n"
739
+ f"unique misses / requested indices: {uvm_cache_stats[3] / uvm_cache_stats[1]}\n"
740
+ )
741
+
742
+ @torch.jit.export
743
+ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
744
+ self.timestep_counter.increment()
745
+ self.timestep_prefetch_size.increment()
746
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
747
+ # a function.
748
+ if not self.lxu_cache_weights.numel():
749
+ return
750
+
751
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
752
+ self.cache_hash_size_cumsum,
753
+ indices,
754
+ offsets,
755
+ )
756
+
757
+ if (
758
+ self.record_cache_metrics.record_cache_miss_counter
759
+ or self.record_cache_metrics.record_tablewise_cache_miss
760
+ ):
761
+ lxu_cache_locations = (
762
+ torch.ops.fbgemm.lxu_cache_lookup(
763
+ linear_cache_indices,
764
+ self.lxu_cache_state,
765
+ self.total_cache_hash_size,
766
+ )
767
+ if self.cache_assoc in [32, 64]
768
+ else torch.ops.fbgemm.direct_mapped_lxu_cache_lookup(
769
+ linear_cache_indices,
770
+ self.lxu_cache_state,
771
+ self.total_cache_hash_size,
772
+ )
773
+ )
774
+ if self.record_cache_metrics.record_cache_miss_counter:
775
+ self._update_cache_miss_counter(
776
+ lxu_cache_locations, linear_cache_indices
777
+ )
778
+ if self.record_cache_metrics.record_tablewise_cache_miss:
779
+ self._update_tablewise_cache_miss(
780
+ lxu_cache_locations, linear_cache_indices, offsets
781
+ )
782
+
783
+ if self.cache_assoc in [32, 64]:
784
+ # 64 for AMD
785
+ self.prefetch_32way(linear_cache_indices)
786
+ elif self.cache_assoc == 1:
787
+ self.prefetch_1way(linear_cache_indices)
788
+ else:
789
+ raise ValueError(f"{self.cache_assoc} not in [1, 32, 64]")
790
+
791
+ def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
792
+ if self.cache_algorithm == CacheAlgorithm.LRU:
793
+ torch.ops.fbgemm.lru_cache_populate_byte(
794
+ self.weights_uvm,
795
+ self.cache_hash_size_cumsum,
796
+ self.total_cache_hash_size,
797
+ self.cache_index_table_map,
798
+ self.weights_offsets,
799
+ self.weights_tys,
800
+ self.D_offsets,
801
+ linear_cache_indices,
802
+ self.lxu_cache_state,
803
+ self.lxu_cache_weights,
804
+ self.timestep_counter.get(),
805
+ self.lxu_state,
806
+ 16, # row_alignment; using default value.
807
+ self.gather_uvm_cache_stats,
808
+ self.local_uvm_cache_stats,
809
+ )
810
+ elif self.cache_algorithm == CacheAlgorithm.LFU:
811
+ torch.ops.fbgemm.lfu_cache_populate_byte(
812
+ self.weights_uvm,
813
+ self.cache_hash_size_cumsum,
814
+ self.total_cache_hash_size,
815
+ self.cache_index_table_map,
816
+ self.weights_offsets,
817
+ self.weights_tys,
818
+ self.D_offsets,
819
+ linear_cache_indices,
820
+ self.lxu_cache_state,
821
+ self.lxu_cache_weights,
822
+ self.lxu_state,
823
+ )
824
+
825
+ assert (
826
+ self.lxu_cache_locations_list.size() < self.max_prefetch_depth
827
+ ), f"self.lxu_cache_locations_list has grown to size: {self.lxu_cache_locations_list.size()}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
828
+ self.lxu_cache_locations_list.push(
829
+ torch.ops.fbgemm.lxu_cache_lookup(
830
+ linear_cache_indices,
831
+ self.lxu_cache_state,
832
+ self.total_cache_hash_size,
833
+ self.gather_uvm_cache_stats,
834
+ self.local_uvm_cache_stats,
835
+ )
836
+ )
837
+ if self.gather_uvm_cache_stats:
838
+ self._accumulate_uvm_cache_stats()
839
+
840
+ def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
841
+ if self.cache_algorithm == CacheAlgorithm.LRU:
842
+ torch.ops.fbgemm.direct_mapped_lru_cache_populate_byte(
843
+ self.weights_uvm,
844
+ self.cache_hash_size_cumsum,
845
+ self.total_cache_hash_size,
846
+ self.cache_index_table_map,
847
+ self.weights_offsets,
848
+ self.weights_tys,
849
+ self.D_offsets,
850
+ linear_cache_indices,
851
+ self.lxu_cache_state,
852
+ self.lxu_cache_weights,
853
+ self.timestep_counter.get(),
854
+ self.lxu_state,
855
+ self.lxu_cache_miss_timestamp,
856
+ 16, # row_alignment; using default value.
857
+ self.gather_uvm_cache_stats,
858
+ self.local_uvm_cache_stats,
859
+ )
860
+ else:
861
+ raise ValueError("Direct Mapped for LRU only")
862
+
863
+ assert (
864
+ self.lxu_cache_locations_list.size() < self.max_prefetch_depth
865
+ ), f"self.lxu_cache_locations_list has grown to size: {self.lxu_cache_locations_list.size()}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
866
+ self.lxu_cache_locations_list.push(
867
+ torch.ops.fbgemm.direct_mapped_lxu_cache_lookup(
868
+ linear_cache_indices,
869
+ self.lxu_cache_state,
870
+ self.total_cache_hash_size,
871
+ self.gather_uvm_cache_stats,
872
+ self.local_uvm_cache_stats,
873
+ )
874
+ )
875
+ if self.gather_uvm_cache_stats:
876
+ self._accumulate_uvm_cache_stats()
877
+
878
+ def _accumulate_uvm_cache_stats(self) -> None:
879
+ # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
880
+ # We may wanna do this accumulation atomically, but as it's only for monitoring,
881
+ # slightly inaccurate result may be acceptable.
882
+ self.uvm_cache_stats = torch.add(
883
+ self.uvm_cache_stats, self.local_uvm_cache_stats
884
+ )
885
+ self.local_uvm_cache_stats.zero_()
886
+
887
+ def _update_cache_miss_counter(
888
+ self,
889
+ lxu_cache_locations: Tensor,
890
+ linear_cache_indices: Tensor,
891
+ ) -> None:
892
+ CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32)
893
+ CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32)
894
+
895
+ cache_missed_locations = torch.where(
896
+ lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
897
+ )
898
+ unique_ids_list = torch.unique(cache_missed_locations)
899
+ unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
900
+
901
+ miss_count = torch.sum(unique_ids_count_list)
902
+
903
+ self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64)
904
+
905
+ self.cache_miss_counter[1] += miss_count
906
+
907
+ # Number of unique requests
908
+ assert (
909
+ len(linear_cache_indices.size()) == 1
910
+ ), f"linear_cache_indices should be 1-D was {len(linear_cache_indices.size())}-D"
911
+
912
+ assert (
913
+ self.cache_miss_counter.size()[0] == 4
914
+ ), f"self.cache_miss_counter should be 4-D was {self.cache_miss_counter.size()[0]}-D"
915
+
916
+ self.cache_miss_counter[2] += torch.unique(linear_cache_indices).size()[0]
917
+
918
+ # Number of total requests
919
+ self.cache_miss_counter[3] += linear_cache_indices.size()[0]
920
+
921
+ def _update_tablewise_cache_miss(
922
+ self,
923
+ lxu_cache_locations: Tensor,
924
+ linear_cache_indices: Tensor,
925
+ offsets: Tensor,
926
+ ) -> None:
927
+ CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32)
928
+ CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32)
929
+
930
+ # pyre-fixme[6]: For 1st argument expected
931
+ # `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`.
932
+ num_tables = len(self.cache_hash_size_cumsum) - 1
933
+ num_offsets_per_table = (len(offsets) - 1) // num_tables
934
+ cache_missed_locations = torch.where(
935
+ lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
936
+ )
937
+
938
+ for i in range(num_tables):
939
+ start = offsets[i * num_offsets_per_table]
940
+ end = offsets[(i + 1) * num_offsets_per_table]
941
+
942
+ current_cache_missed_locations = cache_missed_locations[start:end]
943
+ unique_ids_list = torch.unique(current_cache_missed_locations)
944
+ unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
945
+
946
+ miss_count = torch.sum(unique_ids_count_list)
947
+
948
+ self.table_wise_cache_miss[i] += miss_count
949
+
950
+ def _forward_impl(
951
+ self,
952
+ indices: Tensor,
953
+ offsets: Tensor,
954
+ per_sample_weights: Optional[Tensor] = None,
955
+ ) -> Tensor:
956
+ assert (
957
+ self.weight_initialized
958
+ ), "weight needs to be initialized before forward function"
959
+
960
+ indices, offsets, per_sample_weights = inputs_to_device(
961
+ indices, offsets, per_sample_weights, self.bounds_check_warning
962
+ )
963
+
964
+ # First bound check: check if the indices/offsets are within the boundary
965
+ # of the original embedding rows before pruning.
966
+ # Note that this is only applied when we enable pruning (if the perf becomes
967
+ # an issue, we can fuse it inside the remapping kernel).
968
+ if (
969
+ self.index_remapping_hash_table_cpu is not None
970
+ or self.index_remapping_hash_table.numel() > 0
971
+ or self.index_remappings_array.numel() > 0
972
+ ):
973
+ if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
974
+ torch.ops.fbgemm.bounds_check_indices(
975
+ self.original_rows_per_table,
976
+ indices,
977
+ offsets,
978
+ self.bounds_check_mode_int,
979
+ self.bounds_check_warning,
980
+ per_sample_weights,
981
+ bounds_check_version=self.bounds_check_version,
982
+ )
983
+
984
+ # Index remapping changes input indices, and some of them becomes -1 (prunned rows).
985
+ # Hence, remapping should be done before prefetch and emb lookup
986
+ # so that these operations are with the remapped indices.
987
+ if self.index_remapping_hash_table_cpu is not None:
988
+ indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets)
989
+ elif self.index_remapping_hash_table.numel() > 0:
990
+ # Convert from raw indices to pruned indices
991
+ indices = torch.ops.fbgemm.pruned_hashmap_lookup(
992
+ indices,
993
+ offsets,
994
+ self.index_remapping_hash_table,
995
+ self.index_remapping_hash_table_offsets,
996
+ )
997
+ elif self.index_remappings_array.numel() > 0:
998
+ indices = torch.ops.fbgemm.pruned_array_lookup(
999
+ indices,
1000
+ offsets,
1001
+ self.index_remappings_array,
1002
+ self.index_remappings_array_offsets,
1003
+ )
1004
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
1005
+ # a function.
1006
+ if self.lxu_cache_weights.numel() > 0:
1007
+ if self.timestep_prefetch_size.get() <= 0:
1008
+ self.prefetch(indices, offsets)
1009
+ self.timestep_prefetch_size.decrement()
1010
+
1011
+ lxu_cache_locations = self.lxu_cache_locations_list.pop()
1012
+
1013
+ # Second bound check: check if the indices/offsets are within the boundary
1014
+ # of the pruned embedding rows after pruning.
1015
+ # Note: we cast to int as a TorchScript workaround.
1016
+ if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
1017
+ torch.ops.fbgemm.bounds_check_indices(
1018
+ self.rows_per_table,
1019
+ indices,
1020
+ offsets,
1021
+ self.bounds_check_mode_int,
1022
+ self.bounds_check_warning,
1023
+ per_sample_weights,
1024
+ bounds_check_version=self.bounds_check_version,
1025
+ )
1026
+ # Note: CPU and CUDA ops use the same interface to facilitate JIT IR
1027
+ # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
1028
+ # weights_placements
1029
+ return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
1030
+ dev_weights=self.weights_host if self.host_size > 0 else self.weights_dev,
1031
+ uvm_weights=self.weights_uvm,
1032
+ weights_placements=self.weights_placements,
1033
+ weights_offsets=self.weights_offsets,
1034
+ weights_tys=self.weights_tys,
1035
+ D_offsets=self.D_offsets,
1036
+ total_D=self.total_D,
1037
+ max_int2_D=self.max_int2_D,
1038
+ max_int4_D=self.max_int4_D,
1039
+ max_int8_D=self.max_int8_D,
1040
+ max_float16_D=self.max_float16_D,
1041
+ max_float32_D=self.max_float32_D,
1042
+ indices=indices,
1043
+ offsets=offsets,
1044
+ pooling_mode=int(self.pooling_mode),
1045
+ indice_weights=per_sample_weights,
1046
+ output_dtype=self.output_dtype,
1047
+ lxu_cache_weights=self.lxu_cache_weights,
1048
+ lxu_cache_locations=lxu_cache_locations,
1049
+ row_alignment=self.row_alignment,
1050
+ max_float8_D=self.max_float8_D,
1051
+ fp8_exponent_bits=self.fp8_exponent_bits,
1052
+ fp8_exponent_bias=self.fp8_exponent_bias,
1053
+ )
1054
+
1055
+ def forward(
1056
+ self,
1057
+ indices: Tensor,
1058
+ offsets: Tensor,
1059
+ per_sample_weights: Optional[Tensor] = None,
1060
+ ) -> Tensor:
1061
+ return self._forward_impl(
1062
+ indices=indices, offsets=offsets, per_sample_weights=per_sample_weights
1063
+ )
1064
+
1065
+ def initialize_logical_weights_placements_and_offsets(
1066
+ self,
1067
+ ) -> None:
1068
+ assert len(self.weights_physical_offsets) == len(self.embedding_specs)
1069
+ assert len(self.weights_physical_offsets) == len(
1070
+ self.weights_physical_placements
1071
+ )
1072
+ offsets = [self.weights_physical_offsets[t] for t in self.feature_table_map]
1073
+ placements = [
1074
+ self.weights_physical_placements[t] for t in self.feature_table_map
1075
+ ]
1076
+ self.weights_offsets = torch.tensor(
1077
+ offsets, device=self.current_device, dtype=torch.int64
1078
+ )
1079
+ self.weights_placements = torch.tensor(
1080
+ placements, device=self.current_device, dtype=torch.int32
1081
+ )
1082
+
1083
+ def initialize_physical_weights_placements_and_offsets(
1084
+ self,
1085
+ cacheline_alignment: bool = True,
1086
+ ) -> None:
1087
+ # Initialize physical weights placements and offsets
1088
+ # and host/dev/uvm sizes
1089
+ weight_split: SplitState = nbit_construct_split_state(
1090
+ self.embedding_specs,
1091
+ cacheable=True,
1092
+ row_alignment=self.row_alignment,
1093
+ scale_bias_size_in_bytes=self.scale_bias_size_in_bytes,
1094
+ cacheline_alignment=cacheline_alignment,
1095
+ )
1096
+ self.weights_physical_placements = [t.value for t in weight_split.placements]
1097
+ self.weights_physical_offsets = weight_split.offsets
1098
+ self.host_size = weight_split.host_size
1099
+ self.dev_size = weight_split.dev_size
1100
+ self.uvm_size = weight_split.uvm_size
1101
+
1102
+ @torch.jit.export
1103
+ def reset_weights_placements_and_offsets(
1104
+ self, device: torch.device, location: int
1105
+ ) -> None:
1106
+ # Overwrite location in embedding_specs with new location
1107
+ # Use map since can't script enum call (ie. EmbeddingLocation(value))
1108
+ INT_TO_EMBEDDING_LOCATION = {
1109
+ EmbeddingLocation.DEVICE.value: EmbeddingLocation.DEVICE,
1110
+ EmbeddingLocation.MANAGED.value: EmbeddingLocation.MANAGED,
1111
+ EmbeddingLocation.MANAGED_CACHING.value: EmbeddingLocation.MANAGED_CACHING,
1112
+ EmbeddingLocation.HOST.value: EmbeddingLocation.HOST,
1113
+ EmbeddingLocation.MTIA.value: EmbeddingLocation.MTIA,
1114
+ }
1115
+ # Reset device/location denoted in embedding specs
1116
+ target_location = INT_TO_EMBEDDING_LOCATION[location]
1117
+ if target_location == EmbeddingLocation.MTIA:
1118
+ self.scale_bias_size_in_bytes = 8
1119
+ self.reset_embedding_spec_location(device, target_location)
1120
+ # Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor
1121
+ self.initialize_physical_weights_placements_and_offsets(
1122
+ cacheline_alignment=target_location != EmbeddingLocation.MTIA
1123
+ )
1124
+ self.initialize_logical_weights_placements_and_offsets()
1125
+
1126
+ def reset_embedding_spec_location(
1127
+ self, device: torch.device, target_location: EmbeddingLocation
1128
+ ) -> None:
1129
+ self.current_device = device
1130
+ self.row_alignment = (
1131
+ 1
1132
+ if target_location == EmbeddingLocation.HOST
1133
+ or target_location == EmbeddingLocation.MTIA
1134
+ else 16
1135
+ )
1136
+ self.embedding_specs = [
1137
+ (spec[0], spec[1], spec[2], spec[3], target_location)
1138
+ for spec in self.embedding_specs
1139
+ ]
1140
+
1141
+ @torch.jit.export
1142
+ def recompute_module_buffers(self) -> None:
1143
+ """
1144
+ Compute module buffers that're on meta device and are not materialized
1145
+ in reset_weights_placements_and_offsets(). Currently those buffers are
1146
+ `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`.
1147
+ Pruning related or uvm related buffers are not computed right now.
1148
+ """
1149
+ if (
1150
+ self.weights_tys.device == self.current_device
1151
+ or self.current_device.type == "meta"
1152
+ ):
1153
+ return
1154
+
1155
+ weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs]
1156
+ self.weights_tys = torch.tensor(
1157
+ [weights_tys_int[t] for t in self.feature_table_map],
1158
+ device=self.current_device,
1159
+ dtype=torch.uint8,
1160
+ )
1161
+ rows = [e[1] for e in self.embedding_specs]
1162
+ self.rows_per_table = torch.tensor(
1163
+ [rows[t] for t in self.feature_table_map],
1164
+ device=self.current_device,
1165
+ dtype=torch.int64,
1166
+ )
1167
+ dims = [e[2] for e in self.embedding_specs]
1168
+ D_offsets_list = [0]
1169
+ for t in self.feature_table_map:
1170
+ D_offsets_list.append(dims[t] + D_offsets_list[-1])
1171
+ self.D_offsets = torch.tensor(
1172
+ D_offsets_list, device=self.current_device, dtype=torch.int32
1173
+ )
1174
+ self.bounds_check_warning = torch.tensor(
1175
+ [0], device=self.current_device, dtype=torch.int64
1176
+ )
1177
+
1178
+ # For pruning related or uvm related buffers, we just set them as empty tensors.
1179
+ self.index_remapping_hash_table = torch.empty_like(
1180
+ self.index_remapping_hash_table, device=self.current_device
1181
+ )
1182
+ self.index_remapping_hash_table_offsets = torch.empty_like(
1183
+ self.index_remapping_hash_table_offsets, device=self.current_device
1184
+ )
1185
+ self.index_remappings_array = torch.empty_like(
1186
+ self.index_remappings_array, device=self.current_device
1187
+ )
1188
+ self.index_remappings_array_offsets = torch.empty_like(
1189
+ self.index_remappings_array_offsets, device=self.current_device
1190
+ )
1191
+ # pyre-fixme[16]: `IntNBitTableBatchedEmbeddingBagsCodegen` has no attribute
1192
+ # `lxu_cache_weights`.
1193
+ self.lxu_cache_weights = torch.empty_like(
1194
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
1195
+ # `Union[Module, Tensor]`.
1196
+ self.lxu_cache_weights,
1197
+ device=self.current_device,
1198
+ )
1199
+ self.original_rows_per_table = torch.empty_like(
1200
+ self.original_rows_per_table, device=self.current_device
1201
+ )
1202
+ self.table_wise_cache_miss = torch.empty_like(
1203
+ self.table_wise_cache_miss, device=self.current_device
1204
+ )
1205
+ self.weights_uvm = torch.empty_like(
1206
+ self.weights_uvm, device=self.current_device
1207
+ )
1208
+
1209
+ def _apply_split(
1210
+ self,
1211
+ dev_size: int,
1212
+ host_size: int,
1213
+ uvm_size: int,
1214
+ placements: list[int],
1215
+ offsets: list[int],
1216
+ enforce_hbm: bool,
1217
+ ) -> None:
1218
+ assert not self.weight_initialized, "Weights have already been initialized."
1219
+ self.weight_initialized = True
1220
+ self.weights_physical_placements = placements
1221
+ self.weights_physical_offsets = offsets
1222
+
1223
+ self.host_size = host_size
1224
+ self.dev_size = dev_size
1225
+ self.uvm_size = uvm_size
1226
+
1227
+ self.initialize_logical_weights_placements_and_offsets()
1228
+
1229
+ if dev_size > 0:
1230
+ self.weights_dev = torch.zeros(
1231
+ dev_size,
1232
+ device=self.current_device,
1233
+ dtype=torch.uint8,
1234
+ )
1235
+
1236
+ if host_size > 0:
1237
+ self.weights_host = torch.zeros(
1238
+ host_size, device=self.current_device, dtype=torch.uint8
1239
+ )
1240
+
1241
+ if uvm_size > 0:
1242
+ assert not self.use_cpu
1243
+ if enforce_hbm:
1244
+ if not torch.jit.is_scripting():
1245
+ self.log("Enforce hbm for the cache location")
1246
+ self.weights_uvm = torch.zeros(
1247
+ uvm_size,
1248
+ device=self.current_device,
1249
+ dtype=torch.uint8,
1250
+ )
1251
+ else:
1252
+ self.weights_uvm = torch.zeros(
1253
+ uvm_size,
1254
+ out=torch.ops.fbgemm.new_unified_tensor(
1255
+ torch.zeros(1, device=self.D_offsets.device, dtype=torch.uint8),
1256
+ [uvm_size],
1257
+ self.uvm_host_mapped,
1258
+ ),
1259
+ )
1260
+
1261
+ def _apply_cache_state(
1262
+ self,
1263
+ cache_state: CacheState,
1264
+ cache_algorithm: CacheAlgorithm,
1265
+ cache_load_factor: float,
1266
+ cache_sets: int,
1267
+ cache_reserved_memory: float,
1268
+ ) -> None:
1269
+ assert self.cache_assoc in [
1270
+ 1,
1271
+ 32,
1272
+ 64,
1273
+ ], "Only 1-way or 32-way(64-way for AMD) implmeneted for now"
1274
+
1275
+ self.cache_algorithm = cache_algorithm
1276
+ # pyre-ignore[16]
1277
+ self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
1278
+ # pyre-ignore[16]
1279
+ self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()
1280
+
1281
+ self.max_prefetch_depth = MAX_PREFETCH_DEPTH
1282
+
1283
+ if self.current_device.type == "meta":
1284
+ # To reslove "Cannot copy out of meta tensor; no data!" error
1285
+ lxu_cache_locations_empty = torch.empty(0, dtype=torch.int32).fill_(-1)
1286
+ else:
1287
+ lxu_cache_locations_empty = torch.empty(
1288
+ 0, device=self.current_device, dtype=torch.int32
1289
+ ).fill_(-1)
1290
+ # pyre-ignore[16]
1291
+ self.lxu_cache_locations_list = torch.classes.fbgemm.TensorQueue(
1292
+ lxu_cache_locations_empty
1293
+ )
1294
+
1295
+ # NOTE: no cache for CPU mode!
1296
+ if cache_state.total_cache_hash_size == 0 or self.use_cpu:
1297
+ self.register_buffer(
1298
+ "lxu_cache_weights",
1299
+ torch.zeros(0, 0, device=self.current_device, dtype=torch.uint8),
1300
+ )
1301
+ # NOTE: make TorchScript work!
1302
+ self.register_buffer(
1303
+ "cache_hash_size_cumsum",
1304
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1305
+ persistent=False,
1306
+ )
1307
+ self.total_cache_hash_size = 0
1308
+ self.register_buffer(
1309
+ "cache_index_table_map",
1310
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1311
+ persistent=False,
1312
+ )
1313
+ self.register_buffer(
1314
+ "lxu_cache_state",
1315
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1316
+ persistent=False,
1317
+ )
1318
+ self.register_buffer(
1319
+ "lxu_state",
1320
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1321
+ persistent=False,
1322
+ )
1323
+ self.register_buffer(
1324
+ "lxu_cache_miss_timestamp",
1325
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1326
+ persistent=False,
1327
+ )
1328
+ self.register_buffer(
1329
+ "cache_miss_counter",
1330
+ torch.tensor(
1331
+ [0, 0, 0, 0], dtype=torch.int64, device=self.current_device
1332
+ ),
1333
+ persistent=False,
1334
+ )
1335
+ self.register_buffer(
1336
+ "uvm_cache_stats",
1337
+ torch.zeros(
1338
+ size=(self.uvm_cache_stats_size,),
1339
+ device=self.current_device,
1340
+ dtype=torch.int64,
1341
+ ),
1342
+ persistent=False,
1343
+ )
1344
+ self.register_buffer(
1345
+ "local_uvm_cache_stats",
1346
+ torch.zeros(
1347
+ size=(self.uvm_cache_stats_size,),
1348
+ device=self.current_device,
1349
+ dtype=torch.int32,
1350
+ ),
1351
+ persistent=False,
1352
+ )
1353
+ return
1354
+
1355
+ assert cache_load_factor > 0
1356
+ if cache_sets <= 0:
1357
+ total_memory = torch.cuda.get_device_properties(
1358
+ self.current_device
1359
+ ).total_memory
1360
+ free_memory = (
1361
+ total_memory
1362
+ - torch.cuda.memory_reserved(self.current_device)
1363
+ - int(cache_reserved_memory)
1364
+ )
1365
+ assert free_memory > 0
1366
+ cache_sets = (
1367
+ int(cache_state.total_cache_hash_size * cache_load_factor)
1368
+ + self.cache_assoc
1369
+ - 1
1370
+ ) // self.cache_assoc
1371
+ # Note that element_size has been included in max_D_cache (in Bytes)
1372
+ cache_size = cache_sets * self.cache_assoc * self.max_D_cache
1373
+ if cache_size > free_memory:
1374
+ cache_sets = (
1375
+ int(1.0 * free_memory / self.max_D_cache) + self.cache_assoc - 1
1376
+ ) // self.cache_assoc
1377
+ cache_sets = 1 if cache_sets == 0 else cache_sets
1378
+ cache_load_factor = (
1379
+ 1.0 * cache_sets * self.cache_assoc / int(cache_state.total_cache_hash_size)
1380
+ )
1381
+ assert cache_sets > 0
1382
+ if cache_algorithm == CacheAlgorithm.LFU:
1383
+ assert cache_sets < 2**24 - 1
1384
+ cache_size = cache_sets * self.cache_assoc * self.max_D_cache
1385
+ self.log(
1386
+ f"Using on-device cache with admission algorithm "
1387
+ f"{cache_algorithm}, {cache_sets} sets, "
1388
+ f"cache_load_factor: {cache_load_factor : .3f}, "
1389
+ f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB"
1390
+ )
1391
+
1392
+ self.total_cache_hash_size = cache_state.total_cache_hash_size
1393
+ self.register_buffer(
1394
+ "cache_hash_size_cumsum",
1395
+ torch.tensor(
1396
+ cache_state.cache_hash_size_cumsum,
1397
+ device=self.current_device,
1398
+ dtype=torch.int64,
1399
+ ),
1400
+ )
1401
+ self.register_buffer(
1402
+ "cache_index_table_map",
1403
+ torch.tensor(
1404
+ cache_state.cache_index_table_map,
1405
+ device=self.current_device,
1406
+ dtype=torch.int32,
1407
+ ),
1408
+ )
1409
+ self.register_buffer(
1410
+ "lxu_cache_state",
1411
+ torch.zeros(
1412
+ cache_sets,
1413
+ self.cache_assoc,
1414
+ device=self.current_device,
1415
+ dtype=torch.int64,
1416
+ ).fill_(-1),
1417
+ )
1418
+ self.register_buffer(
1419
+ "lxu_cache_weights",
1420
+ torch.zeros(
1421
+ cache_sets * self.cache_assoc,
1422
+ self.max_D_cache,
1423
+ device=self.current_device,
1424
+ dtype=torch.uint8,
1425
+ ),
1426
+ )
1427
+ self.register_buffer(
1428
+ "lxu_state",
1429
+ torch.zeros(
1430
+ size=(
1431
+ (self.total_cache_hash_size + 1,)
1432
+ if cache_algorithm == CacheAlgorithm.LFU
1433
+ else (cache_sets, self.cache_assoc)
1434
+ ),
1435
+ device=self.current_device,
1436
+ dtype=torch.int64,
1437
+ ),
1438
+ )
1439
+ if self.cache_assoc == 1:
1440
+ self.register_buffer(
1441
+ "lxu_cache_miss_timestamp",
1442
+ torch.zeros(
1443
+ cache_sets,
1444
+ self.cache_assoc,
1445
+ device=self.current_device,
1446
+ dtype=torch.int64,
1447
+ ),
1448
+ )
1449
+ else:
1450
+ # make TorchScript work
1451
+ self.register_buffer(
1452
+ "lxu_cache_miss_timestamp",
1453
+ torch.zeros(1, device=self.current_device, dtype=torch.int64),
1454
+ persistent=False,
1455
+ )
1456
+ self.register_buffer(
1457
+ "cache_miss_counter",
1458
+ torch.tensor([0, 0, 0, 0], device=self.current_device, dtype=torch.int64),
1459
+ )
1460
+ self.register_buffer(
1461
+ "uvm_cache_stats",
1462
+ torch.zeros(
1463
+ size=(self.uvm_cache_stats_size,),
1464
+ device=self.current_device,
1465
+ dtype=torch.int64,
1466
+ ),
1467
+ persistent=False,
1468
+ )
1469
+ self.register_buffer(
1470
+ "local_uvm_cache_stats",
1471
+ torch.zeros(
1472
+ size=(self.uvm_cache_stats_size,),
1473
+ device=self.current_device,
1474
+ dtype=torch.int32,
1475
+ ),
1476
+ persistent=False,
1477
+ )
1478
+ if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
1479
+ raise ValueError(
1480
+ f"cache_algorithm must be {CacheAlgorithm.LRU} "
1481
+ f"or {CacheAlgorithm.LFU}"
1482
+ )
1483
+
1484
+ if self.gather_uvm_cache_stats:
1485
+ self.reset_uvm_cache_stats()
1486
+
1487
+ def reset_cache_states(self) -> None:
1488
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
1489
+ # a function.
1490
+ if not self.lxu_cache_weights.numel():
1491
+ return
1492
+ self.lxu_cache_state.fill_(-1)
1493
+ self.lxu_state.fill_(0)
1494
+ self.timestep_counter.reset()
1495
+
1496
+ def move_to_device_with_cache(
1497
+ self, device: torch.device, cache_load_factor: float
1498
+ ) -> None:
1499
+ """
1500
+ Moves the TBE to the specified device, and updates the cache state accordingly.
1501
+ """
1502
+ if (
1503
+ self.current_device == device
1504
+ and self.cache_load_factor == cache_load_factor
1505
+ ):
1506
+ return
1507
+
1508
+ location = get_new_embedding_location(device, cache_load_factor)
1509
+ if device.type != "cpu":
1510
+ self.use_cpu = False
1511
+
1512
+ weights = self.split_embedding_weights()
1513
+ is_meta = self.current_device.type == "meta"
1514
+ index_remapping_array: torch.Tensor
1515
+ index_remappings_array_offsets: torch.Tensor
1516
+ original_rows_per_table: torch.Tensor
1517
+ if not is_meta:
1518
+ # Record weights and pruning tensors for setting
1519
+ # weights and pruning tensors for TBE on new device
1520
+ if device.type == "cpu":
1521
+ for i, weight in enumerate(weights):
1522
+ weights[i] = (
1523
+ weight[0].to(device),
1524
+ # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `to`.
1525
+ weight[1].to(device) if weight[1] is not None else None,
1526
+ )
1527
+ (
1528
+ index_remapping_array,
1529
+ index_remappings_array_offsets,
1530
+ original_rows_per_table,
1531
+ ) = (
1532
+ self.index_remappings_array.to(device),
1533
+ self.index_remappings_array_offsets.to(device),
1534
+ self.original_rows_per_table.to(device),
1535
+ )
1536
+
1537
+ self.reset_weights_placements_and_offsets(device, location.value)
1538
+ self.recompute_module_buffers()
1539
+ self.weight_initialized = False
1540
+ self.initialize_weights()
1541
+
1542
+ # Ensure all weights are on the same device
1543
+ if device.type != "cpu":
1544
+ self.weights_host = torch.zeros(0, device=device, dtype=torch.uint8)
1545
+
1546
+ if location != EmbeddingLocation.DEVICE:
1547
+ self.weights_dev = torch.zeros(0, device=device, dtype=torch.uint8)
1548
+
1549
+ for name, buf in self.named_buffers():
1550
+ if buf.is_meta:
1551
+ self.register_buffer(name, tensor_to_device(buf, device))
1552
+
1553
+ self.current_device = device
1554
+
1555
+ if not is_meta:
1556
+ self.assign_embedding_weights(weights)
1557
+ self.index_remappings_array = index_remapping_array
1558
+ self.index_remappings_array_offsets = index_remappings_array_offsets
1559
+ self.original_rows_per_table = original_rows_per_table
1560
+
1561
+ if cache_load_factor is not None:
1562
+ self.update_cache_load_factor(cache_load_factor)
1563
+
1564
+ def update_cache_load_factor(self, cache_load_factor: float = 0.2) -> None:
1565
+ """
1566
+ Updates cache_load_factor and embedding location for weights after TBE has already been initialized
1567
+ Assumes that the location of the weights is already set correctly
1568
+ """
1569
+ rows = [
1570
+ embedding_spec[EmbeddingSpecInfo.rows]
1571
+ for embedding_spec in self.embedding_specs
1572
+ ]
1573
+ locations = [
1574
+ embedding_spec[EmbeddingSpecInfo.embedding_location]
1575
+ for embedding_spec in self.embedding_specs
1576
+ ]
1577
+ # pyre-ignore[6]
1578
+ cache_state = construct_cache_state(rows, locations, self.feature_table_map)
1579
+
1580
+ cached_dims = [
1581
+ rounded_row_size_in_bytes(
1582
+ embedding_spec[EmbeddingSpecInfo.dims], # pyre-ignore[6]
1583
+ embedding_spec[EmbeddingSpecInfo.sparse_type], # pyre-ignore[6]
1584
+ 16,
1585
+ self.scale_bias_size_in_bytes,
1586
+ )
1587
+ for embedding_spec in self.embedding_specs
1588
+ if embedding_spec[EmbeddingSpecInfo.embedding_location]
1589
+ == EmbeddingLocation.MANAGED_CACHING
1590
+ ]
1591
+
1592
+ self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
1593
+
1594
+ self._apply_cache_state(
1595
+ cache_state,
1596
+ self.cache_algorithm,
1597
+ cache_load_factor,
1598
+ self.cache_sets,
1599
+ self.cache_reserved_memory,
1600
+ )
1601
+
1602
+ @torch.jit.export
1603
+ def split_embedding_weights_with_scale_bias(
1604
+ self, split_scale_bias_mode: int = 1
1605
+ ) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
1606
+ """
1607
+ Returns a list of weights, split by table
1608
+ split_scale_bias_mode:
1609
+ 0: return one row;
1610
+ 1: return weights + scale_bias;
1611
+ 2: return weights, scale, bias.
1612
+ """
1613
+ assert self.weight_initialized
1614
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
1615
+ for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
1616
+ placement = self.weights_physical_placements[t]
1617
+ if (
1618
+ placement == EmbeddingLocation.DEVICE.value
1619
+ or placement == EmbeddingLocation.MTIA.value
1620
+ ):
1621
+ weights = self.weights_dev
1622
+ elif placement == EmbeddingLocation.HOST.value:
1623
+ weights = self.weights_host
1624
+ else:
1625
+ weights = self.weights_uvm
1626
+ offset = self.weights_physical_offsets[t]
1627
+ weights_shifts = weights.detach()[
1628
+ offset : offset
1629
+ + rows
1630
+ * rounded_row_size_in_bytes(
1631
+ dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
1632
+ )
1633
+ ].view(
1634
+ rows,
1635
+ rounded_row_size_in_bytes(
1636
+ dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
1637
+ ),
1638
+ )
1639
+
1640
+ if split_scale_bias_mode == 1 or split_scale_bias_mode == 2:
1641
+ # remove the padding at the end of each row.
1642
+ weights_shifts = weights_shifts[
1643
+ :,
1644
+ : unpadded_row_size_in_bytes(
1645
+ dim, weight_ty, self.scale_bias_size_in_bytes
1646
+ ),
1647
+ ]
1648
+ if (
1649
+ weight_ty.value == SparseType.INT8.value
1650
+ or weight_ty.value == SparseType.INT4.value
1651
+ or weight_ty.value == SparseType.INT2.value
1652
+ ):
1653
+ if split_scale_bias_mode == 1:
1654
+ if self.reverse_qparam:
1655
+ splits.append(
1656
+ (
1657
+ weights_shifts[
1658
+ :, 0 : (0 - self.scale_bias_size_in_bytes)
1659
+ ],
1660
+ weights_shifts[
1661
+ :, (0 - self.scale_bias_size_in_bytes) :
1662
+ ],
1663
+ None,
1664
+ )
1665
+ )
1666
+ else:
1667
+ splits.append(
1668
+ (
1669
+ weights_shifts[:, self.scale_bias_size_in_bytes :],
1670
+ weights_shifts[:, : self.scale_bias_size_in_bytes],
1671
+ None,
1672
+ )
1673
+ )
1674
+ elif split_scale_bias_mode == 2:
1675
+ if self.reverse_qparam:
1676
+ # weights_shifts: [0:-4] is real weights; [-4:-2] is scale; [-2:] is bias
1677
+ splits.append(
1678
+ (
1679
+ weights_shifts[
1680
+ :, 0 : (0 - self.scale_bias_size_in_bytes)
1681
+ ],
1682
+ weights_shifts[
1683
+ :,
1684
+ (0 - self.scale_bias_size_in_bytes) : (
1685
+ 0 - self.scale_bias_size_in_bytes // 2
1686
+ ),
1687
+ ].view(torch.float16),
1688
+ weights_shifts[
1689
+ :, (0 - self.scale_bias_size_in_bytes // 2) :
1690
+ ].view(torch.float16),
1691
+ )
1692
+ )
1693
+ else:
1694
+ # weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
1695
+ splits.append(
1696
+ (
1697
+ weights_shifts[:, self.scale_bias_size_in_bytes :],
1698
+ weights_shifts[
1699
+ :, : self.scale_bias_size_in_bytes // 2
1700
+ ].view(torch.float16),
1701
+ weights_shifts[
1702
+ :,
1703
+ self.scale_bias_size_in_bytes
1704
+ // 2 : self.scale_bias_size_in_bytes,
1705
+ ].view(torch.float16),
1706
+ )
1707
+ )
1708
+ else:
1709
+ raise ValueError("split_scale_bias_mode is not supported")
1710
+
1711
+ elif (
1712
+ weight_ty.value == SparseType.FP8.value
1713
+ or weight_ty.value == SparseType.FP16.value
1714
+ or weight_ty.value == SparseType.FP32.value
1715
+ ):
1716
+ splits.append(
1717
+ (
1718
+ weights_shifts,
1719
+ None,
1720
+ None,
1721
+ )
1722
+ )
1723
+ else:
1724
+ raise ValueError("weight_ty is not supported")
1725
+
1726
+ else: # split_scale_bias_mode == 0:
1727
+ splits.append((weights_shifts, None, None))
1728
+
1729
+ return splits
1730
+
1731
+ @torch.jit.export
1732
+ def split_embedding_weights(
1733
+ self,
1734
+ split_scale_shifts: bool = True,
1735
+ # When true, return list of two tensors, the first with weights and
1736
+ # the second with scale_bias.
1737
+ # This should've been named as split_scale_bias.
1738
+ # Keep as is for backward compatibility.
1739
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
1740
+ """
1741
+ Returns a list of weights, split by table
1742
+ """
1743
+ # fmt: off
1744
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
1745
+ self.split_embedding_weights_with_scale_bias(
1746
+ split_scale_bias_mode=(1 if split_scale_shifts else 0)
1747
+ )
1748
+ )
1749
+ # fmt: on
1750
+ return [
1751
+ (split_weight_scale_bias[0], split_weight_scale_bias[1])
1752
+ for split_weight_scale_bias in splits
1753
+ ]
1754
+
1755
+ @torch.jit.export
1756
+ def initialize_weights(self) -> None:
1757
+ if not self.weight_initialized:
1758
+ self._apply_split(
1759
+ self.dev_size,
1760
+ self.host_size,
1761
+ self.uvm_size,
1762
+ self.weights_physical_placements,
1763
+ self.weights_physical_offsets,
1764
+ self.enforce_hbm,
1765
+ )
1766
+ self.weight_initialized = True
1767
+
1768
+ def fill_random_weights(self) -> None:
1769
+ """
1770
+ Fill the buffer with random weights, table by table
1771
+ """
1772
+ self.initialize_weights()
1773
+ weights = self.split_embedding_weights()
1774
+ for dest_weight in weights:
1775
+ random_quant_scaled_tensor(
1776
+ shape=dest_weight[0].shape,
1777
+ device=self.current_device,
1778
+ output_tensor=dest_weight[0],
1779
+ )
1780
+
1781
+ def assign_embedding_weights(
1782
+ self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
1783
+ ) -> None:
1784
+ """
1785
+ Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
1786
+ """
1787
+ weights = self.split_embedding_weights()
1788
+ assert len(q_weight_list) == len(weights)
1789
+
1790
+ for dest_weight, input_weight in zip(weights, q_weight_list):
1791
+ dest_weight[0].copy_(input_weight[0])
1792
+ if input_weight[1] is not None:
1793
+ assert dest_weight[1] is not None
1794
+ # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `copy_`.
1795
+ dest_weight[1].copy_(input_weight[1])
1796
+ else:
1797
+ assert dest_weight[1] is None
1798
+
1799
+ @torch.jit.export
1800
+ def set_index_remappings_array(
1801
+ self,
1802
+ index_remapping: list[Tensor],
1803
+ ) -> None:
1804
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1805
+ index_remappings_array_offsets = [0]
1806
+ original_feature_rows = torch.jit.annotate(list[int], [])
1807
+ last_offset = 0
1808
+ for t, mapping in enumerate(index_remapping):
1809
+ if mapping is not None:
1810
+ current_original_row = mapping.numel()
1811
+ last_offset += current_original_row
1812
+ original_feature_rows.append(current_original_row)
1813
+ else:
1814
+ original_feature_rows.append(rows[t])
1815
+ index_remappings_array_offsets.append(last_offset)
1816
+
1817
+ self.index_remappings_array_offsets = torch.tensor(
1818
+ index_remappings_array_offsets,
1819
+ device=self.current_device,
1820
+ dtype=torch.int64,
1821
+ )
1822
+ if len(original_feature_rows) == 0:
1823
+ original_feature_rows = rows
1824
+ self.original_rows_per_table = torch.tensor(
1825
+ [original_feature_rows[t] for t in self.feature_table_map],
1826
+ device=self.current_device,
1827
+ dtype=torch.int64,
1828
+ )
1829
+
1830
+ index_remappings_filter_nones = []
1831
+ for mapping in index_remapping:
1832
+ if mapping is not None:
1833
+ index_remappings_filter_nones.append(mapping)
1834
+ if len(index_remappings_filter_nones) == 0:
1835
+ self.index_remappings_array = torch.empty(
1836
+ 0, dtype=self.indices_dtype, device=self.current_device
1837
+ )
1838
+ else:
1839
+ self.index_remappings_array = torch.cat(index_remappings_filter_nones).to(
1840
+ dtype=self.indices_dtype, device=self.current_device
1841
+ )
1842
+
1843
+ def set_index_remappings(
1844
+ self,
1845
+ index_remapping: list[Tensor],
1846
+ pruning_hash_load_factor: float = 0.5,
1847
+ use_array_for_index_remapping: bool = True,
1848
+ ) -> None:
1849
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1850
+ T = len(self.embedding_specs)
1851
+ # Hash mapping pruning
1852
+ if not use_array_for_index_remapping:
1853
+ capacities = [
1854
+ (
1855
+ round_up(int(row * 1.0 / pruning_hash_load_factor), 32)
1856
+ if index_remap is not None
1857
+ else 0
1858
+ )
1859
+ for (index_remap, row) in zip(index_remapping, rows)
1860
+ ]
1861
+ hash_table = torch.empty(
1862
+ (sum(capacities), 2),
1863
+ dtype=self.indices_dtype,
1864
+ )
1865
+ hash_table[:, :] = -1
1866
+ hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long()
1867
+
1868
+ merged_index_remappings = [
1869
+ mapping if mapping is not None else Tensor(list(range(row)))
1870
+ for (mapping, row) in zip(index_remapping, rows)
1871
+ ]
1872
+ original_feature_rows = [
1873
+ mapping.numel() for mapping in merged_index_remappings
1874
+ ]
1875
+ if len(original_feature_rows) == 0:
1876
+ original_feature_rows = rows
1877
+ self.original_rows_per_table = torch.tensor(
1878
+ [original_feature_rows[t] for t in self.feature_table_map],
1879
+ device=self.current_device,
1880
+ dtype=torch.int64,
1881
+ )
1882
+ dense_indices = torch.cat(merged_index_remappings, dim=0).int()
1883
+ indices = torch.cat(
1884
+ [torch.arange(row) for row in original_feature_rows], dim=0
1885
+ ).int()
1886
+ offsets = torch.tensor([0] + list(accumulate(original_feature_rows))).int()
1887
+
1888
+ if self.use_cpu:
1889
+ self.index_remapping_hash_table_cpu = (
1890
+ # pyre-ignore[16]
1891
+ torch.classes.fbgemm.PrunedMapCPU()
1892
+ )
1893
+ self.index_remapping_hash_table_cpu.insert(
1894
+ indices, dense_indices, offsets, T
1895
+ )
1896
+ else:
1897
+ # pruned_hashmap_insert only has cpu implementation: Move dense_indices to CPU
1898
+ torch.ops.fbgemm.pruned_hashmap_insert(
1899
+ indices,
1900
+ dense_indices.cpu(),
1901
+ offsets,
1902
+ hash_table,
1903
+ hash_table_offsets,
1904
+ )
1905
+ self.index_remapping_hash_table = hash_table.to(
1906
+ dtype=self.indices_dtype, device=self.current_device
1907
+ )
1908
+ self.index_remapping_hash_table_offsets = hash_table_offsets.to(
1909
+ self.current_device
1910
+ )
1911
+ self.index_remapping_hash_table_cpu = None
1912
+ # Array mapping pruning
1913
+ else:
1914
+ self.set_index_remappings_array(index_remapping)
1915
+
1916
+ def _embedding_inplace_update_per_table(
1917
+ self,
1918
+ update_table_idx: int,
1919
+ update_row_indices: list[int],
1920
+ update_weights: Tensor,
1921
+ ) -> None:
1922
+ row_size = len(update_row_indices)
1923
+ if row_size == 0:
1924
+ return
1925
+ # pyre-fixme[9]: update_row_indices has type `List[int]`; used as `Tensor`.
1926
+ update_row_indices = torch.tensor(
1927
+ update_row_indices,
1928
+ device=self.current_device,
1929
+ dtype=torch.int64,
1930
+ )
1931
+ table_values = self.split_embedding_weights(split_scale_shifts=False)[
1932
+ update_table_idx
1933
+ ]
1934
+ table_values[0].scatter_(
1935
+ dim=0,
1936
+ # pyre-fixme[16]: `List` has no attribute `view`.
1937
+ index=update_row_indices.view(row_size, 1).expand_as(update_weights),
1938
+ src=update_weights,
1939
+ )
1940
+
1941
+ @torch.jit.export
1942
+ def embedding_inplace_update(
1943
+ self,
1944
+ update_table_indices: list[int],
1945
+ update_row_indices: list[list[int]],
1946
+ update_weights: list[Tensor],
1947
+ ) -> None:
1948
+ for i in range(len(update_table_indices)):
1949
+ self._embedding_inplace_update_per_table(
1950
+ update_table_indices[i],
1951
+ update_row_indices[i],
1952
+ update_weights[i],
1953
+ )
1954
+
1955
+ def embedding_inplace_update_internal(
1956
+ self,
1957
+ update_table_indices: list[int],
1958
+ update_row_indices: list[int],
1959
+ update_weights: Tensor,
1960
+ ) -> None:
1961
+ assert len(update_table_indices) == len(update_row_indices)
1962
+ update_offsets = []
1963
+ update_offset = 0
1964
+ for table_idx in update_table_indices:
1965
+ D_bytes = rounded_row_size_in_bytes(
1966
+ self.embedding_specs[table_idx][2],
1967
+ self.embedding_specs[table_idx][3],
1968
+ self.row_alignment,
1969
+ self.scale_bias_size_in_bytes,
1970
+ )
1971
+ update_offsets.append(update_offset)
1972
+ update_offset += D_bytes
1973
+ update_offsets.append(update_offset)
1974
+
1975
+ # pyre-fixme[9]: update_table_indices has type `List[int]`; used as `Tensor`.
1976
+ update_table_indices = torch.tensor(
1977
+ update_table_indices,
1978
+ device=self.current_device,
1979
+ dtype=torch.int32,
1980
+ )
1981
+ # pyre-fixme[9]: update_row_indices has type `List[int]`; used as `Tensor`.
1982
+ update_row_indices = torch.tensor(
1983
+ update_row_indices,
1984
+ device=self.current_device,
1985
+ dtype=torch.int64,
1986
+ )
1987
+ update_offsets = torch.tensor(
1988
+ update_offsets,
1989
+ device=self.current_device,
1990
+ dtype=torch.int64,
1991
+ )
1992
+
1993
+ # Only support array based pruning for now.
1994
+ assert self.index_remapping_hash_table_cpu is None
1995
+ assert self.index_remapping_hash_table.numel() == 0
1996
+ assert self.index_remappings_array.numel() >= 0
1997
+
1998
+ if self.index_remappings_array.numel() > 0:
1999
+ update_row_indices = torch.ops.fbgemm.pruned_array_lookup_from_row_idx(
2000
+ update_row_indices,
2001
+ update_table_indices,
2002
+ self.index_remappings_array,
2003
+ self.index_remappings_array_offsets,
2004
+ )
2005
+
2006
+ lxu_cache_locations = None
2007
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
2008
+ # a function.
2009
+ if self.lxu_cache_weights.numel() > 0:
2010
+ linear_cache_indices = (
2011
+ torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
2012
+ self.cache_hash_size_cumsum,
2013
+ update_table_indices,
2014
+ update_row_indices,
2015
+ )
2016
+ )
2017
+
2018
+ if self.cache_assoc in [32, 64]:
2019
+ # 64 for AMD
2020
+ self.prefetch_32way(linear_cache_indices)
2021
+ elif self.cache_assoc == 1:
2022
+ self.prefetch_1way(linear_cache_indices)
2023
+ else:
2024
+ raise ValueError(f"{self.cache_assoc} not in [1, 32, 64]")
2025
+
2026
+ lxu_cache_locations = self.lxu_cache_locations_list.pop()
2027
+
2028
+ torch.ops.fbgemm.emb_inplace_update(
2029
+ dev_weights=self.weights_host if self.host_size > 0 else self.weights_dev,
2030
+ uvm_weights=self.weights_uvm,
2031
+ weights_placements=self.weights_placements,
2032
+ weights_offsets=self.weights_offsets,
2033
+ weights_tys=self.weights_tys,
2034
+ D_offsets=self.D_offsets,
2035
+ update_weights=update_weights,
2036
+ update_table_indices=update_table_indices,
2037
+ update_row_indices=update_row_indices,
2038
+ update_offsets=update_offsets,
2039
+ row_alignment=self.row_alignment,
2040
+ lxu_cache_weights=self.lxu_cache_weights,
2041
+ lxu_cache_locations=lxu_cache_locations,
2042
+ )