fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -18,14 +18,14 @@ import uuid
|
|
|
18
18
|
from dataclasses import dataclass, field
|
|
19
19
|
from itertools import accumulate
|
|
20
20
|
from math import log2
|
|
21
|
-
from typing import Any, Callable,
|
|
21
|
+
from typing import Any, Callable, Optional, Union
|
|
22
22
|
|
|
23
23
|
import torch # usort:skip
|
|
24
24
|
from torch import nn, Tensor # usort:skip
|
|
25
|
+
from torch.autograd.profiler import record_function # usort:skip
|
|
25
26
|
|
|
26
27
|
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
27
28
|
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
28
|
-
|
|
29
29
|
from fbgemm_gpu.config import FeatureGate, FeatureGateName
|
|
30
30
|
from fbgemm_gpu.runtime_monitor import (
|
|
31
31
|
AsyncSeriesTimer,
|
|
@@ -37,8 +37,10 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
37
37
|
BoundsCheckMode,
|
|
38
38
|
CacheAlgorithm,
|
|
39
39
|
CacheState,
|
|
40
|
+
ComputeDevice,
|
|
40
41
|
construct_cache_state,
|
|
41
42
|
EmbeddingLocation,
|
|
43
|
+
get_bounds_check_version_for_platform,
|
|
42
44
|
MAX_PREFETCH_DEPTH,
|
|
43
45
|
MultiPassPrefetchConfig,
|
|
44
46
|
PoolingMode,
|
|
@@ -46,6 +48,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
46
48
|
SplitState,
|
|
47
49
|
)
|
|
48
50
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
51
|
+
check_allocated_vbe_output,
|
|
49
52
|
generate_vbe_metadata,
|
|
50
53
|
is_torchdynamo_compiling,
|
|
51
54
|
)
|
|
@@ -55,8 +58,8 @@ from fbgemm_gpu.tbe_input_multiplexer import (
|
|
|
55
58
|
TBEInputMultiplexer,
|
|
56
59
|
TBEInputMultiplexerConfig,
|
|
57
60
|
)
|
|
58
|
-
|
|
59
61
|
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
|
62
|
+
from fbgemm_gpu.utils.writeback_util import writeback_gradient
|
|
60
63
|
|
|
61
64
|
try:
|
|
62
65
|
load_torch_module(
|
|
@@ -80,12 +83,6 @@ class DoesNotHavePrefix(Exception):
|
|
|
80
83
|
pass
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
class ComputeDevice(enum.IntEnum):
|
|
84
|
-
CPU = 0
|
|
85
|
-
CUDA = 1
|
|
86
|
-
MTIA = 2
|
|
87
|
-
|
|
88
|
-
|
|
89
86
|
class WeightDecayMode(enum.IntEnum):
|
|
90
87
|
NONE = 0
|
|
91
88
|
L2 = 1
|
|
@@ -99,6 +96,7 @@ class CounterWeightDecayMode(enum.IntEnum):
|
|
|
99
96
|
NONE = 0
|
|
100
97
|
L2 = 1
|
|
101
98
|
DECOUPLE = 2
|
|
99
|
+
ADAGRADW = 3
|
|
102
100
|
|
|
103
101
|
|
|
104
102
|
class StepMode(enum.IntEnum):
|
|
@@ -160,6 +158,8 @@ class UserEnabledConfigDefinition:
|
|
|
160
158
|
# This is used in Adam to perform rowwise bias correction using `row_counter`
|
|
161
159
|
# More details can be found in D64848802.
|
|
162
160
|
use_rowwise_bias_correction: bool = False
|
|
161
|
+
use_writeback_bwd_prehook: bool = False
|
|
162
|
+
writeback_first_feature_only: bool = False
|
|
163
163
|
|
|
164
164
|
|
|
165
165
|
@dataclass(frozen=True)
|
|
@@ -188,16 +188,52 @@ class UVMCacheStatsIndex(enum.IntEnum):
|
|
|
188
188
|
num_conflict_misses = 5
|
|
189
189
|
|
|
190
190
|
|
|
191
|
+
@dataclass
|
|
192
|
+
class RESParams:
|
|
193
|
+
res_server_port: int = 0 # the port of the res server
|
|
194
|
+
res_store_shards: int = 1 # the number of shards to store the raw embeddings
|
|
195
|
+
table_names: list[str] = field(default_factory=list) # table names the TBE holds
|
|
196
|
+
table_offsets: list[int] = field(
|
|
197
|
+
default_factory=list
|
|
198
|
+
) # table offsets for the global rows the TBE holds
|
|
199
|
+
table_sizes: list[int] = field(
|
|
200
|
+
default_factory=list
|
|
201
|
+
) # table sizes for the global rows the TBE holds
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class PrefetchedInfo:
|
|
205
|
+
"""
|
|
206
|
+
Container for prefetched cache information.
|
|
207
|
+
|
|
208
|
+
This class is explicitly defined (not using @dataclass) to be compatible with
|
|
209
|
+
TorchScript's inspect.getsource() requirements.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
linear_unique_indices: torch.Tensor,
|
|
215
|
+
linear_unique_cache_indices: torch.Tensor,
|
|
216
|
+
linear_unique_indices_length: torch.Tensor,
|
|
217
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
218
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
219
|
+
) -> None:
|
|
220
|
+
self.linear_unique_indices = linear_unique_indices
|
|
221
|
+
self.linear_unique_cache_indices = linear_unique_cache_indices
|
|
222
|
+
self.linear_unique_indices_length = linear_unique_indices_length
|
|
223
|
+
self.hash_zch_identities = hash_zch_identities
|
|
224
|
+
self.hash_zch_runtime_meta = hash_zch_runtime_meta
|
|
225
|
+
|
|
226
|
+
|
|
191
227
|
def construct_split_state(
|
|
192
|
-
embedding_specs:
|
|
228
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
|
|
193
229
|
rowwise: bool,
|
|
194
230
|
cacheable: bool,
|
|
195
231
|
precision: SparseType = SparseType.FP32,
|
|
196
232
|
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
|
|
197
233
|
placement: Optional[EmbeddingLocation] = None,
|
|
198
234
|
) -> SplitState:
|
|
199
|
-
placements:
|
|
200
|
-
offsets:
|
|
235
|
+
placements: list[EmbeddingLocation] = []
|
|
236
|
+
offsets: list[int] = []
|
|
201
237
|
dev_size: int = 0
|
|
202
238
|
host_size: int = 0
|
|
203
239
|
uvm_size: int = 0
|
|
@@ -239,18 +275,18 @@ def construct_split_state(
|
|
|
239
275
|
def apply_split_helper(
|
|
240
276
|
persistent_state_fn: Callable[[str, Tensor], None],
|
|
241
277
|
set_attr_fn: Callable[
|
|
242
|
-
[str, Union[Tensor,
|
|
278
|
+
[str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
|
|
243
279
|
],
|
|
244
280
|
current_device: torch.device,
|
|
245
281
|
use_cpu: bool,
|
|
246
|
-
feature_table_map:
|
|
282
|
+
feature_table_map: list[int],
|
|
247
283
|
split: SplitState,
|
|
248
284
|
prefix: str,
|
|
249
|
-
dtype:
|
|
285
|
+
dtype: type[torch.dtype],
|
|
250
286
|
enforce_hbm: bool = False,
|
|
251
287
|
make_dev_param: bool = False,
|
|
252
|
-
dev_reshape: Optional[
|
|
253
|
-
uvm_tensors_log: Optional[
|
|
288
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
289
|
+
uvm_tensors_log: Optional[list[str]] = None,
|
|
254
290
|
uvm_host_mapped: bool = False,
|
|
255
291
|
) -> None:
|
|
256
292
|
set_attr_fn(f"{prefix}_physical_placements", split.placements)
|
|
@@ -335,6 +371,7 @@ def apply_split_helper(
|
|
|
335
371
|
f"{prefix}_uvm",
|
|
336
372
|
torch.zeros(
|
|
337
373
|
split.uvm_size,
|
|
374
|
+
device=current_device,
|
|
338
375
|
out=torch.ops.fbgemm.new_unified_tensor(
|
|
339
376
|
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
|
|
340
377
|
# for 3rd param but got `Type[Type[torch._dtype]]`.
|
|
@@ -603,13 +640,19 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
603
640
|
embedding_table_offset_type (torch.dtype = torch.int64): The data type of
|
|
604
641
|
the embedding table offset tensor. Options are `torch.int32` and
|
|
605
642
|
`torch.int64`
|
|
643
|
+
|
|
644
|
+
embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None): the
|
|
645
|
+
information about shard position and pre-sharded table size. If not set,
|
|
646
|
+
the table is not sharded.
|
|
647
|
+
(preshard_table_height, preshard_table_dim, height_offset, dim_offset)
|
|
606
648
|
"""
|
|
607
649
|
|
|
608
|
-
embedding_specs:
|
|
650
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
|
|
609
651
|
optimizer_args: invokers.lookup_args.OptimizerArgs
|
|
610
|
-
lxu_cache_locations_list:
|
|
652
|
+
lxu_cache_locations_list: list[Tensor]
|
|
611
653
|
lxu_cache_locations_empty: Tensor
|
|
612
|
-
timesteps_prefetched:
|
|
654
|
+
timesteps_prefetched: list[int]
|
|
655
|
+
prefetched_info_list: list[PrefetchedInfo]
|
|
613
656
|
record_cache_metrics: RecordCacheMetrics
|
|
614
657
|
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
|
615
658
|
uvm_cache_stats: torch.Tensor
|
|
@@ -623,10 +666,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
623
666
|
|
|
624
667
|
def __init__( # noqa C901
|
|
625
668
|
self,
|
|
626
|
-
embedding_specs:
|
|
627
|
-
|
|
669
|
+
embedding_specs: list[
|
|
670
|
+
tuple[int, int, EmbeddingLocation, ComputeDevice]
|
|
628
671
|
], # tuple of (rows, dims, placements, compute_devices)
|
|
629
|
-
feature_table_map: Optional[
|
|
672
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
630
673
|
cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
|
|
631
674
|
cache_load_factor: float = 0.2,
|
|
632
675
|
cache_sets: int = 0,
|
|
@@ -664,8 +707,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
664
707
|
use_experimental_tbe: bool = False,
|
|
665
708
|
prefetch_pipeline: bool = False,
|
|
666
709
|
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
667
|
-
table_names: Optional[
|
|
668
|
-
optimizer_state_dtypes: Optional[
|
|
710
|
+
table_names: Optional[list[str]] = None,
|
|
711
|
+
optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
|
|
669
712
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
670
713
|
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
|
|
671
714
|
uvm_host_mapped: bool = False,
|
|
@@ -673,23 +716,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
673
716
|
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
|
|
674
717
|
embedding_table_index_type: torch.dtype = torch.int64,
|
|
675
718
|
embedding_table_offset_type: torch.dtype = torch.int64,
|
|
719
|
+
embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
|
|
720
|
+
enable_raw_embedding_streaming: bool = False,
|
|
721
|
+
res_params: Optional[RESParams] = None,
|
|
676
722
|
) -> None:
|
|
677
723
|
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
678
724
|
self.uuid = str(uuid.uuid4())
|
|
679
|
-
|
|
725
|
+
self.log("SplitTableBatchedEmbeddingBagsCodegen API: V2")
|
|
680
726
|
self.log(f"SplitTableBatchedEmbeddingBagsCodegen Arguments: {locals()}")
|
|
681
727
|
self.log(
|
|
682
728
|
f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
|
|
683
729
|
)
|
|
684
730
|
|
|
731
|
+
self.table_names: Optional[list[str]] = table_names
|
|
685
732
|
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
|
|
733
|
+
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
|
|
686
734
|
self.pooling_mode = pooling_mode
|
|
687
735
|
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
|
|
736
|
+
|
|
688
737
|
# If environment variable is set, it overwrites the default bounds check mode.
|
|
689
|
-
self.bounds_check_version: int =
|
|
738
|
+
self.bounds_check_version: int = (
|
|
739
|
+
2
|
|
740
|
+
if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2)
|
|
741
|
+
else get_bounds_check_version_for_platform()
|
|
742
|
+
)
|
|
690
743
|
self.bounds_check_mode_int: int = int(
|
|
691
744
|
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
|
|
692
745
|
)
|
|
746
|
+
# Check if bounds_check_indices_v2 is enabled via the feature gate
|
|
693
747
|
bounds_check_mode = BoundsCheckMode(self.bounds_check_mode_int)
|
|
694
748
|
if bounds_check_mode.name.startswith("V2_"):
|
|
695
749
|
self.bounds_check_version = 2
|
|
@@ -699,10 +753,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
699
753
|
bounds_check_mode = BoundsCheckMode.WARNING
|
|
700
754
|
elif bounds_check_mode == BoundsCheckMode.V2_FATAL:
|
|
701
755
|
bounds_check_mode = BoundsCheckMode.FATAL
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
756
|
+
|
|
757
|
+
if bounds_check_mode not in (
|
|
758
|
+
BoundsCheckMode.IGNORE,
|
|
759
|
+
BoundsCheckMode.WARNING,
|
|
760
|
+
BoundsCheckMode.FATAL,
|
|
761
|
+
BoundsCheckMode.NONE,
|
|
762
|
+
):
|
|
763
|
+
raise NotImplementedError(
|
|
764
|
+
f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} is not supported"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
self.bounds_check_mode_int = bounds_check_mode.value
|
|
768
|
+
|
|
769
|
+
self.log(
|
|
770
|
+
f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} bounds_check_version={self.bounds_check_version}"
|
|
771
|
+
)
|
|
706
772
|
|
|
707
773
|
self.weights_precision = weights_precision
|
|
708
774
|
|
|
@@ -715,6 +781,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
715
781
|
# See:
|
|
716
782
|
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
|
|
717
783
|
cache_precision = SparseType.FP32
|
|
784
|
+
self.log("Override cache_precision=SparseType.FP32 on ROCm")
|
|
718
785
|
else:
|
|
719
786
|
# NOTE: The changes from D65865527 are retained here until we can
|
|
720
787
|
# test that the the hack also works for non-ROCm environments.
|
|
@@ -757,9 +824,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
757
824
|
), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
|
|
758
825
|
|
|
759
826
|
self.embedding_specs = embedding_specs
|
|
760
|
-
|
|
827
|
+
rows, dims, locations, compute_devices = zip(*embedding_specs)
|
|
761
828
|
T_ = len(self.embedding_specs)
|
|
762
|
-
self.dims:
|
|
829
|
+
self.dims: list[int] = dims
|
|
763
830
|
assert T_ > 0
|
|
764
831
|
# mixed D is not supported by no bag kernels
|
|
765
832
|
mixed_D = False
|
|
@@ -772,7 +839,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
772
839
|
assert (
|
|
773
840
|
self.pooling_mode != PoolingMode.NONE
|
|
774
841
|
), "Mixed dimension tables only supported for pooling tables."
|
|
775
|
-
|
|
842
|
+
self.mixed_D: bool = mixed_D
|
|
776
843
|
assert all(
|
|
777
844
|
cd == compute_devices[0] for cd in compute_devices
|
|
778
845
|
), "Heterogenous compute_devices are NOT supported!"
|
|
@@ -836,7 +903,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
836
903
|
self.stats_reporter: Optional[TBEStatsReporter] = (
|
|
837
904
|
stats_reporter_config.create_reporter() if stats_reporter_config else None
|
|
838
905
|
)
|
|
839
|
-
self._uvm_tensors_log:
|
|
906
|
+
self._uvm_tensors_log: list[str] = []
|
|
840
907
|
|
|
841
908
|
self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
|
|
842
909
|
self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
|
|
@@ -863,10 +930,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
863
930
|
|
|
864
931
|
self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
|
|
865
932
|
|
|
866
|
-
self.feature_table_map:
|
|
933
|
+
self.feature_table_map: list[int] = (
|
|
867
934
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
868
935
|
)
|
|
869
936
|
|
|
937
|
+
if embedding_shard_info:
|
|
938
|
+
full_table_heights, full_table_dims, row_offset, col_offset = zip(
|
|
939
|
+
*embedding_shard_info
|
|
940
|
+
)
|
|
941
|
+
else:
|
|
942
|
+
# Just assume the table is unsharded
|
|
943
|
+
full_table_heights = rows
|
|
944
|
+
full_table_dims = dims
|
|
945
|
+
row_offset = [0] * len(rows)
|
|
946
|
+
col_offset = [0] * len(rows)
|
|
870
947
|
self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
|
|
871
948
|
tbe_input_multiplexer_config.create_tbe_input_multiplexer(
|
|
872
949
|
tbe_info=TBEInfo(
|
|
@@ -878,6 +955,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
878
955
|
table_heights=rows,
|
|
879
956
|
tbe_uuid=self.uuid,
|
|
880
957
|
feature_table_map=self.feature_table_map,
|
|
958
|
+
table_dims=dims,
|
|
959
|
+
full_table_heights=full_table_heights,
|
|
960
|
+
full_table_dims=full_table_dims,
|
|
961
|
+
row_offset=row_offset,
|
|
962
|
+
col_offset=col_offset,
|
|
881
963
|
)
|
|
882
964
|
)
|
|
883
965
|
if tbe_input_multiplexer_config is not None
|
|
@@ -888,7 +970,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
888
970
|
table_has_feature = [False] * T_
|
|
889
971
|
for t in self.feature_table_map:
|
|
890
972
|
table_has_feature[t] = True
|
|
891
|
-
assert all(table_has_feature),
|
|
973
|
+
assert all(table_has_feature), (
|
|
974
|
+
"Each table must have at least one feature!"
|
|
975
|
+
+ f"{[(i, x) for i, x in enumerate(table_has_feature)]}"
|
|
976
|
+
)
|
|
892
977
|
|
|
893
978
|
feature_dims = [dims[t] for t in self.feature_table_map]
|
|
894
979
|
D_offsets = [0] + list(accumulate(feature_dims))
|
|
@@ -940,6 +1025,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
940
1025
|
"feature_dims",
|
|
941
1026
|
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
942
1027
|
)
|
|
1028
|
+
_info_B_num_bits, _info_B_mask = torch.ops.fbgemm.get_infos_metadata(
|
|
1029
|
+
self.D_offsets, # unused tensor
|
|
1030
|
+
1, # max_B
|
|
1031
|
+
T, # T
|
|
1032
|
+
)
|
|
1033
|
+
self.info_B_num_bits: int = _info_B_num_bits
|
|
1034
|
+
self.info_B_mask: int = _info_B_mask
|
|
943
1035
|
|
|
944
1036
|
# A flag for indicating whether all embedding tables are placed in the
|
|
945
1037
|
# same locations
|
|
@@ -1047,13 +1139,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1047
1139
|
|
|
1048
1140
|
if ensemble_mode is None:
|
|
1049
1141
|
ensemble_mode = EnsembleModeDefinition()
|
|
1050
|
-
self._ensemble_mode:
|
|
1142
|
+
self._ensemble_mode: dict[str, float] = {
|
|
1051
1143
|
key: float(fval) for key, fval in ensemble_mode.__dict__.items()
|
|
1052
1144
|
}
|
|
1053
1145
|
|
|
1054
1146
|
if emainplace_mode is None:
|
|
1055
1147
|
emainplace_mode = EmainplaceModeDefinition()
|
|
1056
|
-
self._emainplace_mode:
|
|
1148
|
+
self._emainplace_mode: dict[str, float] = {
|
|
1057
1149
|
key: float(fval) for key, fval in emainplace_mode.__dict__.items()
|
|
1058
1150
|
}
|
|
1059
1151
|
|
|
@@ -1085,26 +1177,37 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1085
1177
|
# and CowClipDefinition are not used
|
|
1086
1178
|
counter_halflife = -1
|
|
1087
1179
|
|
|
1088
|
-
# TO DO: Enable this on the new interface
|
|
1089
|
-
# learning_rate_tensor = torch.tensor(
|
|
1090
|
-
# learning_rate, device=torch.device("cpu"), dtype=torch.float
|
|
1091
|
-
# )
|
|
1092
1180
|
if extra_optimizer_config is None:
|
|
1093
1181
|
extra_optimizer_config = UserEnabledConfigDefinition()
|
|
1094
1182
|
self.use_rowwise_bias_correction: bool = (
|
|
1095
1183
|
extra_optimizer_config.use_rowwise_bias_correction
|
|
1096
1184
|
)
|
|
1185
|
+
self.use_writeback_bwd_prehook: bool = (
|
|
1186
|
+
extra_optimizer_config.use_writeback_bwd_prehook
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
writeback_first_feature_only: bool = (
|
|
1190
|
+
extra_optimizer_config.writeback_first_feature_only
|
|
1191
|
+
)
|
|
1192
|
+
self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
|
|
1097
1193
|
if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
|
|
1098
1194
|
raise AssertionError(
|
|
1099
1195
|
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
|
|
1100
1196
|
)
|
|
1197
|
+
if self.use_writeback_bwd_prehook and not self.optimizer == OptimType.EXACT_SGD:
|
|
1198
|
+
raise AssertionError(
|
|
1199
|
+
"`use_writeback_bwd_prehook` is only supported for OptimType.EXACT_SGD",
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
self.learning_rate_tensor: torch.Tensor = torch.tensor(
|
|
1203
|
+
learning_rate, device=torch.device("cpu"), dtype=torch.float32
|
|
1204
|
+
)
|
|
1101
1205
|
|
|
1102
1206
|
self.optimizer_args = invokers.lookup_args.OptimizerArgs(
|
|
1103
1207
|
stochastic_rounding=stochastic_rounding,
|
|
1104
1208
|
gradient_clipping=gradient_clipping,
|
|
1105
1209
|
max_gradient=max_gradient,
|
|
1106
1210
|
max_norm=max_norm,
|
|
1107
|
-
learning_rate=learning_rate,
|
|
1108
1211
|
eps=eps,
|
|
1109
1212
|
beta1=beta1,
|
|
1110
1213
|
beta2=beta2,
|
|
@@ -1351,7 +1454,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1351
1454
|
|
|
1352
1455
|
self.step = 0
|
|
1353
1456
|
self.last_reported_step = 0
|
|
1354
|
-
self.last_reported_uvm_stats:
|
|
1457
|
+
self.last_reported_uvm_stats: list[float] = []
|
|
1458
|
+
# Track number of times detailed memory breakdown has been reported
|
|
1459
|
+
self.detailed_mem_breakdown_report_count = 0
|
|
1460
|
+
# Set max number of reports for detailed memory breakdown
|
|
1461
|
+
self.max_detailed_mem_breakdown_reports = 10
|
|
1355
1462
|
|
|
1356
1463
|
# Check whether to use TBE v2
|
|
1357
1464
|
is_experimental = False
|
|
@@ -1370,18 +1477,25 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1370
1477
|
# self.log("TBE_V2 Knob is set to True; Using experimental TBE")
|
|
1371
1478
|
|
|
1372
1479
|
self.is_experimental: bool = is_experimental
|
|
1480
|
+
self._writeback_first_feature_only: bool = writeback_first_feature_only
|
|
1373
1481
|
|
|
1374
1482
|
# Get a debug function pointer
|
|
1375
1483
|
self._debug_print_input_stats: Callable[..., None] = (
|
|
1376
1484
|
self._debug_print_input_stats_factory()
|
|
1377
1485
|
)
|
|
1378
1486
|
|
|
1379
|
-
#
|
|
1380
|
-
self.
|
|
1381
|
-
|
|
1487
|
+
# Get a reporter function pointer
|
|
1488
|
+
self._report_input_params: Callable[..., None] = (
|
|
1489
|
+
self.__report_input_params_factory()
|
|
1382
1490
|
)
|
|
1383
|
-
|
|
1384
|
-
|
|
1491
|
+
|
|
1492
|
+
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
|
|
1493
|
+
# Register writeback hook for Exact_SGD optimizer
|
|
1494
|
+
self.log(
|
|
1495
|
+
f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
|
|
1496
|
+
)
|
|
1497
|
+
# pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
|
|
1498
|
+
self.register_full_backward_pre_hook(self.writeback_hook)
|
|
1385
1499
|
|
|
1386
1500
|
if embedding_table_index_type not in [torch.int32, torch.int64]:
|
|
1387
1501
|
raise ValueError(
|
|
@@ -1394,6 +1508,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1394
1508
|
)
|
|
1395
1509
|
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
|
|
1396
1510
|
|
|
1511
|
+
self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
|
|
1512
|
+
list[PrefetchedInfo], []
|
|
1513
|
+
)
|
|
1514
|
+
if self.enable_raw_embedding_streaming:
|
|
1515
|
+
self.res_params: RESParams = res_params or RESParams()
|
|
1516
|
+
self.res_params.table_sizes = [0] + list(accumulate(rows))
|
|
1517
|
+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
|
|
1518
|
+
self.res_params.res_server_port = (
|
|
1519
|
+
int(res_port_from_env) if res_port_from_env else 0
|
|
1520
|
+
)
|
|
1521
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
1522
|
+
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
|
|
1523
|
+
self.uuid,
|
|
1524
|
+
self.enable_raw_embedding_streaming,
|
|
1525
|
+
self.res_params.res_store_shards,
|
|
1526
|
+
self.res_params.res_server_port,
|
|
1527
|
+
self.res_params.table_names,
|
|
1528
|
+
self.res_params.table_offsets,
|
|
1529
|
+
self.res_params.table_sizes,
|
|
1530
|
+
)
|
|
1531
|
+
logging.info(
|
|
1532
|
+
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1397
1535
|
@torch.jit.ignore
|
|
1398
1536
|
def log(self, msg: str) -> None:
|
|
1399
1537
|
"""
|
|
@@ -1437,7 +1575,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1437
1575
|
)
|
|
1438
1576
|
|
|
1439
1577
|
@staticmethod
|
|
1440
|
-
def get_table_name_for_logging(table_names: Optional[
|
|
1578
|
+
def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
|
|
1441
1579
|
"""
|
|
1442
1580
|
Given a list of all table names in the TBE, generate a string to
|
|
1443
1581
|
represent them in logging. If there is more than one table, this method
|
|
@@ -1453,17 +1591,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1453
1591
|
return "<Unknown>"
|
|
1454
1592
|
# Do this because sometimes multiple shards of the same table could appear
|
|
1455
1593
|
# in one TBE.
|
|
1456
|
-
table_name_set = set(table_names)
|
|
1594
|
+
table_name_set = sorted(set(table_names))
|
|
1457
1595
|
if len(table_name_set) == 1:
|
|
1458
1596
|
return next(iter(table_name_set))
|
|
1459
|
-
return f"<{len(table_name_set)} tables
|
|
1597
|
+
return f"<{len(table_name_set)} tables>: {table_name_set}"
|
|
1460
1598
|
|
|
1461
1599
|
@staticmethod
|
|
1462
1600
|
def get_prefetch_passes(
|
|
1463
1601
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
|
|
1464
1602
|
input_tensor: Tensor,
|
|
1465
1603
|
output_tensor: Tensor,
|
|
1466
|
-
) ->
|
|
1604
|
+
) -> list[tuple[Tensor, Tensor, int]]:
|
|
1467
1605
|
"""
|
|
1468
1606
|
Given inputs (the indices to forward), partition the input and output
|
|
1469
1607
|
into smaller chunks and return them as a list of tuples
|
|
@@ -1511,7 +1649,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1511
1649
|
)
|
|
1512
1650
|
)
|
|
1513
1651
|
|
|
1514
|
-
def get_states(self, prefix: str) ->
|
|
1652
|
+
def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
1515
1653
|
"""
|
|
1516
1654
|
Get a state of a given tensor (`prefix`)
|
|
1517
1655
|
|
|
@@ -1550,7 +1688,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1550
1688
|
torch.tensor(offsets, dtype=torch.int64),
|
|
1551
1689
|
)
|
|
1552
1690
|
|
|
1553
|
-
def get_all_states(self) ->
|
|
1691
|
+
def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
|
|
1554
1692
|
"""
|
|
1555
1693
|
Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
|
|
1556
1694
|
`prev_iter`, and `row_counter`)
|
|
@@ -1614,10 +1752,161 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1614
1752
|
tbe_id=self.uuid,
|
|
1615
1753
|
)
|
|
1616
1754
|
|
|
1617
|
-
|
|
1618
|
-
|
|
1755
|
+
def _get_tensor_memory(self, tensor_name: str) -> int:
|
|
1756
|
+
"""Get memory usage of a tensor in bytes."""
|
|
1757
|
+
if not hasattr(self, tensor_name):
|
|
1758
|
+
self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
|
|
1759
|
+
return 0
|
|
1760
|
+
tensor = getattr(self, tensor_name)
|
|
1761
|
+
return tensor.numel() * tensor.element_size()
|
|
1762
|
+
|
|
1763
|
+
def _categorize_memory_by_location(
|
|
1764
|
+
self, tensor_names: list[str]
|
|
1765
|
+
) -> tuple[int, int]:
|
|
1766
|
+
"""Categorize memory into HBM and UVM for given tensors.
|
|
1767
|
+
|
|
1768
|
+
Returns:
|
|
1769
|
+
(hbm_bytes, uvm_bytes)
|
|
1770
|
+
"""
|
|
1771
|
+
uvm_set = set(self._uvm_tensors_log)
|
|
1772
|
+
hbm_bytes = 0
|
|
1773
|
+
uvm_bytes = 0
|
|
1774
|
+
|
|
1775
|
+
for name in tensor_names:
|
|
1776
|
+
size = self._get_tensor_memory(name)
|
|
1777
|
+
if name in uvm_set:
|
|
1778
|
+
uvm_bytes += size
|
|
1779
|
+
else:
|
|
1780
|
+
hbm_bytes += size
|
|
1781
|
+
|
|
1782
|
+
return hbm_bytes, uvm_bytes
|
|
1783
|
+
|
|
1784
|
+
def _report_hbm_breakdown(
|
|
1619
1785
|
self,
|
|
1786
|
+
stats_reporter: TBEStatsReporter,
|
|
1787
|
+
embeddings: int,
|
|
1788
|
+
optimizer_states: int,
|
|
1789
|
+
cache: int,
|
|
1790
|
+
total_static_sparse: int,
|
|
1791
|
+
ephemeral: int,
|
|
1792
|
+
cache_weights: int = 0,
|
|
1793
|
+
cache_aux: int = 0,
|
|
1620
1794
|
) -> None:
|
|
1795
|
+
"""Report HBM memory breakdown to stats reporter."""
|
|
1796
|
+
stats_reporter.report_data_amount(
|
|
1797
|
+
iteration_step=self.step,
|
|
1798
|
+
event_name="tbe.hbm.embeddings",
|
|
1799
|
+
data_bytes=embeddings,
|
|
1800
|
+
embedding_id=self.logging_table_name,
|
|
1801
|
+
tbe_id=self.uuid,
|
|
1802
|
+
)
|
|
1803
|
+
stats_reporter.report_data_amount(
|
|
1804
|
+
iteration_step=self.step,
|
|
1805
|
+
event_name="tbe.hbm.optimizer_states",
|
|
1806
|
+
data_bytes=optimizer_states,
|
|
1807
|
+
embedding_id=self.logging_table_name,
|
|
1808
|
+
tbe_id=self.uuid,
|
|
1809
|
+
)
|
|
1810
|
+
stats_reporter.report_data_amount(
|
|
1811
|
+
iteration_step=self.step,
|
|
1812
|
+
event_name="tbe.hbm.cache",
|
|
1813
|
+
data_bytes=cache,
|
|
1814
|
+
embedding_id=self.logging_table_name,
|
|
1815
|
+
tbe_id=self.uuid,
|
|
1816
|
+
)
|
|
1817
|
+
stats_reporter.report_data_amount(
|
|
1818
|
+
iteration_step=self.step,
|
|
1819
|
+
event_name="tbe.hbm.cache_weights",
|
|
1820
|
+
data_bytes=cache_weights,
|
|
1821
|
+
embedding_id=self.logging_table_name,
|
|
1822
|
+
tbe_id=self.uuid,
|
|
1823
|
+
)
|
|
1824
|
+
stats_reporter.report_data_amount(
|
|
1825
|
+
iteration_step=self.step,
|
|
1826
|
+
event_name="tbe.hbm.cache_aux",
|
|
1827
|
+
data_bytes=cache_aux,
|
|
1828
|
+
embedding_id=self.logging_table_name,
|
|
1829
|
+
tbe_id=self.uuid,
|
|
1830
|
+
)
|
|
1831
|
+
stats_reporter.report_data_amount(
|
|
1832
|
+
iteration_step=self.step,
|
|
1833
|
+
event_name="tbe.hbm.total_static_sparse",
|
|
1834
|
+
data_bytes=total_static_sparse,
|
|
1835
|
+
embedding_id=self.logging_table_name,
|
|
1836
|
+
tbe_id=self.uuid,
|
|
1837
|
+
)
|
|
1838
|
+
stats_reporter.report_data_amount(
|
|
1839
|
+
iteration_step=self.step,
|
|
1840
|
+
event_name="tbe.hbm.ephemeral",
|
|
1841
|
+
data_bytes=ephemeral,
|
|
1842
|
+
embedding_id=self.logging_table_name,
|
|
1843
|
+
tbe_id=self.uuid,
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
def _report_uvm_breakdown(
|
|
1847
|
+
self,
|
|
1848
|
+
stats_reporter: TBEStatsReporter,
|
|
1849
|
+
embeddings: int,
|
|
1850
|
+
optimizer_states: int,
|
|
1851
|
+
cache: int,
|
|
1852
|
+
total_static_sparse: int,
|
|
1853
|
+
ephemeral: int,
|
|
1854
|
+
cache_weights: int = 0,
|
|
1855
|
+
cache_aux: int = 0,
|
|
1856
|
+
) -> None:
|
|
1857
|
+
"""Report UVM memory breakdown to stats reporter."""
|
|
1858
|
+
stats_reporter.report_data_amount(
|
|
1859
|
+
iteration_step=self.step,
|
|
1860
|
+
event_name="tbe.uvm.embeddings",
|
|
1861
|
+
data_bytes=embeddings,
|
|
1862
|
+
embedding_id=self.logging_table_name,
|
|
1863
|
+
tbe_id=self.uuid,
|
|
1864
|
+
)
|
|
1865
|
+
stats_reporter.report_data_amount(
|
|
1866
|
+
iteration_step=self.step,
|
|
1867
|
+
event_name="tbe.uvm.optimizer_states",
|
|
1868
|
+
data_bytes=optimizer_states,
|
|
1869
|
+
embedding_id=self.logging_table_name,
|
|
1870
|
+
tbe_id=self.uuid,
|
|
1871
|
+
)
|
|
1872
|
+
stats_reporter.report_data_amount(
|
|
1873
|
+
iteration_step=self.step,
|
|
1874
|
+
event_name="tbe.uvm.cache",
|
|
1875
|
+
data_bytes=cache,
|
|
1876
|
+
embedding_id=self.logging_table_name,
|
|
1877
|
+
tbe_id=self.uuid,
|
|
1878
|
+
)
|
|
1879
|
+
stats_reporter.report_data_amount(
|
|
1880
|
+
iteration_step=self.step,
|
|
1881
|
+
event_name="tbe.uvm.cache_weights",
|
|
1882
|
+
data_bytes=cache_weights,
|
|
1883
|
+
embedding_id=self.logging_table_name,
|
|
1884
|
+
tbe_id=self.uuid,
|
|
1885
|
+
)
|
|
1886
|
+
stats_reporter.report_data_amount(
|
|
1887
|
+
iteration_step=self.step,
|
|
1888
|
+
event_name="tbe.uvm.cache_aux",
|
|
1889
|
+
data_bytes=cache_aux,
|
|
1890
|
+
embedding_id=self.logging_table_name,
|
|
1891
|
+
tbe_id=self.uuid,
|
|
1892
|
+
)
|
|
1893
|
+
stats_reporter.report_data_amount(
|
|
1894
|
+
iteration_step=self.step,
|
|
1895
|
+
event_name="tbe.uvm.total_static_sparse",
|
|
1896
|
+
data_bytes=total_static_sparse,
|
|
1897
|
+
embedding_id=self.logging_table_name,
|
|
1898
|
+
tbe_id=self.uuid,
|
|
1899
|
+
)
|
|
1900
|
+
stats_reporter.report_data_amount(
|
|
1901
|
+
iteration_step=self.step,
|
|
1902
|
+
event_name="tbe.uvm.ephemeral",
|
|
1903
|
+
data_bytes=ephemeral,
|
|
1904
|
+
embedding_id=self.logging_table_name,
|
|
1905
|
+
tbe_id=self.uuid,
|
|
1906
|
+
)
|
|
1907
|
+
|
|
1908
|
+
@torch.jit.ignore
|
|
1909
|
+
def _report_tbe_mem_usage(self) -> None:
|
|
1621
1910
|
if self.stats_reporter is None:
|
|
1622
1911
|
return
|
|
1623
1912
|
|
|
@@ -1625,22 +1914,24 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1625
1914
|
if not stats_reporter.should_report(self.step):
|
|
1626
1915
|
return
|
|
1627
1916
|
|
|
1917
|
+
# Calculate total memory from all parameters and buffers (always needed)
|
|
1628
1918
|
total_mem_usage = sum(
|
|
1629
|
-
|
|
1630
|
-
) + sum(
|
|
1919
|
+
p.numel() * p.element_size() for p in self.parameters()
|
|
1920
|
+
) + sum(b.numel() * b.element_size() for b in self.buffers())
|
|
1921
|
+
|
|
1922
|
+
# Calculate total HBM and UVM usage (always needed)
|
|
1631
1923
|
if self.use_cpu:
|
|
1632
1924
|
total_hbm_usage = 0
|
|
1633
1925
|
total_uvm_usage = total_mem_usage
|
|
1634
1926
|
else:
|
|
1635
|
-
# hbm usage is total usage minus uvm usage
|
|
1636
1927
|
total_uvm_usage = sum(
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
if hasattr(self, tensor_name)
|
|
1928
|
+
self._get_tensor_memory(name)
|
|
1929
|
+
for name in self._uvm_tensors_log
|
|
1930
|
+
if hasattr(self, name)
|
|
1641
1931
|
)
|
|
1642
1932
|
total_hbm_usage = total_mem_usage - total_uvm_usage
|
|
1643
1933
|
|
|
1934
|
+
# Report total memory usage metrics (always reported for backward compatibility)
|
|
1644
1935
|
stats_reporter.report_data_amount(
|
|
1645
1936
|
iteration_step=self.step,
|
|
1646
1937
|
event_name="tbe.total_hbm_usage",
|
|
@@ -1656,6 +1947,96 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1656
1947
|
tbe_id=self.uuid,
|
|
1657
1948
|
)
|
|
1658
1949
|
|
|
1950
|
+
# Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
|
|
1951
|
+
# steps since static sparse memory (weights, optimizer states, cache) is constant
|
|
1952
|
+
if (
|
|
1953
|
+
self.detailed_mem_breakdown_report_count
|
|
1954
|
+
>= self.max_detailed_mem_breakdown_reports
|
|
1955
|
+
):
|
|
1956
|
+
return
|
|
1957
|
+
self.detailed_mem_breakdown_report_count += 1
|
|
1958
|
+
|
|
1959
|
+
# Tensor groups for sparse memory categorization
|
|
1960
|
+
weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
|
|
1961
|
+
optimizer_tensors = [
|
|
1962
|
+
"momentum1_dev",
|
|
1963
|
+
"momentum1_host",
|
|
1964
|
+
"momentum1_uvm",
|
|
1965
|
+
"momentum2_dev",
|
|
1966
|
+
"momentum2_host",
|
|
1967
|
+
"momentum2_uvm",
|
|
1968
|
+
]
|
|
1969
|
+
# Cache weights tensor (the actual cached embeddings in HBM)
|
|
1970
|
+
cache_weight_tensors = [
|
|
1971
|
+
"lxu_cache_weights",
|
|
1972
|
+
]
|
|
1973
|
+
# Cache auxiliary state tensors (metadata for cache management, excluding weights)
|
|
1974
|
+
# Sizes scale with hash_size or cache_slots (hash_size × clf)
|
|
1975
|
+
# Excludes constant-size tensors: cache_hash_size_cumsum, cache_miss_counter, etc.
|
|
1976
|
+
cache_aux_tensors = [
|
|
1977
|
+
"cache_index_table_map", # int32, 4B × hash_size
|
|
1978
|
+
"lxu_cache_state", # int64, 8B × cache_slots
|
|
1979
|
+
"lxu_state", # int64, 8B × cache_slots (LRU) or hash_size (LFU)
|
|
1980
|
+
"lxu_cache_locking_counter", # int32, 4B × cache_slots (only if prefetch_pipeline)
|
|
1981
|
+
]
|
|
1982
|
+
|
|
1983
|
+
# Calculate total memory for each component
|
|
1984
|
+
weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
|
|
1985
|
+
optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
|
|
1986
|
+
cache_weights_total = sum(
|
|
1987
|
+
self._get_tensor_memory(t) for t in cache_weight_tensors
|
|
1988
|
+
)
|
|
1989
|
+
cache_aux_total = sum(self._get_tensor_memory(t) for t in cache_aux_tensors)
|
|
1990
|
+
|
|
1991
|
+
# Categorize memory by location (HBM vs UVM)
|
|
1992
|
+
if self.use_cpu:
|
|
1993
|
+
weights_hbm, weights_uvm = 0, weights_total
|
|
1994
|
+
opt_hbm, opt_uvm = 0, optimizer_total
|
|
1995
|
+
cache_weights_hbm, cache_weights_uvm = 0, cache_weights_total
|
|
1996
|
+
cache_aux_hbm, cache_aux_uvm = 0, cache_aux_total
|
|
1997
|
+
else:
|
|
1998
|
+
weights_hbm, weights_uvm = self._categorize_memory_by_location(
|
|
1999
|
+
weight_tensors
|
|
2000
|
+
)
|
|
2001
|
+
opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
|
|
2002
|
+
cache_weights_hbm, cache_weights_uvm = self._categorize_memory_by_location(
|
|
2003
|
+
cache_weight_tensors
|
|
2004
|
+
)
|
|
2005
|
+
cache_aux_hbm, cache_aux_uvm = self._categorize_memory_by_location(
|
|
2006
|
+
cache_aux_tensors
|
|
2007
|
+
)
|
|
2008
|
+
|
|
2009
|
+
# Calculate ephemeral memory split between HBM and UVM
|
|
2010
|
+
# Total cache = cache weights + cache auxiliary state
|
|
2011
|
+
cache_hbm = cache_weights_hbm + cache_aux_hbm
|
|
2012
|
+
cache_uvm = cache_weights_uvm + cache_aux_uvm
|
|
2013
|
+
static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
|
|
2014
|
+
static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
|
|
2015
|
+
ephemeral_hbm = total_hbm_usage - static_sparse_hbm
|
|
2016
|
+
ephemeral_uvm = total_uvm_usage - static_sparse_uvm
|
|
2017
|
+
|
|
2018
|
+
# Report granular memory breakdowns
|
|
2019
|
+
self._report_hbm_breakdown(
|
|
2020
|
+
stats_reporter,
|
|
2021
|
+
weights_hbm,
|
|
2022
|
+
opt_hbm,
|
|
2023
|
+
cache_hbm,
|
|
2024
|
+
static_sparse_hbm,
|
|
2025
|
+
ephemeral_hbm,
|
|
2026
|
+
cache_weights_hbm,
|
|
2027
|
+
cache_aux_hbm,
|
|
2028
|
+
)
|
|
2029
|
+
self._report_uvm_breakdown(
|
|
2030
|
+
stats_reporter,
|
|
2031
|
+
weights_uvm,
|
|
2032
|
+
opt_uvm,
|
|
2033
|
+
cache_uvm,
|
|
2034
|
+
static_sparse_uvm,
|
|
2035
|
+
ephemeral_uvm,
|
|
2036
|
+
cache_weights_uvm,
|
|
2037
|
+
cache_aux_uvm,
|
|
2038
|
+
)
|
|
2039
|
+
|
|
1659
2040
|
@torch.jit.ignore
|
|
1660
2041
|
def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
|
|
1661
2042
|
if self.stats_reporter is None:
|
|
@@ -1682,7 +2063,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1682
2063
|
def _generate_vbe_metadata(
|
|
1683
2064
|
self,
|
|
1684
2065
|
offsets: Tensor,
|
|
1685
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2066
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2067
|
+
vbe_output: Optional[Tensor] = None,
|
|
2068
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1686
2069
|
) -> invokers.lookup_args.VBEMetadata:
|
|
1687
2070
|
# Blocking D2H copy, but only runs at first call
|
|
1688
2071
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -1705,6 +2088,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1705
2088
|
self.pooling_mode,
|
|
1706
2089
|
self.feature_dims,
|
|
1707
2090
|
self.current_device,
|
|
2091
|
+
vbe_output,
|
|
2092
|
+
vbe_output_offsets,
|
|
1708
2093
|
)
|
|
1709
2094
|
|
|
1710
2095
|
@torch.jit.ignore
|
|
@@ -1713,14 +2098,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1713
2098
|
# This allows models using this class to compile correctly
|
|
1714
2099
|
return FeatureGate.is_enabled(feature)
|
|
1715
2100
|
|
|
2101
|
+
# pyre-fixme[2]: For 1st argument expected not ANY
|
|
2102
|
+
def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
|
|
2103
|
+
indices = self._indices
|
|
2104
|
+
offsets = self._offsets
|
|
2105
|
+
return writeback_gradient(
|
|
2106
|
+
grad,
|
|
2107
|
+
indices,
|
|
2108
|
+
offsets,
|
|
2109
|
+
self.feature_table_map,
|
|
2110
|
+
self._writeback_first_feature_only,
|
|
2111
|
+
)
|
|
2112
|
+
|
|
1716
2113
|
def forward( # noqa: C901
|
|
1717
2114
|
self,
|
|
1718
2115
|
indices: Tensor,
|
|
1719
2116
|
offsets: Tensor,
|
|
1720
2117
|
per_sample_weights: Optional[Tensor] = None,
|
|
1721
2118
|
feature_requires_grad: Optional[Tensor] = None,
|
|
1722
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2119
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
1723
2120
|
total_unique_indices: Optional[int] = None,
|
|
2121
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2122
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2123
|
+
vbe_output: Optional[Tensor] = None,
|
|
2124
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1724
2125
|
) -> Tensor:
|
|
1725
2126
|
"""
|
|
1726
2127
|
The forward pass function that
|
|
@@ -1773,7 +2174,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1773
2174
|
be set when using `OptimType.NONE`. This is because TBE
|
|
1774
2175
|
requires this information for allocating the weight gradient
|
|
1775
2176
|
tensor in the backward pass.
|
|
1776
|
-
|
|
2177
|
+
hash_zch_identities (Optional[Tensor]): The original raw IDs before
|
|
2178
|
+
remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
|
|
2179
|
+
populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
|
|
2180
|
+
and is required for Raw Embedding Streaming (RES) to maintain
|
|
2181
|
+
consistency between training and inference.
|
|
2182
|
+
vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
|
|
2183
|
+
contains output for TBE VBE. The shape of the tensor is
|
|
2184
|
+
[1, total_vbe_output_size] where total_vbe_output_size is the
|
|
2185
|
+
output size across all ranks and all embedding tables.
|
|
2186
|
+
If this tensor is not None, the TBE VBE forward output is written
|
|
2187
|
+
to this tensor at the locations specified by `vbe_output_offsets`.
|
|
2188
|
+
vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
|
|
2189
|
+
contains VBE output offsets to `vbe_output`. The shape of the
|
|
2190
|
+
tensor is [num_ranks, num_features].
|
|
2191
|
+
vbe_output_offsets[r][f] represents the starting offset for rank `r`
|
|
2192
|
+
and feature `f`.
|
|
1777
2193
|
Returns:
|
|
1778
2194
|
A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
|
|
1779
2195
|
batch size and `total_D` = the sum of all embedding dimensions in the
|
|
@@ -1846,11 +2262,35 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1846
2262
|
per_sample_weights,
|
|
1847
2263
|
batch_size_per_feature_per_rank,
|
|
1848
2264
|
force_cast_input_types=True,
|
|
2265
|
+
prefetch_pipeline=False,
|
|
2266
|
+
vbe_output=vbe_output,
|
|
2267
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
1849
2268
|
)
|
|
1850
2269
|
|
|
2270
|
+
# Only enable VBE if batch_size_per_feature_per_rank is not None
|
|
2271
|
+
assert not (
|
|
2272
|
+
batch_size_per_feature_per_rank is not None
|
|
2273
|
+
and self.use_writeback_bwd_prehook
|
|
2274
|
+
), "VBE is not supported with writeback_bwd_prehook"
|
|
2275
|
+
|
|
1851
2276
|
# Print input stats if enable (for debugging purpose only)
|
|
1852
2277
|
self._debug_print_input_stats(indices, offsets, per_sample_weights)
|
|
1853
2278
|
|
|
2279
|
+
# Extract and Write input stats if enable
|
|
2280
|
+
if self._report_input_params is not None:
|
|
2281
|
+
self._report_input_params(
|
|
2282
|
+
feature_rows=self.rows_per_table,
|
|
2283
|
+
feature_dims=self.feature_dims,
|
|
2284
|
+
iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
|
|
2285
|
+
indices=indices,
|
|
2286
|
+
offsets=offsets,
|
|
2287
|
+
op_id=self.uuid,
|
|
2288
|
+
per_sample_weights=per_sample_weights,
|
|
2289
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2290
|
+
embedding_specs=[(s[0], s[1]) for s in self.embedding_specs],
|
|
2291
|
+
feature_table_map=self.feature_table_map,
|
|
2292
|
+
)
|
|
2293
|
+
|
|
1854
2294
|
if not is_torchdynamo_compiling():
|
|
1855
2295
|
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
1856
2296
|
|
|
@@ -1878,7 +2318,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1878
2318
|
# to be as fast as possible and memory usage doesn't matter (will be recycled
|
|
1879
2319
|
# by dense fwd/bwd)
|
|
1880
2320
|
self._prefetch(
|
|
1881
|
-
indices,
|
|
2321
|
+
indices,
|
|
2322
|
+
offsets,
|
|
2323
|
+
vbe_metadata,
|
|
2324
|
+
multipass_prefetch_config=None,
|
|
2325
|
+
hash_zch_identities=hash_zch_identities,
|
|
2326
|
+
hash_zch_runtime_meta=hash_zch_runtime_meta,
|
|
1882
2327
|
)
|
|
1883
2328
|
|
|
1884
2329
|
if len(self.timesteps_prefetched) > 0:
|
|
@@ -1936,6 +2381,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1936
2381
|
is_experimental=self.is_experimental,
|
|
1937
2382
|
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
|
|
1938
2383
|
use_homogeneous_placements=self.use_homogeneous_placements,
|
|
2384
|
+
learning_rate_tensor=self.learning_rate_tensor,
|
|
2385
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
2386
|
+
info_B_mask=self.info_B_mask,
|
|
1939
2387
|
)
|
|
1940
2388
|
|
|
1941
2389
|
if self.optimizer == OptimType.NONE:
|
|
@@ -2048,7 +2496,6 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2048
2496
|
momentum1,
|
|
2049
2497
|
momentum2,
|
|
2050
2498
|
iter_int,
|
|
2051
|
-
self.use_rowwise_bias_correction,
|
|
2052
2499
|
row_counter=(
|
|
2053
2500
|
row_counter if self.use_rowwise_bias_correction else None
|
|
2054
2501
|
),
|
|
@@ -2158,6 +2605,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2158
2605
|
row_counter,
|
|
2159
2606
|
iter_int,
|
|
2160
2607
|
self.max_counter.item(),
|
|
2608
|
+
mixed_D=self.mixed_D,
|
|
2161
2609
|
),
|
|
2162
2610
|
)
|
|
2163
2611
|
elif self._used_rowwise_adagrad_with_global_weight_decay:
|
|
@@ -2176,6 +2624,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2176
2624
|
# `Optional[Tensor]` but got `Union[Module, Tensor]`.
|
|
2177
2625
|
prev_iter_dev=self.prev_iter_dev,
|
|
2178
2626
|
gwd_lower_bound=self.gwd_lower_bound,
|
|
2627
|
+
mixed_D=self.mixed_D,
|
|
2179
2628
|
),
|
|
2180
2629
|
)
|
|
2181
2630
|
else:
|
|
@@ -2185,12 +2634,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2185
2634
|
common_args,
|
|
2186
2635
|
self.optimizer_args,
|
|
2187
2636
|
momentum1,
|
|
2637
|
+
mixed_D=self.mixed_D,
|
|
2188
2638
|
),
|
|
2189
2639
|
)
|
|
2190
2640
|
|
|
2191
2641
|
raise ValueError(f"Invalid OptimType: {self.optimizer}")
|
|
2192
2642
|
|
|
2193
|
-
def ema_inplace(self, emainplace_mode:
|
|
2643
|
+
def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
|
|
2194
2644
|
"""
|
|
2195
2645
|
Perform ema operations on the full sparse embedding tables.
|
|
2196
2646
|
We organize the sparse table, in the following way.
|
|
@@ -2220,7 +2670,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2220
2670
|
emainplace_mode["step_ema_coef"],
|
|
2221
2671
|
)
|
|
2222
2672
|
|
|
2223
|
-
def ensemble_and_swap(self, ensemble_mode:
|
|
2673
|
+
def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
|
|
2224
2674
|
"""
|
|
2225
2675
|
Perform ensemble and swap operations on the full sparse embedding tables.
|
|
2226
2676
|
|
|
@@ -2268,7 +2718,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2268
2718
|
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
2269
2719
|
return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
|
|
2270
2720
|
|
|
2271
|
-
def _get_uvm_cache_print_state(self, use_local_cache: bool = False) ->
|
|
2721
|
+
def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
|
|
2272
2722
|
snapshot = self.get_uvm_cache_stats(use_local_cache)
|
|
2273
2723
|
if use_local_cache:
|
|
2274
2724
|
return snapshot.tolist()
|
|
@@ -2281,7 +2731,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2281
2731
|
@torch.jit.ignore
|
|
2282
2732
|
def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
|
|
2283
2733
|
# TODO: Create a separate reporter class to unify the stdlog reporting
|
|
2284
|
-
uvm_cache_stats:
|
|
2734
|
+
uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
|
|
2285
2735
|
N = max(1, uvm_cache_stats[0])
|
|
2286
2736
|
m = {
|
|
2287
2737
|
"N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
|
|
@@ -2325,14 +2775,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2325
2775
|
if not stats_reporter.should_report(self.step):
|
|
2326
2776
|
return
|
|
2327
2777
|
|
|
2328
|
-
uvm_cache_stats:
|
|
2778
|
+
uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
|
|
2329
2779
|
use_local_cache=False
|
|
2330
2780
|
).tolist()
|
|
2331
2781
|
self.last_reported_step = self.step
|
|
2332
2782
|
|
|
2333
2783
|
if len(self.last_reported_uvm_stats) == 0:
|
|
2334
2784
|
self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
|
|
2335
|
-
uvm_cache_stats_delta:
|
|
2785
|
+
uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
|
|
2336
2786
|
for i in range(len(uvm_cache_stats)):
|
|
2337
2787
|
uvm_cache_stats_delta[i] = (
|
|
2338
2788
|
uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
|
|
@@ -2361,7 +2811,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2361
2811
|
indices: Tensor,
|
|
2362
2812
|
offsets: Tensor,
|
|
2363
2813
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
2364
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2814
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2365
2815
|
) -> None:
|
|
2366
2816
|
if self.prefetch_stream is None and forward_stream is not None:
|
|
2367
2817
|
self.prefetch_stream = torch.cuda.current_stream()
|
|
@@ -2369,19 +2819,21 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2369
2819
|
self.prefetch_stream != forward_stream
|
|
2370
2820
|
), "prefetch_stream and forward_stream should not be the same stream"
|
|
2371
2821
|
|
|
2372
|
-
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
2373
|
-
indices,
|
|
2374
|
-
offsets,
|
|
2375
|
-
per_sample_weights=None,
|
|
2376
|
-
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2377
|
-
force_cast_input_types=False,
|
|
2378
|
-
)
|
|
2379
|
-
|
|
2380
2822
|
with self._recording_to_timer(
|
|
2381
2823
|
self.prefetch_duration_timer,
|
|
2382
2824
|
context=self.step,
|
|
2383
2825
|
stream=torch.cuda.current_stream(),
|
|
2384
2826
|
):
|
|
2827
|
+
|
|
2828
|
+
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
2829
|
+
indices,
|
|
2830
|
+
offsets,
|
|
2831
|
+
per_sample_weights=None,
|
|
2832
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2833
|
+
force_cast_input_types=False,
|
|
2834
|
+
prefetch_pipeline=self.prefetch_pipeline,
|
|
2835
|
+
)
|
|
2836
|
+
|
|
2385
2837
|
self._prefetch(
|
|
2386
2838
|
indices,
|
|
2387
2839
|
offsets,
|
|
@@ -2398,6 +2850,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2398
2850
|
offsets: Tensor,
|
|
2399
2851
|
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
2400
2852
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
2853
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2854
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2401
2855
|
) -> None:
|
|
2402
2856
|
if not is_torchdynamo_compiling():
|
|
2403
2857
|
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
@@ -2416,7 +2870,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2416
2870
|
self.local_uvm_cache_stats.zero_()
|
|
2417
2871
|
self._report_io_size_count("prefetch_input", indices)
|
|
2418
2872
|
|
|
2873
|
+
# streaming before updating the cache
|
|
2874
|
+
self.raw_embedding_stream()
|
|
2875
|
+
|
|
2419
2876
|
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
|
|
2877
|
+
linear_cache_indices_merged = torch.zeros(
|
|
2878
|
+
0, dtype=indices.dtype, device=indices.device
|
|
2879
|
+
)
|
|
2420
2880
|
for (
|
|
2421
2881
|
partial_indices,
|
|
2422
2882
|
partial_lxu_cache_locations,
|
|
@@ -2432,6 +2892,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2432
2892
|
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
2433
2893
|
base_offset,
|
|
2434
2894
|
)
|
|
2895
|
+
linear_cache_indices_merged = torch.cat(
|
|
2896
|
+
[linear_cache_indices_merged, linear_cache_indices]
|
|
2897
|
+
)
|
|
2435
2898
|
|
|
2436
2899
|
if (
|
|
2437
2900
|
self.record_cache_metrics.record_cache_miss_counter
|
|
@@ -2512,6 +2975,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2512
2975
|
if self.should_log():
|
|
2513
2976
|
self.print_uvm_cache_stats(use_local_cache=False)
|
|
2514
2977
|
|
|
2978
|
+
self._store_prefetched_tensors(
|
|
2979
|
+
indices,
|
|
2980
|
+
offsets,
|
|
2981
|
+
vbe_metadata,
|
|
2982
|
+
linear_cache_indices_merged,
|
|
2983
|
+
final_lxu_cache_locations,
|
|
2984
|
+
hash_zch_identities,
|
|
2985
|
+
hash_zch_runtime_meta,
|
|
2986
|
+
)
|
|
2987
|
+
|
|
2515
2988
|
def should_log(self) -> bool:
|
|
2516
2989
|
"""Determines if we should log for this step, using exponentially decreasing frequency.
|
|
2517
2990
|
|
|
@@ -2596,12 +3069,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2596
3069
|
tmp_emb.uniform_(min_val, max_val)
|
|
2597
3070
|
tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
|
|
2598
3071
|
emb.data.copy_(tmp_emb_i8)
|
|
3072
|
+
# Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
|
|
3073
|
+
elif self.weights_precision == SparseType.NFP8:
|
|
3074
|
+
assert (
|
|
3075
|
+
self.current_device.type == "cuda"
|
|
3076
|
+
), "NFP8 is currently only supportd on GPU."
|
|
3077
|
+
assert self.optimizer in [
|
|
3078
|
+
OptimType.EXACT_ADAGRAD,
|
|
3079
|
+
OptimType.ROWWISE_ADAGRAD,
|
|
3080
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
3081
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
3082
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
3083
|
+
], "NFP8 is currently only supportd with adagrad optimizers."
|
|
3084
|
+
for param in splits:
|
|
3085
|
+
tmp_param = torch.zeros(param.shape, device=self.current_device)
|
|
3086
|
+
# Create initialized weights and cast to fp8.
|
|
3087
|
+
fp8_dtype = (
|
|
3088
|
+
torch.float8_e4m3fnuz
|
|
3089
|
+
if torch.version.hip is not None
|
|
3090
|
+
else torch.float8_e4m3fn
|
|
3091
|
+
)
|
|
3092
|
+
tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
|
|
3093
|
+
param.data.copy_(tmp_param)
|
|
2599
3094
|
else:
|
|
2600
3095
|
for param in splits:
|
|
2601
3096
|
param.uniform_(min_val, max_val)
|
|
2602
3097
|
|
|
2603
3098
|
@torch.jit.ignore
|
|
2604
|
-
def split_embedding_weights(self) ->
|
|
3099
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
2605
3100
|
"""
|
|
2606
3101
|
Returns a list of embedding weights (view), split by table
|
|
2607
3102
|
|
|
@@ -2643,7 +3138,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2643
3138
|
raise ValueError(f"Optimizer buffer {state} not found")
|
|
2644
3139
|
|
|
2645
3140
|
@torch.jit.export
|
|
2646
|
-
def get_optimizer_state(self) ->
|
|
3141
|
+
def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
|
|
2647
3142
|
r"""
|
|
2648
3143
|
Get the optimizer state dict that matches the OSS Pytorch optims
|
|
2649
3144
|
TODO: populate the supported list of optimizers
|
|
@@ -2656,7 +3151,23 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2656
3151
|
):
|
|
2657
3152
|
list_of_state_dict = [
|
|
2658
3153
|
(
|
|
2659
|
-
|
|
3154
|
+
(
|
|
3155
|
+
{
|
|
3156
|
+
"sum": states[0],
|
|
3157
|
+
"prev_iter": states[1],
|
|
3158
|
+
"row_counter": states[2],
|
|
3159
|
+
"iter": self.iter,
|
|
3160
|
+
}
|
|
3161
|
+
if self.optimizer_args.regularization_mode
|
|
3162
|
+
== WeightDecayMode.COUNTER.value
|
|
3163
|
+
and self.optimizer_args.weight_decay_mode
|
|
3164
|
+
== CounterWeightDecayMode.ADAGRADW.value
|
|
3165
|
+
else {
|
|
3166
|
+
"sum": states[0],
|
|
3167
|
+
"prev_iter": states[1],
|
|
3168
|
+
"row_counter": states[2],
|
|
3169
|
+
}
|
|
3170
|
+
)
|
|
2660
3171
|
if self._used_rowwise_adagrad_with_counter
|
|
2661
3172
|
else (
|
|
2662
3173
|
{
|
|
@@ -2711,7 +3222,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2711
3222
|
@torch.jit.ignore
|
|
2712
3223
|
def split_optimizer_states(
|
|
2713
3224
|
self,
|
|
2714
|
-
) ->
|
|
3225
|
+
) -> list[list[torch.Tensor]]:
|
|
2715
3226
|
"""
|
|
2716
3227
|
Returns a list of optimizer states (view), split by table
|
|
2717
3228
|
|
|
@@ -2759,7 +3270,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2759
3270
|
state_offsets: Tensor,
|
|
2760
3271
|
state_placements: Tensor,
|
|
2761
3272
|
rowwise: bool,
|
|
2762
|
-
) ->
|
|
3273
|
+
) -> list[torch.Tensor]:
|
|
2763
3274
|
splits = []
|
|
2764
3275
|
for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
|
|
2765
3276
|
offset = state_offsets[t]
|
|
@@ -2778,7 +3289,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2778
3289
|
splits.append(state.detach()[offset : offset + rows].view(rows))
|
|
2779
3290
|
return splits
|
|
2780
3291
|
|
|
2781
|
-
states:
|
|
3292
|
+
states: list[list[torch.Tensor]] = []
|
|
2782
3293
|
if self.optimizer not in (OptimType.EXACT_SGD,):
|
|
2783
3294
|
states.append(
|
|
2784
3295
|
get_optimizer_states(
|
|
@@ -2899,15 +3410,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2899
3410
|
|
|
2900
3411
|
def get_learning_rate(self) -> float:
|
|
2901
3412
|
"""
|
|
2902
|
-
|
|
2903
|
-
|
|
2904
|
-
Args:
|
|
2905
|
-
lr (float): The learning rate value to set to
|
|
3413
|
+
Get and return the learning rate.
|
|
2906
3414
|
"""
|
|
2907
|
-
return self.
|
|
3415
|
+
return self.learning_rate_tensor.item()
|
|
2908
3416
|
|
|
2909
3417
|
@torch.jit.ignore
|
|
2910
|
-
def update_hyper_parameters(self, params_dict:
|
|
3418
|
+
def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
|
|
2911
3419
|
"""
|
|
2912
3420
|
Sets hyper-parameters from external control flow.
|
|
2913
3421
|
|
|
@@ -2943,7 +3451,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2943
3451
|
Helper function to script `set_learning_rate`.
|
|
2944
3452
|
Note that returning None does not work.
|
|
2945
3453
|
"""
|
|
2946
|
-
self.
|
|
3454
|
+
self.learning_rate_tensor.fill_(lr)
|
|
2947
3455
|
return 0.0
|
|
2948
3456
|
|
|
2949
3457
|
@torch.jit.ignore
|
|
@@ -2983,10 +3491,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2983
3491
|
self,
|
|
2984
3492
|
split: SplitState,
|
|
2985
3493
|
prefix: str,
|
|
2986
|
-
dtype:
|
|
3494
|
+
dtype: type[torch.dtype],
|
|
2987
3495
|
enforce_hbm: bool = False,
|
|
2988
3496
|
make_dev_param: bool = False,
|
|
2989
|
-
dev_reshape: Optional[
|
|
3497
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
2990
3498
|
uvm_host_mapped: bool = False,
|
|
2991
3499
|
) -> None:
|
|
2992
3500
|
apply_split_helper(
|
|
@@ -3036,6 +3544,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3036
3544
|
dtype = torch.float32
|
|
3037
3545
|
elif cache_precision == SparseType.FP16:
|
|
3038
3546
|
dtype = torch.float16
|
|
3547
|
+
elif cache_precision == SparseType.NFP8:
|
|
3548
|
+
# NFP8 weights use floating point cache.
|
|
3549
|
+
dtype = torch.float16
|
|
3039
3550
|
else:
|
|
3040
3551
|
dtype = torch.float32 # not relevant, but setting it to keep linter happy
|
|
3041
3552
|
if not self.use_cpu > 0:
|
|
@@ -3229,7 +3740,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3229
3740
|
def _update_cache_counter_and_locations(
|
|
3230
3741
|
self,
|
|
3231
3742
|
module: nn.Module,
|
|
3232
|
-
grad_input: Union[
|
|
3743
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
3233
3744
|
) -> None:
|
|
3234
3745
|
"""
|
|
3235
3746
|
Backward prehook function when prefetch_pipeline is enabled.
|
|
@@ -3425,9 +3936,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3425
3936
|
indices: Tensor,
|
|
3426
3937
|
offsets: Tensor,
|
|
3427
3938
|
per_sample_weights: Optional[Tensor] = None,
|
|
3428
|
-
batch_size_per_feature_per_rank: Optional[
|
|
3939
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3429
3940
|
force_cast_input_types: bool = True,
|
|
3430
|
-
|
|
3941
|
+
prefetch_pipeline: bool = False,
|
|
3942
|
+
vbe_output: Optional[Tensor] = None,
|
|
3943
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3944
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3431
3945
|
"""
|
|
3432
3946
|
Prepare TBE inputs as follows:
|
|
3433
3947
|
|
|
@@ -3453,11 +3967,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3453
3967
|
metadata
|
|
3454
3968
|
"""
|
|
3455
3969
|
|
|
3970
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
3971
|
+
assert (
|
|
3972
|
+
not self.use_cpu
|
|
3973
|
+
), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
|
|
3974
|
+
check_allocated_vbe_output(
|
|
3975
|
+
self.output_dtype,
|
|
3976
|
+
batch_size_per_feature_per_rank,
|
|
3977
|
+
vbe_output,
|
|
3978
|
+
vbe_output_offsets,
|
|
3979
|
+
)
|
|
3980
|
+
|
|
3456
3981
|
# Generate VBE metadata
|
|
3457
3982
|
vbe_metadata = self._generate_vbe_metadata(
|
|
3458
|
-
offsets, batch_size_per_feature_per_rank
|
|
3983
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
3459
3984
|
)
|
|
3460
3985
|
|
|
3986
|
+
vbe = vbe_metadata.B_offsets is not None
|
|
3987
|
+
# Note this check has already been done in C++ side
|
|
3988
|
+
# TODO: max_B <= self.info_B_mask in python
|
|
3989
|
+
# We cannot use assert as it breaks pt2 compile for dynamic shape
|
|
3990
|
+
# and need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
|
|
3991
|
+
# torch._check(
|
|
3992
|
+
# max_B <= self.info_B_mask,
|
|
3993
|
+
# "Not enough infos bits to accommodate T and B.",
|
|
3994
|
+
# )
|
|
3995
|
+
# We cannot use lambda as it fails jit script.
|
|
3996
|
+
# torch._check is also not supported in jitscript
|
|
3997
|
+
|
|
3461
3998
|
# TODO: remove this and add an assert after updating
|
|
3462
3999
|
# bounds_check_indices to support different indices type and offset
|
|
3463
4000
|
# type
|
|
@@ -3485,10 +4022,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3485
4022
|
per_sample_weights = per_sample_weights.float()
|
|
3486
4023
|
|
|
3487
4024
|
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
|
|
4025
|
+
# Override the bounds check version based on prefetch_pipeline
|
|
4026
|
+
use_bounds_check_v2 = self.bounds_check_version == 2 or prefetch_pipeline
|
|
4027
|
+
bounds_check_version = (
|
|
4028
|
+
2 if use_bounds_check_v2 else self.bounds_check_version
|
|
4029
|
+
)
|
|
4030
|
+
|
|
3488
4031
|
vbe = vbe_metadata.B_offsets is not None
|
|
4032
|
+
|
|
3489
4033
|
# Compute B info and VBE metadata for bounds_check_indices only if
|
|
3490
4034
|
# VBE and bounds check indices v2 are used
|
|
3491
|
-
if vbe and
|
|
4035
|
+
if vbe and use_bounds_check_v2:
|
|
3492
4036
|
B_offsets = vbe_metadata.B_offsets
|
|
3493
4037
|
B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
|
|
3494
4038
|
output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
|
|
@@ -3499,11 +4043,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3499
4043
|
assert isinstance(
|
|
3500
4044
|
output_offsets_feature_rank, Tensor
|
|
3501
4045
|
), "output_offsets_feature_rank must be tensor"
|
|
3502
|
-
|
|
3503
|
-
B_offsets, # unused tensor
|
|
3504
|
-
vbe_metadata.max_B,
|
|
3505
|
-
B_offsets.numel() - 1, # T
|
|
3506
|
-
)
|
|
4046
|
+
|
|
3507
4047
|
row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
|
|
3508
4048
|
B_offsets,
|
|
3509
4049
|
B_offsets_rank_per_feature,
|
|
@@ -3512,13 +4052,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3512
4052
|
self.max_D,
|
|
3513
4053
|
self.is_nobag,
|
|
3514
4054
|
vbe_metadata.max_B_feature_rank,
|
|
3515
|
-
info_B_num_bits,
|
|
3516
|
-
offsets.numel() - 1, # total_B
|
|
4055
|
+
self.info_B_num_bits,
|
|
4056
|
+
offsets.numel() - 1, # total_B,
|
|
4057
|
+
vbe_output_offsets,
|
|
3517
4058
|
)
|
|
3518
4059
|
else:
|
|
3519
4060
|
b_t_map = None
|
|
3520
|
-
info_B_num_bits = -1
|
|
3521
|
-
info_B_mask = -1
|
|
3522
4061
|
|
|
3523
4062
|
torch.ops.fbgemm.bounds_check_indices(
|
|
3524
4063
|
self.rows_per_table,
|
|
@@ -3530,9 +4069,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3530
4069
|
B_offsets=vbe_metadata.B_offsets,
|
|
3531
4070
|
max_B=vbe_metadata.max_B,
|
|
3532
4071
|
b_t_map=b_t_map,
|
|
3533
|
-
info_B_num_bits=info_B_num_bits,
|
|
3534
|
-
info_B_mask=info_B_mask,
|
|
3535
|
-
bounds_check_version=
|
|
4072
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
4073
|
+
info_B_mask=self.info_B_mask,
|
|
4074
|
+
bounds_check_version=bounds_check_version,
|
|
4075
|
+
prefetch_pipeline=prefetch_pipeline,
|
|
3536
4076
|
)
|
|
3537
4077
|
|
|
3538
4078
|
return indices, offsets, per_sample_weights, vbe_metadata
|
|
@@ -3603,7 +4143,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3603
4143
|
# Counts of indices that segment lengths > 1024
|
|
3604
4144
|
counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
|
|
3605
4145
|
|
|
3606
|
-
def compute_numel_and_avg(counts: Tensor) ->
|
|
4146
|
+
def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
|
|
3607
4147
|
numel = counts.numel()
|
|
3608
4148
|
avg = (counts.sum().item() / numel) if numel != 0 else -1.0
|
|
3609
4149
|
return numel, avg
|
|
@@ -3671,6 +4211,240 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3671
4211
|
return _debug_print_input_stats_factory_impl
|
|
3672
4212
|
return _debug_print_input_stats_factory_null
|
|
3673
4213
|
|
|
4214
|
+
@torch.jit.ignore
|
|
4215
|
+
def raw_embedding_stream(self) -> None:
|
|
4216
|
+
if not self.enable_raw_embedding_streaming:
|
|
4217
|
+
return None
|
|
4218
|
+
# when pipelining is enabled
|
|
4219
|
+
# prefetch in iter i happens before the backward sparse in iter i - 1
|
|
4220
|
+
# so embeddings for iter i - 1's changed ids are not updated.
|
|
4221
|
+
# so we can only fetch the indices from the iter i - 2
|
|
4222
|
+
# when pipelining is disabled
|
|
4223
|
+
# prefetch in iter i happens before forward iter i
|
|
4224
|
+
# so we can get the iter i - 1's changed ids safely.
|
|
4225
|
+
target_prev_iter = 1
|
|
4226
|
+
if self.prefetch_pipeline:
|
|
4227
|
+
target_prev_iter = 2
|
|
4228
|
+
if not len(self.prefetched_info_list) > (target_prev_iter - 1):
|
|
4229
|
+
return None
|
|
4230
|
+
with record_function(
|
|
4231
|
+
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4232
|
+
):
|
|
4233
|
+
prefetched_info = self.prefetched_info_list.pop(0)
|
|
4234
|
+
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4235
|
+
prefetched_info.linear_unique_cache_indices,
|
|
4236
|
+
self.lxu_cache_state,
|
|
4237
|
+
self.total_cache_hash_size,
|
|
4238
|
+
gather_cache_stats=False, # not collecting cache stats
|
|
4239
|
+
num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
|
|
4240
|
+
)
|
|
4241
|
+
updated_weights = torch.empty(
|
|
4242
|
+
[
|
|
4243
|
+
prefetched_info.linear_unique_cache_indices.size()[0],
|
|
4244
|
+
self.max_D_cache,
|
|
4245
|
+
],
|
|
4246
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
|
|
4247
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
4248
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
|
|
4249
|
+
device=self.lxu_cache_weights.device,
|
|
4250
|
+
)
|
|
4251
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4252
|
+
updated_weights,
|
|
4253
|
+
updated_locations,
|
|
4254
|
+
self.lxu_cache_weights,
|
|
4255
|
+
prefetched_info.linear_unique_indices_length,
|
|
4256
|
+
)
|
|
4257
|
+
# TODO: this statement triggers a sync
|
|
4258
|
+
# added here to make this diff self-contained
|
|
4259
|
+
# will remove in later change
|
|
4260
|
+
cache_hit_mask_index = (
|
|
4261
|
+
updated_locations.narrow(
|
|
4262
|
+
0, 0, prefetched_info.linear_unique_indices_length.item()
|
|
4263
|
+
)
|
|
4264
|
+
.not_equal(-1)
|
|
4265
|
+
.nonzero()
|
|
4266
|
+
.flatten()
|
|
4267
|
+
)
|
|
4268
|
+
# stream weights
|
|
4269
|
+
self._raw_embedding_streamer.stream(
|
|
4270
|
+
prefetched_info.linear_unique_indices.index_select(
|
|
4271
|
+
dim=0, index=cache_hit_mask_index
|
|
4272
|
+
).to(device=torch.device("cpu")),
|
|
4273
|
+
updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
|
|
4274
|
+
device=torch.device("cpu")
|
|
4275
|
+
),
|
|
4276
|
+
(
|
|
4277
|
+
prefetched_info.hash_zch_identities.index_select(
|
|
4278
|
+
dim=0, index=cache_hit_mask_index
|
|
4279
|
+
).to(device=torch.device("cpu"))
|
|
4280
|
+
if prefetched_info.hash_zch_identities is not None
|
|
4281
|
+
else None
|
|
4282
|
+
),
|
|
4283
|
+
(
|
|
4284
|
+
prefetched_info.hash_zch_runtime_meta.index_select(
|
|
4285
|
+
dim=0, index=cache_hit_mask_index
|
|
4286
|
+
).to(device=torch.device("cpu"))
|
|
4287
|
+
if prefetched_info.hash_zch_runtime_meta is not None
|
|
4288
|
+
else None
|
|
4289
|
+
),
|
|
4290
|
+
prefetched_info.linear_unique_indices_length.to(
|
|
4291
|
+
device=torch.device("cpu")
|
|
4292
|
+
),
|
|
4293
|
+
False, # require_tensor_copy
|
|
4294
|
+
False, # blocking_tensor_copy
|
|
4295
|
+
)
|
|
4296
|
+
|
|
4297
|
+
@staticmethod
|
|
4298
|
+
@torch.jit.ignore
|
|
4299
|
+
def _get_prefetched_info(
|
|
4300
|
+
linear_indices: torch.Tensor,
|
|
4301
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4302
|
+
total_cache_hash_size: int,
|
|
4303
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4304
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4305
|
+
max_indices_length: int,
|
|
4306
|
+
) -> PrefetchedInfo:
|
|
4307
|
+
(
|
|
4308
|
+
linear_unique_cache_indices,
|
|
4309
|
+
linear_unique_cache_indices_length,
|
|
4310
|
+
linear_unique_cache_indices_cnt,
|
|
4311
|
+
linear_unique_cache_inverse_indices,
|
|
4312
|
+
) = torch.ops.fbgemm.get_unique_indices_with_inverse(
|
|
4313
|
+
linear_cache_indices_merged,
|
|
4314
|
+
total_cache_hash_size,
|
|
4315
|
+
compute_count=True,
|
|
4316
|
+
compute_inverse_indices=True,
|
|
4317
|
+
)
|
|
4318
|
+
# pure cpu op, no need to sync, to avoid the indices out size the weights buffer
|
|
4319
|
+
max_len = min(
|
|
4320
|
+
max_indices_length,
|
|
4321
|
+
linear_unique_cache_indices.size(0),
|
|
4322
|
+
)
|
|
4323
|
+
if max_len < linear_unique_cache_indices.size(0):
|
|
4324
|
+
linear_unique_cache_indices_length.clamp_(max=max_len)
|
|
4325
|
+
# linear_unique_indices is the result after deduplication and sorting
|
|
4326
|
+
linear_unique_cache_indices = linear_unique_cache_indices.narrow(
|
|
4327
|
+
0, 0, max_len
|
|
4328
|
+
)
|
|
4329
|
+
# Compute cumulative sum as indices for selecting unique elements to
|
|
4330
|
+
# map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
|
|
4331
|
+
count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
|
|
4332
|
+
linear_unique_cache_indices_cnt
|
|
4333
|
+
)
|
|
4334
|
+
# count_cum_sum will be one more element than linear_unique_cache_indices_cnt
|
|
4335
|
+
count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
|
|
4336
|
+
# clamp the uninitialized elements to avoid out of bound access
|
|
4337
|
+
# the uninitialized elements will be sliced out by linear_unique_cache_indices_length
|
|
4338
|
+
# directly using linear_unique_cache_indices_length requires a sync
|
|
4339
|
+
count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
|
|
4340
|
+
|
|
4341
|
+
# Select indices corresponding to first occurrence of each unique element
|
|
4342
|
+
linear_unique_inverse_indices = (
|
|
4343
|
+
linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
|
|
4344
|
+
)
|
|
4345
|
+
# same as above clamp
|
|
4346
|
+
linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
|
|
4347
|
+
linear_unique_indices = linear_indices.index_select(
|
|
4348
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4349
|
+
)
|
|
4350
|
+
if hash_zch_identities is not None:
|
|
4351
|
+
# Map hash_zch_identities to unique indices
|
|
4352
|
+
hash_zch_identities = hash_zch_identities.index_select(
|
|
4353
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4354
|
+
)
|
|
4355
|
+
|
|
4356
|
+
if hash_zch_runtime_meta is not None:
|
|
4357
|
+
# Map hash_zch_runtime_meta to unique indices
|
|
4358
|
+
hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
|
|
4359
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4360
|
+
)
|
|
4361
|
+
|
|
4362
|
+
return PrefetchedInfo(
|
|
4363
|
+
linear_unique_indices,
|
|
4364
|
+
linear_unique_cache_indices,
|
|
4365
|
+
linear_unique_cache_indices_length,
|
|
4366
|
+
hash_zch_identities,
|
|
4367
|
+
hash_zch_runtime_meta,
|
|
4368
|
+
)
|
|
4369
|
+
|
|
4370
|
+
@torch.jit.ignore
|
|
4371
|
+
def _store_prefetched_tensors(
|
|
4372
|
+
self,
|
|
4373
|
+
indices: torch.Tensor,
|
|
4374
|
+
offsets: torch.Tensor,
|
|
4375
|
+
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
|
|
4376
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4377
|
+
final_lxu_cache_locations: torch.Tensor,
|
|
4378
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4379
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4380
|
+
) -> None:
|
|
4381
|
+
"""
|
|
4382
|
+
NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
|
|
4383
|
+
This function stores the prefetched tensors for the raw embedding streaming.
|
|
4384
|
+
"""
|
|
4385
|
+
if not self.enable_raw_embedding_streaming:
|
|
4386
|
+
return
|
|
4387
|
+
|
|
4388
|
+
with record_function(
|
|
4389
|
+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4390
|
+
):
|
|
4391
|
+
found_in_cache_mask = final_lxu_cache_locations != -1
|
|
4392
|
+
# only process the indices that are found in the cache
|
|
4393
|
+
# this will filter out the indices from tables that doesn't have UVM_CACHE enabled
|
|
4394
|
+
linear_cache_indices_merged_masked = torch.where(
|
|
4395
|
+
found_in_cache_mask,
|
|
4396
|
+
linear_cache_indices_merged,
|
|
4397
|
+
self.total_cache_hash_size,
|
|
4398
|
+
)
|
|
4399
|
+
linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4400
|
+
self.hash_size_cumsum,
|
|
4401
|
+
indices,
|
|
4402
|
+
offsets,
|
|
4403
|
+
vbe_metadata.B_offsets if vbe_metadata is not None else None,
|
|
4404
|
+
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
4405
|
+
)
|
|
4406
|
+
# -1 indices are ignored in raw_embedding_streamer.
|
|
4407
|
+
linearize_indices_masked = torch.where(
|
|
4408
|
+
found_in_cache_mask,
|
|
4409
|
+
linearize_indices,
|
|
4410
|
+
-1,
|
|
4411
|
+
)
|
|
4412
|
+
# Process hash_zch_identities using helper function
|
|
4413
|
+
prefetched_info = self._get_prefetched_info(
|
|
4414
|
+
linearize_indices_masked,
|
|
4415
|
+
linear_cache_indices_merged_masked,
|
|
4416
|
+
self.total_cache_hash_size,
|
|
4417
|
+
hash_zch_identities,
|
|
4418
|
+
hash_zch_runtime_meta,
|
|
4419
|
+
self.lxu_cache_weights.size(0),
|
|
4420
|
+
)
|
|
4421
|
+
|
|
4422
|
+
self.prefetched_info_list.append(prefetched_info)
|
|
4423
|
+
|
|
4424
|
+
@torch.jit.ignore
|
|
4425
|
+
def __report_input_params_factory(
|
|
4426
|
+
self,
|
|
4427
|
+
) -> Optional[Callable[..., None]]:
|
|
4428
|
+
"""
|
|
4429
|
+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
|
|
4430
|
+
|
|
4431
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
|
|
4432
|
+
- Reports input parameters (TBEDataConfig).
|
|
4433
|
+
- Writes the output as a JSON file.
|
|
4434
|
+
|
|
4435
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
|
|
4436
|
+
"""
|
|
4437
|
+
try:
|
|
4438
|
+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
|
|
4439
|
+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
|
|
4440
|
+
|
|
4441
|
+
reporter = TBEBenchmarkParamsReporter.create()
|
|
4442
|
+
return reporter.report_stats
|
|
4443
|
+
except Exception:
|
|
4444
|
+
return None
|
|
4445
|
+
|
|
4446
|
+
return None
|
|
4447
|
+
|
|
3674
4448
|
|
|
3675
4449
|
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
3676
4450
|
"""
|
|
@@ -3684,12 +4458,12 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3684
4458
|
max_D: int
|
|
3685
4459
|
hash_size_cumsum: Tensor
|
|
3686
4460
|
total_hash_size_bits: int
|
|
3687
|
-
embedding_specs:
|
|
4461
|
+
embedding_specs: list[tuple[int, int]]
|
|
3688
4462
|
|
|
3689
4463
|
def __init__(
|
|
3690
4464
|
self,
|
|
3691
|
-
embedding_specs:
|
|
3692
|
-
feature_table_map: Optional[
|
|
4465
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
4466
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
3693
4467
|
weights_precision: SparseType = SparseType.FP32,
|
|
3694
4468
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
3695
4469
|
use_cpu: bool = False,
|
|
@@ -3732,7 +4506,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3732
4506
|
)
|
|
3733
4507
|
|
|
3734
4508
|
self.embedding_specs = embedding_specs
|
|
3735
|
-
|
|
4509
|
+
rows, dims = zip(*embedding_specs)
|
|
3736
4510
|
T_ = len(self.embedding_specs)
|
|
3737
4511
|
assert T_ > 0
|
|
3738
4512
|
|
|
@@ -3802,7 +4576,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3802
4576
|
row for (row, _) in embedding_specs[:t]
|
|
3803
4577
|
)
|
|
3804
4578
|
|
|
3805
|
-
self.weights_physical_offsets:
|
|
4579
|
+
self.weights_physical_offsets: list[int] = weights_offsets
|
|
3806
4580
|
weights_offsets = [weights_offsets[t] for t in feature_table_map]
|
|
3807
4581
|
self.register_buffer(
|
|
3808
4582
|
"weights_offsets",
|
|
@@ -3829,7 +4603,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3829
4603
|
def _generate_vbe_metadata(
|
|
3830
4604
|
self,
|
|
3831
4605
|
offsets: Tensor,
|
|
3832
|
-
batch_size_per_feature_per_rank: Optional[
|
|
4606
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
3833
4607
|
) -> invokers.lookup_args.VBEMetadata:
|
|
3834
4608
|
# Blocking D2H copy, but only runs at first call
|
|
3835
4609
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -3847,7 +4621,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3847
4621
|
offsets: Tensor,
|
|
3848
4622
|
per_sample_weights: Optional[Tensor] = None,
|
|
3849
4623
|
feature_requires_grad: Optional[Tensor] = None,
|
|
3850
|
-
batch_size_per_feature_per_rank: Optional[
|
|
4624
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3851
4625
|
) -> Tensor:
|
|
3852
4626
|
# Generate VBE metadata
|
|
3853
4627
|
vbe_metadata = self._generate_vbe_metadata(
|
|
@@ -3886,7 +4660,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3886
4660
|
)
|
|
3887
4661
|
|
|
3888
4662
|
@torch.jit.export
|
|
3889
|
-
def split_embedding_weights(self) ->
|
|
4663
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
3890
4664
|
"""
|
|
3891
4665
|
Returns a list of weights, split by table
|
|
3892
4666
|
"""
|