fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -18,14 +18,14 @@ import uuid
|
|
|
18
18
|
from dataclasses import dataclass, field
|
|
19
19
|
from itertools import accumulate
|
|
20
20
|
from math import log2
|
|
21
|
-
from typing import Any, Callable,
|
|
21
|
+
from typing import Any, Callable, Optional, Union
|
|
22
22
|
|
|
23
23
|
import torch # usort:skip
|
|
24
24
|
from torch import nn, Tensor # usort:skip
|
|
25
|
+
from torch.autograd.profiler import record_function # usort:skip
|
|
25
26
|
|
|
26
27
|
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
27
28
|
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
28
|
-
|
|
29
29
|
from fbgemm_gpu.config import FeatureGate, FeatureGateName
|
|
30
30
|
from fbgemm_gpu.runtime_monitor import (
|
|
31
31
|
AsyncSeriesTimer,
|
|
@@ -48,6 +48,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
48
48
|
SplitState,
|
|
49
49
|
)
|
|
50
50
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
51
|
+
check_allocated_vbe_output,
|
|
51
52
|
generate_vbe_metadata,
|
|
52
53
|
is_torchdynamo_compiling,
|
|
53
54
|
)
|
|
@@ -57,8 +58,8 @@ from fbgemm_gpu.tbe_input_multiplexer import (
|
|
|
57
58
|
TBEInputMultiplexer,
|
|
58
59
|
TBEInputMultiplexerConfig,
|
|
59
60
|
)
|
|
60
|
-
|
|
61
61
|
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
|
62
|
+
from fbgemm_gpu.utils.writeback_util import writeback_gradient
|
|
62
63
|
|
|
63
64
|
try:
|
|
64
65
|
load_torch_module(
|
|
@@ -158,6 +159,7 @@ class UserEnabledConfigDefinition:
|
|
|
158
159
|
# More details can be found in D64848802.
|
|
159
160
|
use_rowwise_bias_correction: bool = False
|
|
160
161
|
use_writeback_bwd_prehook: bool = False
|
|
162
|
+
writeback_first_feature_only: bool = False
|
|
161
163
|
|
|
162
164
|
|
|
163
165
|
@dataclass(frozen=True)
|
|
@@ -190,25 +192,48 @@ class UVMCacheStatsIndex(enum.IntEnum):
|
|
|
190
192
|
class RESParams:
|
|
191
193
|
res_server_port: int = 0 # the port of the res server
|
|
192
194
|
res_store_shards: int = 1 # the number of shards to store the raw embeddings
|
|
193
|
-
table_names:
|
|
194
|
-
table_offsets:
|
|
195
|
+
table_names: list[str] = field(default_factory=list) # table names the TBE holds
|
|
196
|
+
table_offsets: list[int] = field(
|
|
195
197
|
default_factory=list
|
|
196
198
|
) # table offsets for the global rows the TBE holds
|
|
197
|
-
table_sizes:
|
|
199
|
+
table_sizes: list[int] = field(
|
|
198
200
|
default_factory=list
|
|
199
201
|
) # table sizes for the global rows the TBE holds
|
|
200
202
|
|
|
201
203
|
|
|
204
|
+
class PrefetchedInfo:
|
|
205
|
+
"""
|
|
206
|
+
Container for prefetched cache information.
|
|
207
|
+
|
|
208
|
+
This class is explicitly defined (not using @dataclass) to be compatible with
|
|
209
|
+
TorchScript's inspect.getsource() requirements.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
linear_unique_indices: torch.Tensor,
|
|
215
|
+
linear_unique_cache_indices: torch.Tensor,
|
|
216
|
+
linear_unique_indices_length: torch.Tensor,
|
|
217
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
218
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
219
|
+
) -> None:
|
|
220
|
+
self.linear_unique_indices = linear_unique_indices
|
|
221
|
+
self.linear_unique_cache_indices = linear_unique_cache_indices
|
|
222
|
+
self.linear_unique_indices_length = linear_unique_indices_length
|
|
223
|
+
self.hash_zch_identities = hash_zch_identities
|
|
224
|
+
self.hash_zch_runtime_meta = hash_zch_runtime_meta
|
|
225
|
+
|
|
226
|
+
|
|
202
227
|
def construct_split_state(
|
|
203
|
-
embedding_specs:
|
|
228
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
|
|
204
229
|
rowwise: bool,
|
|
205
230
|
cacheable: bool,
|
|
206
231
|
precision: SparseType = SparseType.FP32,
|
|
207
232
|
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
|
|
208
233
|
placement: Optional[EmbeddingLocation] = None,
|
|
209
234
|
) -> SplitState:
|
|
210
|
-
placements:
|
|
211
|
-
offsets:
|
|
235
|
+
placements: list[EmbeddingLocation] = []
|
|
236
|
+
offsets: list[int] = []
|
|
212
237
|
dev_size: int = 0
|
|
213
238
|
host_size: int = 0
|
|
214
239
|
uvm_size: int = 0
|
|
@@ -250,18 +275,18 @@ def construct_split_state(
|
|
|
250
275
|
def apply_split_helper(
|
|
251
276
|
persistent_state_fn: Callable[[str, Tensor], None],
|
|
252
277
|
set_attr_fn: Callable[
|
|
253
|
-
[str, Union[Tensor,
|
|
278
|
+
[str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
|
|
254
279
|
],
|
|
255
280
|
current_device: torch.device,
|
|
256
281
|
use_cpu: bool,
|
|
257
|
-
feature_table_map:
|
|
282
|
+
feature_table_map: list[int],
|
|
258
283
|
split: SplitState,
|
|
259
284
|
prefix: str,
|
|
260
|
-
dtype:
|
|
285
|
+
dtype: type[torch.dtype],
|
|
261
286
|
enforce_hbm: bool = False,
|
|
262
287
|
make_dev_param: bool = False,
|
|
263
|
-
dev_reshape: Optional[
|
|
264
|
-
uvm_tensors_log: Optional[
|
|
288
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
289
|
+
uvm_tensors_log: Optional[list[str]] = None,
|
|
265
290
|
uvm_host_mapped: bool = False,
|
|
266
291
|
) -> None:
|
|
267
292
|
set_attr_fn(f"{prefix}_physical_placements", split.placements)
|
|
@@ -346,6 +371,7 @@ def apply_split_helper(
|
|
|
346
371
|
f"{prefix}_uvm",
|
|
347
372
|
torch.zeros(
|
|
348
373
|
split.uvm_size,
|
|
374
|
+
device=current_device,
|
|
349
375
|
out=torch.ops.fbgemm.new_unified_tensor(
|
|
350
376
|
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
|
|
351
377
|
# for 3rd param but got `Type[Type[torch._dtype]]`.
|
|
@@ -621,11 +647,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
621
647
|
(preshard_table_height, preshard_table_dim, height_offset, dim_offset)
|
|
622
648
|
"""
|
|
623
649
|
|
|
624
|
-
embedding_specs:
|
|
650
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
|
|
625
651
|
optimizer_args: invokers.lookup_args.OptimizerArgs
|
|
626
|
-
lxu_cache_locations_list:
|
|
652
|
+
lxu_cache_locations_list: list[Tensor]
|
|
627
653
|
lxu_cache_locations_empty: Tensor
|
|
628
|
-
timesteps_prefetched:
|
|
654
|
+
timesteps_prefetched: list[int]
|
|
655
|
+
prefetched_info_list: list[PrefetchedInfo]
|
|
629
656
|
record_cache_metrics: RecordCacheMetrics
|
|
630
657
|
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
|
631
658
|
uvm_cache_stats: torch.Tensor
|
|
@@ -639,10 +666,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
639
666
|
|
|
640
667
|
def __init__( # noqa C901
|
|
641
668
|
self,
|
|
642
|
-
embedding_specs:
|
|
643
|
-
|
|
669
|
+
embedding_specs: list[
|
|
670
|
+
tuple[int, int, EmbeddingLocation, ComputeDevice]
|
|
644
671
|
], # tuple of (rows, dims, placements, compute_devices)
|
|
645
|
-
feature_table_map: Optional[
|
|
672
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
646
673
|
cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
|
|
647
674
|
cache_load_factor: float = 0.2,
|
|
648
675
|
cache_sets: int = 0,
|
|
@@ -680,8 +707,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
680
707
|
use_experimental_tbe: bool = False,
|
|
681
708
|
prefetch_pipeline: bool = False,
|
|
682
709
|
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
683
|
-
table_names: Optional[
|
|
684
|
-
optimizer_state_dtypes: Optional[
|
|
710
|
+
table_names: Optional[list[str]] = None,
|
|
711
|
+
optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
|
|
685
712
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
686
713
|
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
|
|
687
714
|
uvm_host_mapped: bool = False,
|
|
@@ -689,7 +716,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
689
716
|
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
|
|
690
717
|
embedding_table_index_type: torch.dtype = torch.int64,
|
|
691
718
|
embedding_table_offset_type: torch.dtype = torch.int64,
|
|
692
|
-
embedding_shard_info: Optional[
|
|
719
|
+
embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
|
|
720
|
+
enable_raw_embedding_streaming: bool = False,
|
|
721
|
+
res_params: Optional[RESParams] = None,
|
|
693
722
|
) -> None:
|
|
694
723
|
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
695
724
|
self.uuid = str(uuid.uuid4())
|
|
@@ -699,7 +728,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
699
728
|
f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
|
|
700
729
|
)
|
|
701
730
|
|
|
731
|
+
self.table_names: Optional[list[str]] = table_names
|
|
702
732
|
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
|
|
733
|
+
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
|
|
703
734
|
self.pooling_mode = pooling_mode
|
|
704
735
|
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
|
|
705
736
|
|
|
@@ -793,9 +824,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
793
824
|
), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
|
|
794
825
|
|
|
795
826
|
self.embedding_specs = embedding_specs
|
|
796
|
-
|
|
827
|
+
rows, dims, locations, compute_devices = zip(*embedding_specs)
|
|
797
828
|
T_ = len(self.embedding_specs)
|
|
798
|
-
self.dims:
|
|
829
|
+
self.dims: list[int] = dims
|
|
799
830
|
assert T_ > 0
|
|
800
831
|
# mixed D is not supported by no bag kernels
|
|
801
832
|
mixed_D = False
|
|
@@ -808,7 +839,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
808
839
|
assert (
|
|
809
840
|
self.pooling_mode != PoolingMode.NONE
|
|
810
841
|
), "Mixed dimension tables only supported for pooling tables."
|
|
811
|
-
|
|
842
|
+
self.mixed_D: bool = mixed_D
|
|
812
843
|
assert all(
|
|
813
844
|
cd == compute_devices[0] for cd in compute_devices
|
|
814
845
|
), "Heterogenous compute_devices are NOT supported!"
|
|
@@ -872,7 +903,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
872
903
|
self.stats_reporter: Optional[TBEStatsReporter] = (
|
|
873
904
|
stats_reporter_config.create_reporter() if stats_reporter_config else None
|
|
874
905
|
)
|
|
875
|
-
self._uvm_tensors_log:
|
|
906
|
+
self._uvm_tensors_log: list[str] = []
|
|
876
907
|
|
|
877
908
|
self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
|
|
878
909
|
self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
|
|
@@ -899,12 +930,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
899
930
|
|
|
900
931
|
self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
|
|
901
932
|
|
|
902
|
-
self.feature_table_map:
|
|
933
|
+
self.feature_table_map: list[int] = (
|
|
903
934
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
904
935
|
)
|
|
905
936
|
|
|
906
937
|
if embedding_shard_info:
|
|
907
|
-
|
|
938
|
+
full_table_heights, full_table_dims, row_offset, col_offset = zip(
|
|
908
939
|
*embedding_shard_info
|
|
909
940
|
)
|
|
910
941
|
else:
|
|
@@ -939,7 +970,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
939
970
|
table_has_feature = [False] * T_
|
|
940
971
|
for t in self.feature_table_map:
|
|
941
972
|
table_has_feature[t] = True
|
|
942
|
-
assert all(table_has_feature),
|
|
973
|
+
assert all(table_has_feature), (
|
|
974
|
+
"Each table must have at least one feature!"
|
|
975
|
+
+ f"{[(i, x) for i, x in enumerate(table_has_feature)]}"
|
|
976
|
+
)
|
|
943
977
|
|
|
944
978
|
feature_dims = [dims[t] for t in self.feature_table_map]
|
|
945
979
|
D_offsets = [0] + list(accumulate(feature_dims))
|
|
@@ -991,7 +1025,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
991
1025
|
"feature_dims",
|
|
992
1026
|
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
993
1027
|
)
|
|
994
|
-
|
|
1028
|
+
_info_B_num_bits, _info_B_mask = torch.ops.fbgemm.get_infos_metadata(
|
|
995
1029
|
self.D_offsets, # unused tensor
|
|
996
1030
|
1, # max_B
|
|
997
1031
|
T, # T
|
|
@@ -1105,13 +1139,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1105
1139
|
|
|
1106
1140
|
if ensemble_mode is None:
|
|
1107
1141
|
ensemble_mode = EnsembleModeDefinition()
|
|
1108
|
-
self._ensemble_mode:
|
|
1142
|
+
self._ensemble_mode: dict[str, float] = {
|
|
1109
1143
|
key: float(fval) for key, fval in ensemble_mode.__dict__.items()
|
|
1110
1144
|
}
|
|
1111
1145
|
|
|
1112
1146
|
if emainplace_mode is None:
|
|
1113
1147
|
emainplace_mode = EmainplaceModeDefinition()
|
|
1114
|
-
self._emainplace_mode:
|
|
1148
|
+
self._emainplace_mode: dict[str, float] = {
|
|
1115
1149
|
key: float(fval) for key, fval in emainplace_mode.__dict__.items()
|
|
1116
1150
|
}
|
|
1117
1151
|
|
|
@@ -1151,6 +1185,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1151
1185
|
self.use_writeback_bwd_prehook: bool = (
|
|
1152
1186
|
extra_optimizer_config.use_writeback_bwd_prehook
|
|
1153
1187
|
)
|
|
1188
|
+
|
|
1189
|
+
writeback_first_feature_only: bool = (
|
|
1190
|
+
extra_optimizer_config.writeback_first_feature_only
|
|
1191
|
+
)
|
|
1154
1192
|
self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
|
|
1155
1193
|
if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
|
|
1156
1194
|
raise AssertionError(
|
|
@@ -1416,7 +1454,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1416
1454
|
|
|
1417
1455
|
self.step = 0
|
|
1418
1456
|
self.last_reported_step = 0
|
|
1419
|
-
self.last_reported_uvm_stats:
|
|
1457
|
+
self.last_reported_uvm_stats: list[float] = []
|
|
1458
|
+
# Track number of times detailed memory breakdown has been reported
|
|
1459
|
+
self.detailed_mem_breakdown_report_count = 0
|
|
1460
|
+
# Set max number of reports for detailed memory breakdown
|
|
1461
|
+
self.max_detailed_mem_breakdown_reports = 10
|
|
1420
1462
|
|
|
1421
1463
|
# Check whether to use TBE v2
|
|
1422
1464
|
is_experimental = False
|
|
@@ -1435,16 +1477,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1435
1477
|
# self.log("TBE_V2 Knob is set to True; Using experimental TBE")
|
|
1436
1478
|
|
|
1437
1479
|
self.is_experimental: bool = is_experimental
|
|
1480
|
+
self._writeback_first_feature_only: bool = writeback_first_feature_only
|
|
1438
1481
|
|
|
1439
1482
|
# Get a debug function pointer
|
|
1440
1483
|
self._debug_print_input_stats: Callable[..., None] = (
|
|
1441
1484
|
self._debug_print_input_stats_factory()
|
|
1442
1485
|
)
|
|
1443
1486
|
|
|
1487
|
+
# Get a reporter function pointer
|
|
1488
|
+
self._report_input_params: Callable[..., None] = (
|
|
1489
|
+
self.__report_input_params_factory()
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1444
1492
|
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
|
|
1445
1493
|
# Register writeback hook for Exact_SGD optimizer
|
|
1446
1494
|
self.log(
|
|
1447
|
-
"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
|
|
1495
|
+
f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
|
|
1448
1496
|
)
|
|
1449
1497
|
# pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
|
|
1450
1498
|
self.register_full_backward_pre_hook(self.writeback_hook)
|
|
@@ -1460,6 +1508,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1460
1508
|
)
|
|
1461
1509
|
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
|
|
1462
1510
|
|
|
1511
|
+
self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
|
|
1512
|
+
list[PrefetchedInfo], []
|
|
1513
|
+
)
|
|
1514
|
+
if self.enable_raw_embedding_streaming:
|
|
1515
|
+
self.res_params: RESParams = res_params or RESParams()
|
|
1516
|
+
self.res_params.table_sizes = [0] + list(accumulate(rows))
|
|
1517
|
+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
|
|
1518
|
+
self.res_params.res_server_port = (
|
|
1519
|
+
int(res_port_from_env) if res_port_from_env else 0
|
|
1520
|
+
)
|
|
1521
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
1522
|
+
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
|
|
1523
|
+
self.uuid,
|
|
1524
|
+
self.enable_raw_embedding_streaming,
|
|
1525
|
+
self.res_params.res_store_shards,
|
|
1526
|
+
self.res_params.res_server_port,
|
|
1527
|
+
self.res_params.table_names,
|
|
1528
|
+
self.res_params.table_offsets,
|
|
1529
|
+
self.res_params.table_sizes,
|
|
1530
|
+
)
|
|
1531
|
+
logging.info(
|
|
1532
|
+
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1463
1535
|
@torch.jit.ignore
|
|
1464
1536
|
def log(self, msg: str) -> None:
|
|
1465
1537
|
"""
|
|
@@ -1503,7 +1575,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1503
1575
|
)
|
|
1504
1576
|
|
|
1505
1577
|
@staticmethod
|
|
1506
|
-
def get_table_name_for_logging(table_names: Optional[
|
|
1578
|
+
def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
|
|
1507
1579
|
"""
|
|
1508
1580
|
Given a list of all table names in the TBE, generate a string to
|
|
1509
1581
|
represent them in logging. If there is more than one table, this method
|
|
@@ -1519,17 +1591,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1519
1591
|
return "<Unknown>"
|
|
1520
1592
|
# Do this because sometimes multiple shards of the same table could appear
|
|
1521
1593
|
# in one TBE.
|
|
1522
|
-
table_name_set = set(table_names)
|
|
1594
|
+
table_name_set = sorted(set(table_names))
|
|
1523
1595
|
if len(table_name_set) == 1:
|
|
1524
1596
|
return next(iter(table_name_set))
|
|
1525
|
-
return f"<{len(table_name_set)} tables
|
|
1597
|
+
return f"<{len(table_name_set)} tables>: {table_name_set}"
|
|
1526
1598
|
|
|
1527
1599
|
@staticmethod
|
|
1528
1600
|
def get_prefetch_passes(
|
|
1529
1601
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
|
|
1530
1602
|
input_tensor: Tensor,
|
|
1531
1603
|
output_tensor: Tensor,
|
|
1532
|
-
) ->
|
|
1604
|
+
) -> list[tuple[Tensor, Tensor, int]]:
|
|
1533
1605
|
"""
|
|
1534
1606
|
Given inputs (the indices to forward), partition the input and output
|
|
1535
1607
|
into smaller chunks and return them as a list of tuples
|
|
@@ -1577,7 +1649,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1577
1649
|
)
|
|
1578
1650
|
)
|
|
1579
1651
|
|
|
1580
|
-
def get_states(self, prefix: str) ->
|
|
1652
|
+
def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
1581
1653
|
"""
|
|
1582
1654
|
Get a state of a given tensor (`prefix`)
|
|
1583
1655
|
|
|
@@ -1616,7 +1688,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1616
1688
|
torch.tensor(offsets, dtype=torch.int64),
|
|
1617
1689
|
)
|
|
1618
1690
|
|
|
1619
|
-
def get_all_states(self) ->
|
|
1691
|
+
def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
|
|
1620
1692
|
"""
|
|
1621
1693
|
Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
|
|
1622
1694
|
`prev_iter`, and `row_counter`)
|
|
@@ -1680,10 +1752,161 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1680
1752
|
tbe_id=self.uuid,
|
|
1681
1753
|
)
|
|
1682
1754
|
|
|
1683
|
-
|
|
1684
|
-
|
|
1755
|
+
def _get_tensor_memory(self, tensor_name: str) -> int:
|
|
1756
|
+
"""Get memory usage of a tensor in bytes."""
|
|
1757
|
+
if not hasattr(self, tensor_name):
|
|
1758
|
+
self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
|
|
1759
|
+
return 0
|
|
1760
|
+
tensor = getattr(self, tensor_name)
|
|
1761
|
+
return tensor.numel() * tensor.element_size()
|
|
1762
|
+
|
|
1763
|
+
def _categorize_memory_by_location(
|
|
1764
|
+
self, tensor_names: list[str]
|
|
1765
|
+
) -> tuple[int, int]:
|
|
1766
|
+
"""Categorize memory into HBM and UVM for given tensors.
|
|
1767
|
+
|
|
1768
|
+
Returns:
|
|
1769
|
+
(hbm_bytes, uvm_bytes)
|
|
1770
|
+
"""
|
|
1771
|
+
uvm_set = set(self._uvm_tensors_log)
|
|
1772
|
+
hbm_bytes = 0
|
|
1773
|
+
uvm_bytes = 0
|
|
1774
|
+
|
|
1775
|
+
for name in tensor_names:
|
|
1776
|
+
size = self._get_tensor_memory(name)
|
|
1777
|
+
if name in uvm_set:
|
|
1778
|
+
uvm_bytes += size
|
|
1779
|
+
else:
|
|
1780
|
+
hbm_bytes += size
|
|
1781
|
+
|
|
1782
|
+
return hbm_bytes, uvm_bytes
|
|
1783
|
+
|
|
1784
|
+
def _report_hbm_breakdown(
|
|
1785
|
+
self,
|
|
1786
|
+
stats_reporter: TBEStatsReporter,
|
|
1787
|
+
embeddings: int,
|
|
1788
|
+
optimizer_states: int,
|
|
1789
|
+
cache: int,
|
|
1790
|
+
total_static_sparse: int,
|
|
1791
|
+
ephemeral: int,
|
|
1792
|
+
cache_weights: int = 0,
|
|
1793
|
+
cache_aux: int = 0,
|
|
1794
|
+
) -> None:
|
|
1795
|
+
"""Report HBM memory breakdown to stats reporter."""
|
|
1796
|
+
stats_reporter.report_data_amount(
|
|
1797
|
+
iteration_step=self.step,
|
|
1798
|
+
event_name="tbe.hbm.embeddings",
|
|
1799
|
+
data_bytes=embeddings,
|
|
1800
|
+
embedding_id=self.logging_table_name,
|
|
1801
|
+
tbe_id=self.uuid,
|
|
1802
|
+
)
|
|
1803
|
+
stats_reporter.report_data_amount(
|
|
1804
|
+
iteration_step=self.step,
|
|
1805
|
+
event_name="tbe.hbm.optimizer_states",
|
|
1806
|
+
data_bytes=optimizer_states,
|
|
1807
|
+
embedding_id=self.logging_table_name,
|
|
1808
|
+
tbe_id=self.uuid,
|
|
1809
|
+
)
|
|
1810
|
+
stats_reporter.report_data_amount(
|
|
1811
|
+
iteration_step=self.step,
|
|
1812
|
+
event_name="tbe.hbm.cache",
|
|
1813
|
+
data_bytes=cache,
|
|
1814
|
+
embedding_id=self.logging_table_name,
|
|
1815
|
+
tbe_id=self.uuid,
|
|
1816
|
+
)
|
|
1817
|
+
stats_reporter.report_data_amount(
|
|
1818
|
+
iteration_step=self.step,
|
|
1819
|
+
event_name="tbe.hbm.cache_weights",
|
|
1820
|
+
data_bytes=cache_weights,
|
|
1821
|
+
embedding_id=self.logging_table_name,
|
|
1822
|
+
tbe_id=self.uuid,
|
|
1823
|
+
)
|
|
1824
|
+
stats_reporter.report_data_amount(
|
|
1825
|
+
iteration_step=self.step,
|
|
1826
|
+
event_name="tbe.hbm.cache_aux",
|
|
1827
|
+
data_bytes=cache_aux,
|
|
1828
|
+
embedding_id=self.logging_table_name,
|
|
1829
|
+
tbe_id=self.uuid,
|
|
1830
|
+
)
|
|
1831
|
+
stats_reporter.report_data_amount(
|
|
1832
|
+
iteration_step=self.step,
|
|
1833
|
+
event_name="tbe.hbm.total_static_sparse",
|
|
1834
|
+
data_bytes=total_static_sparse,
|
|
1835
|
+
embedding_id=self.logging_table_name,
|
|
1836
|
+
tbe_id=self.uuid,
|
|
1837
|
+
)
|
|
1838
|
+
stats_reporter.report_data_amount(
|
|
1839
|
+
iteration_step=self.step,
|
|
1840
|
+
event_name="tbe.hbm.ephemeral",
|
|
1841
|
+
data_bytes=ephemeral,
|
|
1842
|
+
embedding_id=self.logging_table_name,
|
|
1843
|
+
tbe_id=self.uuid,
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
def _report_uvm_breakdown(
|
|
1685
1847
|
self,
|
|
1848
|
+
stats_reporter: TBEStatsReporter,
|
|
1849
|
+
embeddings: int,
|
|
1850
|
+
optimizer_states: int,
|
|
1851
|
+
cache: int,
|
|
1852
|
+
total_static_sparse: int,
|
|
1853
|
+
ephemeral: int,
|
|
1854
|
+
cache_weights: int = 0,
|
|
1855
|
+
cache_aux: int = 0,
|
|
1686
1856
|
) -> None:
|
|
1857
|
+
"""Report UVM memory breakdown to stats reporter."""
|
|
1858
|
+
stats_reporter.report_data_amount(
|
|
1859
|
+
iteration_step=self.step,
|
|
1860
|
+
event_name="tbe.uvm.embeddings",
|
|
1861
|
+
data_bytes=embeddings,
|
|
1862
|
+
embedding_id=self.logging_table_name,
|
|
1863
|
+
tbe_id=self.uuid,
|
|
1864
|
+
)
|
|
1865
|
+
stats_reporter.report_data_amount(
|
|
1866
|
+
iteration_step=self.step,
|
|
1867
|
+
event_name="tbe.uvm.optimizer_states",
|
|
1868
|
+
data_bytes=optimizer_states,
|
|
1869
|
+
embedding_id=self.logging_table_name,
|
|
1870
|
+
tbe_id=self.uuid,
|
|
1871
|
+
)
|
|
1872
|
+
stats_reporter.report_data_amount(
|
|
1873
|
+
iteration_step=self.step,
|
|
1874
|
+
event_name="tbe.uvm.cache",
|
|
1875
|
+
data_bytes=cache,
|
|
1876
|
+
embedding_id=self.logging_table_name,
|
|
1877
|
+
tbe_id=self.uuid,
|
|
1878
|
+
)
|
|
1879
|
+
stats_reporter.report_data_amount(
|
|
1880
|
+
iteration_step=self.step,
|
|
1881
|
+
event_name="tbe.uvm.cache_weights",
|
|
1882
|
+
data_bytes=cache_weights,
|
|
1883
|
+
embedding_id=self.logging_table_name,
|
|
1884
|
+
tbe_id=self.uuid,
|
|
1885
|
+
)
|
|
1886
|
+
stats_reporter.report_data_amount(
|
|
1887
|
+
iteration_step=self.step,
|
|
1888
|
+
event_name="tbe.uvm.cache_aux",
|
|
1889
|
+
data_bytes=cache_aux,
|
|
1890
|
+
embedding_id=self.logging_table_name,
|
|
1891
|
+
tbe_id=self.uuid,
|
|
1892
|
+
)
|
|
1893
|
+
stats_reporter.report_data_amount(
|
|
1894
|
+
iteration_step=self.step,
|
|
1895
|
+
event_name="tbe.uvm.total_static_sparse",
|
|
1896
|
+
data_bytes=total_static_sparse,
|
|
1897
|
+
embedding_id=self.logging_table_name,
|
|
1898
|
+
tbe_id=self.uuid,
|
|
1899
|
+
)
|
|
1900
|
+
stats_reporter.report_data_amount(
|
|
1901
|
+
iteration_step=self.step,
|
|
1902
|
+
event_name="tbe.uvm.ephemeral",
|
|
1903
|
+
data_bytes=ephemeral,
|
|
1904
|
+
embedding_id=self.logging_table_name,
|
|
1905
|
+
tbe_id=self.uuid,
|
|
1906
|
+
)
|
|
1907
|
+
|
|
1908
|
+
@torch.jit.ignore
|
|
1909
|
+
def _report_tbe_mem_usage(self) -> None:
|
|
1687
1910
|
if self.stats_reporter is None:
|
|
1688
1911
|
return
|
|
1689
1912
|
|
|
@@ -1691,22 +1914,24 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1691
1914
|
if not stats_reporter.should_report(self.step):
|
|
1692
1915
|
return
|
|
1693
1916
|
|
|
1917
|
+
# Calculate total memory from all parameters and buffers (always needed)
|
|
1694
1918
|
total_mem_usage = sum(
|
|
1695
|
-
|
|
1696
|
-
) + sum(
|
|
1919
|
+
p.numel() * p.element_size() for p in self.parameters()
|
|
1920
|
+
) + sum(b.numel() * b.element_size() for b in self.buffers())
|
|
1921
|
+
|
|
1922
|
+
# Calculate total HBM and UVM usage (always needed)
|
|
1697
1923
|
if self.use_cpu:
|
|
1698
1924
|
total_hbm_usage = 0
|
|
1699
1925
|
total_uvm_usage = total_mem_usage
|
|
1700
1926
|
else:
|
|
1701
|
-
# hbm usage is total usage minus uvm usage
|
|
1702
1927
|
total_uvm_usage = sum(
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
if hasattr(self, tensor_name)
|
|
1928
|
+
self._get_tensor_memory(name)
|
|
1929
|
+
for name in self._uvm_tensors_log
|
|
1930
|
+
if hasattr(self, name)
|
|
1707
1931
|
)
|
|
1708
1932
|
total_hbm_usage = total_mem_usage - total_uvm_usage
|
|
1709
1933
|
|
|
1934
|
+
# Report total memory usage metrics (always reported for backward compatibility)
|
|
1710
1935
|
stats_reporter.report_data_amount(
|
|
1711
1936
|
iteration_step=self.step,
|
|
1712
1937
|
event_name="tbe.total_hbm_usage",
|
|
@@ -1722,6 +1947,96 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1722
1947
|
tbe_id=self.uuid,
|
|
1723
1948
|
)
|
|
1724
1949
|
|
|
1950
|
+
# Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
|
|
1951
|
+
# steps since static sparse memory (weights, optimizer states, cache) is constant
|
|
1952
|
+
if (
|
|
1953
|
+
self.detailed_mem_breakdown_report_count
|
|
1954
|
+
>= self.max_detailed_mem_breakdown_reports
|
|
1955
|
+
):
|
|
1956
|
+
return
|
|
1957
|
+
self.detailed_mem_breakdown_report_count += 1
|
|
1958
|
+
|
|
1959
|
+
# Tensor groups for sparse memory categorization
|
|
1960
|
+
weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
|
|
1961
|
+
optimizer_tensors = [
|
|
1962
|
+
"momentum1_dev",
|
|
1963
|
+
"momentum1_host",
|
|
1964
|
+
"momentum1_uvm",
|
|
1965
|
+
"momentum2_dev",
|
|
1966
|
+
"momentum2_host",
|
|
1967
|
+
"momentum2_uvm",
|
|
1968
|
+
]
|
|
1969
|
+
# Cache weights tensor (the actual cached embeddings in HBM)
|
|
1970
|
+
cache_weight_tensors = [
|
|
1971
|
+
"lxu_cache_weights",
|
|
1972
|
+
]
|
|
1973
|
+
# Cache auxiliary state tensors (metadata for cache management, excluding weights)
|
|
1974
|
+
# Sizes scale with hash_size or cache_slots (hash_size × clf)
|
|
1975
|
+
# Excludes constant-size tensors: cache_hash_size_cumsum, cache_miss_counter, etc.
|
|
1976
|
+
cache_aux_tensors = [
|
|
1977
|
+
"cache_index_table_map", # int32, 4B × hash_size
|
|
1978
|
+
"lxu_cache_state", # int64, 8B × cache_slots
|
|
1979
|
+
"lxu_state", # int64, 8B × cache_slots (LRU) or hash_size (LFU)
|
|
1980
|
+
"lxu_cache_locking_counter", # int32, 4B × cache_slots (only if prefetch_pipeline)
|
|
1981
|
+
]
|
|
1982
|
+
|
|
1983
|
+
# Calculate total memory for each component
|
|
1984
|
+
weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
|
|
1985
|
+
optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
|
|
1986
|
+
cache_weights_total = sum(
|
|
1987
|
+
self._get_tensor_memory(t) for t in cache_weight_tensors
|
|
1988
|
+
)
|
|
1989
|
+
cache_aux_total = sum(self._get_tensor_memory(t) for t in cache_aux_tensors)
|
|
1990
|
+
|
|
1991
|
+
# Categorize memory by location (HBM vs UVM)
|
|
1992
|
+
if self.use_cpu:
|
|
1993
|
+
weights_hbm, weights_uvm = 0, weights_total
|
|
1994
|
+
opt_hbm, opt_uvm = 0, optimizer_total
|
|
1995
|
+
cache_weights_hbm, cache_weights_uvm = 0, cache_weights_total
|
|
1996
|
+
cache_aux_hbm, cache_aux_uvm = 0, cache_aux_total
|
|
1997
|
+
else:
|
|
1998
|
+
weights_hbm, weights_uvm = self._categorize_memory_by_location(
|
|
1999
|
+
weight_tensors
|
|
2000
|
+
)
|
|
2001
|
+
opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
|
|
2002
|
+
cache_weights_hbm, cache_weights_uvm = self._categorize_memory_by_location(
|
|
2003
|
+
cache_weight_tensors
|
|
2004
|
+
)
|
|
2005
|
+
cache_aux_hbm, cache_aux_uvm = self._categorize_memory_by_location(
|
|
2006
|
+
cache_aux_tensors
|
|
2007
|
+
)
|
|
2008
|
+
|
|
2009
|
+
# Calculate ephemeral memory split between HBM and UVM
|
|
2010
|
+
# Total cache = cache weights + cache auxiliary state
|
|
2011
|
+
cache_hbm = cache_weights_hbm + cache_aux_hbm
|
|
2012
|
+
cache_uvm = cache_weights_uvm + cache_aux_uvm
|
|
2013
|
+
static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
|
|
2014
|
+
static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
|
|
2015
|
+
ephemeral_hbm = total_hbm_usage - static_sparse_hbm
|
|
2016
|
+
ephemeral_uvm = total_uvm_usage - static_sparse_uvm
|
|
2017
|
+
|
|
2018
|
+
# Report granular memory breakdowns
|
|
2019
|
+
self._report_hbm_breakdown(
|
|
2020
|
+
stats_reporter,
|
|
2021
|
+
weights_hbm,
|
|
2022
|
+
opt_hbm,
|
|
2023
|
+
cache_hbm,
|
|
2024
|
+
static_sparse_hbm,
|
|
2025
|
+
ephemeral_hbm,
|
|
2026
|
+
cache_weights_hbm,
|
|
2027
|
+
cache_aux_hbm,
|
|
2028
|
+
)
|
|
2029
|
+
self._report_uvm_breakdown(
|
|
2030
|
+
stats_reporter,
|
|
2031
|
+
weights_uvm,
|
|
2032
|
+
opt_uvm,
|
|
2033
|
+
cache_uvm,
|
|
2034
|
+
static_sparse_uvm,
|
|
2035
|
+
ephemeral_uvm,
|
|
2036
|
+
cache_weights_uvm,
|
|
2037
|
+
cache_aux_uvm,
|
|
2038
|
+
)
|
|
2039
|
+
|
|
1725
2040
|
@torch.jit.ignore
|
|
1726
2041
|
def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
|
|
1727
2042
|
if self.stats_reporter is None:
|
|
@@ -1748,7 +2063,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1748
2063
|
def _generate_vbe_metadata(
|
|
1749
2064
|
self,
|
|
1750
2065
|
offsets: Tensor,
|
|
1751
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2066
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2067
|
+
vbe_output: Optional[Tensor] = None,
|
|
2068
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1752
2069
|
) -> invokers.lookup_args.VBEMetadata:
|
|
1753
2070
|
# Blocking D2H copy, but only runs at first call
|
|
1754
2071
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -1771,6 +2088,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1771
2088
|
self.pooling_mode,
|
|
1772
2089
|
self.feature_dims,
|
|
1773
2090
|
self.current_device,
|
|
2091
|
+
vbe_output,
|
|
2092
|
+
vbe_output_offsets,
|
|
1774
2093
|
)
|
|
1775
2094
|
|
|
1776
2095
|
@torch.jit.ignore
|
|
@@ -1779,40 +2098,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1779
2098
|
# This allows models using this class to compile correctly
|
|
1780
2099
|
return FeatureGate.is_enabled(feature)
|
|
1781
2100
|
|
|
1782
|
-
def writeback_update_gradient(
|
|
1783
|
-
self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
|
|
1784
|
-
) -> Tensor:
|
|
1785
|
-
if indices.numel() == 0:
|
|
1786
|
-
return grad[0]
|
|
1787
|
-
num_of_tables = len(set(self.feature_table_map))
|
|
1788
|
-
assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
|
|
1789
|
-
batch_size = offsets.shape[0] // num_of_tables
|
|
1790
|
-
max_indices = indices.max()
|
|
1791
|
-
non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
|
|
1792
|
-
# disable dedup across different table
|
|
1793
|
-
indices = ((offsets[non_empty_index]) // batch_size) * (
|
|
1794
|
-
1 + max_indices
|
|
1795
|
-
) + indices
|
|
1796
|
-
grad = grad[0]
|
|
1797
|
-
_, idx, counts = torch.unique(
|
|
1798
|
-
indices, dim=0, sorted=True, return_inverse=True, return_counts=True
|
|
1799
|
-
)
|
|
1800
|
-
_, ind_sorted = torch.sort(idx, stable=True)
|
|
1801
|
-
cum_sum = counts.cumsum(0)
|
|
1802
|
-
cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
|
|
1803
|
-
first_indicies = ind_sorted[cum_sum]
|
|
1804
|
-
mask = torch.zeros_like(grad, device=grad.device)
|
|
1805
|
-
original_index = non_empty_index[first_indicies]
|
|
1806
|
-
|
|
1807
|
-
mask[original_index] = grad[original_index]
|
|
1808
|
-
return mask
|
|
1809
|
-
|
|
1810
2101
|
# pyre-fixme[2]: For 1st argument expected not ANY
|
|
1811
|
-
def writeback_hook(self, module: Any, grad: Tensor) ->
|
|
2102
|
+
def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
|
|
1812
2103
|
indices = self._indices
|
|
1813
2104
|
offsets = self._offsets
|
|
1814
|
-
|
|
1815
|
-
|
|
2105
|
+
return writeback_gradient(
|
|
2106
|
+
grad,
|
|
2107
|
+
indices,
|
|
2108
|
+
offsets,
|
|
2109
|
+
self.feature_table_map,
|
|
2110
|
+
self._writeback_first_feature_only,
|
|
2111
|
+
)
|
|
1816
2112
|
|
|
1817
2113
|
def forward( # noqa: C901
|
|
1818
2114
|
self,
|
|
@@ -1820,8 +2116,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1820
2116
|
offsets: Tensor,
|
|
1821
2117
|
per_sample_weights: Optional[Tensor] = None,
|
|
1822
2118
|
feature_requires_grad: Optional[Tensor] = None,
|
|
1823
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2119
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
1824
2120
|
total_unique_indices: Optional[int] = None,
|
|
2121
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2122
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2123
|
+
vbe_output: Optional[Tensor] = None,
|
|
2124
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1825
2125
|
) -> Tensor:
|
|
1826
2126
|
"""
|
|
1827
2127
|
The forward pass function that
|
|
@@ -1874,7 +2174,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1874
2174
|
be set when using `OptimType.NONE`. This is because TBE
|
|
1875
2175
|
requires this information for allocating the weight gradient
|
|
1876
2176
|
tensor in the backward pass.
|
|
1877
|
-
|
|
2177
|
+
hash_zch_identities (Optional[Tensor]): The original raw IDs before
|
|
2178
|
+
remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
|
|
2179
|
+
populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
|
|
2180
|
+
and is required for Raw Embedding Streaming (RES) to maintain
|
|
2181
|
+
consistency between training and inference.
|
|
2182
|
+
vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
|
|
2183
|
+
contains output for TBE VBE. The shape of the tensor is
|
|
2184
|
+
[1, total_vbe_output_size] where total_vbe_output_size is the
|
|
2185
|
+
output size across all ranks and all embedding tables.
|
|
2186
|
+
If this tensor is not None, the TBE VBE forward output is written
|
|
2187
|
+
to this tensor at the locations specified by `vbe_output_offsets`.
|
|
2188
|
+
vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
|
|
2189
|
+
contains VBE output offsets to `vbe_output`. The shape of the
|
|
2190
|
+
tensor is [num_ranks, num_features].
|
|
2191
|
+
vbe_output_offsets[r][f] represents the starting offset for rank `r`
|
|
2192
|
+
and feature `f`.
|
|
1878
2193
|
Returns:
|
|
1879
2194
|
A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
|
|
1880
2195
|
batch size and `total_D` = the sum of all embedding dimensions in the
|
|
@@ -1948,11 +2263,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1948
2263
|
batch_size_per_feature_per_rank,
|
|
1949
2264
|
force_cast_input_types=True,
|
|
1950
2265
|
prefetch_pipeline=False,
|
|
2266
|
+
vbe_output=vbe_output,
|
|
2267
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
1951
2268
|
)
|
|
1952
2269
|
|
|
2270
|
+
# Only enable VBE if batch_size_per_feature_per_rank is not None
|
|
2271
|
+
assert not (
|
|
2272
|
+
batch_size_per_feature_per_rank is not None
|
|
2273
|
+
and self.use_writeback_bwd_prehook
|
|
2274
|
+
), "VBE is not supported with writeback_bwd_prehook"
|
|
2275
|
+
|
|
1953
2276
|
# Print input stats if enable (for debugging purpose only)
|
|
1954
2277
|
self._debug_print_input_stats(indices, offsets, per_sample_weights)
|
|
1955
2278
|
|
|
2279
|
+
# Extract and Write input stats if enable
|
|
2280
|
+
if self._report_input_params is not None:
|
|
2281
|
+
self._report_input_params(
|
|
2282
|
+
feature_rows=self.rows_per_table,
|
|
2283
|
+
feature_dims=self.feature_dims,
|
|
2284
|
+
iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
|
|
2285
|
+
indices=indices,
|
|
2286
|
+
offsets=offsets,
|
|
2287
|
+
op_id=self.uuid,
|
|
2288
|
+
per_sample_weights=per_sample_weights,
|
|
2289
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2290
|
+
embedding_specs=[(s[0], s[1]) for s in self.embedding_specs],
|
|
2291
|
+
feature_table_map=self.feature_table_map,
|
|
2292
|
+
)
|
|
2293
|
+
|
|
1956
2294
|
if not is_torchdynamo_compiling():
|
|
1957
2295
|
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
1958
2296
|
|
|
@@ -1980,7 +2318,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1980
2318
|
# to be as fast as possible and memory usage doesn't matter (will be recycled
|
|
1981
2319
|
# by dense fwd/bwd)
|
|
1982
2320
|
self._prefetch(
|
|
1983
|
-
indices,
|
|
2321
|
+
indices,
|
|
2322
|
+
offsets,
|
|
2323
|
+
vbe_metadata,
|
|
2324
|
+
multipass_prefetch_config=None,
|
|
2325
|
+
hash_zch_identities=hash_zch_identities,
|
|
2326
|
+
hash_zch_runtime_meta=hash_zch_runtime_meta,
|
|
1984
2327
|
)
|
|
1985
2328
|
|
|
1986
2329
|
if len(self.timesteps_prefetched) > 0:
|
|
@@ -2262,6 +2605,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2262
2605
|
row_counter,
|
|
2263
2606
|
iter_int,
|
|
2264
2607
|
self.max_counter.item(),
|
|
2608
|
+
mixed_D=self.mixed_D,
|
|
2265
2609
|
),
|
|
2266
2610
|
)
|
|
2267
2611
|
elif self._used_rowwise_adagrad_with_global_weight_decay:
|
|
@@ -2280,6 +2624,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2280
2624
|
# `Optional[Tensor]` but got `Union[Module, Tensor]`.
|
|
2281
2625
|
prev_iter_dev=self.prev_iter_dev,
|
|
2282
2626
|
gwd_lower_bound=self.gwd_lower_bound,
|
|
2627
|
+
mixed_D=self.mixed_D,
|
|
2283
2628
|
),
|
|
2284
2629
|
)
|
|
2285
2630
|
else:
|
|
@@ -2289,12 +2634,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2289
2634
|
common_args,
|
|
2290
2635
|
self.optimizer_args,
|
|
2291
2636
|
momentum1,
|
|
2637
|
+
mixed_D=self.mixed_D,
|
|
2292
2638
|
),
|
|
2293
2639
|
)
|
|
2294
2640
|
|
|
2295
2641
|
raise ValueError(f"Invalid OptimType: {self.optimizer}")
|
|
2296
2642
|
|
|
2297
|
-
def ema_inplace(self, emainplace_mode:
|
|
2643
|
+
def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
|
|
2298
2644
|
"""
|
|
2299
2645
|
Perform ema operations on the full sparse embedding tables.
|
|
2300
2646
|
We organize the sparse table, in the following way.
|
|
@@ -2324,7 +2670,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2324
2670
|
emainplace_mode["step_ema_coef"],
|
|
2325
2671
|
)
|
|
2326
2672
|
|
|
2327
|
-
def ensemble_and_swap(self, ensemble_mode:
|
|
2673
|
+
def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
|
|
2328
2674
|
"""
|
|
2329
2675
|
Perform ensemble and swap operations on the full sparse embedding tables.
|
|
2330
2676
|
|
|
@@ -2372,7 +2718,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2372
2718
|
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
2373
2719
|
return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
|
|
2374
2720
|
|
|
2375
|
-
def _get_uvm_cache_print_state(self, use_local_cache: bool = False) ->
|
|
2721
|
+
def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
|
|
2376
2722
|
snapshot = self.get_uvm_cache_stats(use_local_cache)
|
|
2377
2723
|
if use_local_cache:
|
|
2378
2724
|
return snapshot.tolist()
|
|
@@ -2385,7 +2731,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2385
2731
|
@torch.jit.ignore
|
|
2386
2732
|
def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
|
|
2387
2733
|
# TODO: Create a separate reporter class to unify the stdlog reporting
|
|
2388
|
-
uvm_cache_stats:
|
|
2734
|
+
uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
|
|
2389
2735
|
N = max(1, uvm_cache_stats[0])
|
|
2390
2736
|
m = {
|
|
2391
2737
|
"N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
|
|
@@ -2429,14 +2775,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2429
2775
|
if not stats_reporter.should_report(self.step):
|
|
2430
2776
|
return
|
|
2431
2777
|
|
|
2432
|
-
uvm_cache_stats:
|
|
2778
|
+
uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
|
|
2433
2779
|
use_local_cache=False
|
|
2434
2780
|
).tolist()
|
|
2435
2781
|
self.last_reported_step = self.step
|
|
2436
2782
|
|
|
2437
2783
|
if len(self.last_reported_uvm_stats) == 0:
|
|
2438
2784
|
self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
|
|
2439
|
-
uvm_cache_stats_delta:
|
|
2785
|
+
uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
|
|
2440
2786
|
for i in range(len(uvm_cache_stats)):
|
|
2441
2787
|
uvm_cache_stats_delta[i] = (
|
|
2442
2788
|
uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
|
|
@@ -2465,7 +2811,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2465
2811
|
indices: Tensor,
|
|
2466
2812
|
offsets: Tensor,
|
|
2467
2813
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
2468
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2814
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2469
2815
|
) -> None:
|
|
2470
2816
|
if self.prefetch_stream is None and forward_stream is not None:
|
|
2471
2817
|
self.prefetch_stream = torch.cuda.current_stream()
|
|
@@ -2473,20 +2819,21 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2473
2819
|
self.prefetch_stream != forward_stream
|
|
2474
2820
|
), "prefetch_stream and forward_stream should not be the same stream"
|
|
2475
2821
|
|
|
2476
|
-
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
2477
|
-
indices,
|
|
2478
|
-
offsets,
|
|
2479
|
-
per_sample_weights=None,
|
|
2480
|
-
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2481
|
-
force_cast_input_types=False,
|
|
2482
|
-
prefetch_pipeline=self.prefetch_pipeline,
|
|
2483
|
-
)
|
|
2484
|
-
|
|
2485
2822
|
with self._recording_to_timer(
|
|
2486
2823
|
self.prefetch_duration_timer,
|
|
2487
2824
|
context=self.step,
|
|
2488
2825
|
stream=torch.cuda.current_stream(),
|
|
2489
2826
|
):
|
|
2827
|
+
|
|
2828
|
+
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
2829
|
+
indices,
|
|
2830
|
+
offsets,
|
|
2831
|
+
per_sample_weights=None,
|
|
2832
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2833
|
+
force_cast_input_types=False,
|
|
2834
|
+
prefetch_pipeline=self.prefetch_pipeline,
|
|
2835
|
+
)
|
|
2836
|
+
|
|
2490
2837
|
self._prefetch(
|
|
2491
2838
|
indices,
|
|
2492
2839
|
offsets,
|
|
@@ -2503,6 +2850,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2503
2850
|
offsets: Tensor,
|
|
2504
2851
|
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
2505
2852
|
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
2853
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2854
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2506
2855
|
) -> None:
|
|
2507
2856
|
if not is_torchdynamo_compiling():
|
|
2508
2857
|
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
@@ -2521,7 +2870,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2521
2870
|
self.local_uvm_cache_stats.zero_()
|
|
2522
2871
|
self._report_io_size_count("prefetch_input", indices)
|
|
2523
2872
|
|
|
2873
|
+
# streaming before updating the cache
|
|
2874
|
+
self.raw_embedding_stream()
|
|
2875
|
+
|
|
2524
2876
|
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
|
|
2877
|
+
linear_cache_indices_merged = torch.zeros(
|
|
2878
|
+
0, dtype=indices.dtype, device=indices.device
|
|
2879
|
+
)
|
|
2525
2880
|
for (
|
|
2526
2881
|
partial_indices,
|
|
2527
2882
|
partial_lxu_cache_locations,
|
|
@@ -2537,6 +2892,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2537
2892
|
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
2538
2893
|
base_offset,
|
|
2539
2894
|
)
|
|
2895
|
+
linear_cache_indices_merged = torch.cat(
|
|
2896
|
+
[linear_cache_indices_merged, linear_cache_indices]
|
|
2897
|
+
)
|
|
2540
2898
|
|
|
2541
2899
|
if (
|
|
2542
2900
|
self.record_cache_metrics.record_cache_miss_counter
|
|
@@ -2617,6 +2975,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2617
2975
|
if self.should_log():
|
|
2618
2976
|
self.print_uvm_cache_stats(use_local_cache=False)
|
|
2619
2977
|
|
|
2978
|
+
self._store_prefetched_tensors(
|
|
2979
|
+
indices,
|
|
2980
|
+
offsets,
|
|
2981
|
+
vbe_metadata,
|
|
2982
|
+
linear_cache_indices_merged,
|
|
2983
|
+
final_lxu_cache_locations,
|
|
2984
|
+
hash_zch_identities,
|
|
2985
|
+
hash_zch_runtime_meta,
|
|
2986
|
+
)
|
|
2987
|
+
|
|
2620
2988
|
def should_log(self) -> bool:
|
|
2621
2989
|
"""Determines if we should log for this step, using exponentially decreasing frequency.
|
|
2622
2990
|
|
|
@@ -2701,12 +3069,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2701
3069
|
tmp_emb.uniform_(min_val, max_val)
|
|
2702
3070
|
tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
|
|
2703
3071
|
emb.data.copy_(tmp_emb_i8)
|
|
3072
|
+
# Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
|
|
3073
|
+
elif self.weights_precision == SparseType.NFP8:
|
|
3074
|
+
assert (
|
|
3075
|
+
self.current_device.type == "cuda"
|
|
3076
|
+
), "NFP8 is currently only supportd on GPU."
|
|
3077
|
+
assert self.optimizer in [
|
|
3078
|
+
OptimType.EXACT_ADAGRAD,
|
|
3079
|
+
OptimType.ROWWISE_ADAGRAD,
|
|
3080
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
3081
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
3082
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
3083
|
+
], "NFP8 is currently only supportd with adagrad optimizers."
|
|
3084
|
+
for param in splits:
|
|
3085
|
+
tmp_param = torch.zeros(param.shape, device=self.current_device)
|
|
3086
|
+
# Create initialized weights and cast to fp8.
|
|
3087
|
+
fp8_dtype = (
|
|
3088
|
+
torch.float8_e4m3fnuz
|
|
3089
|
+
if torch.version.hip is not None
|
|
3090
|
+
else torch.float8_e4m3fn
|
|
3091
|
+
)
|
|
3092
|
+
tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
|
|
3093
|
+
param.data.copy_(tmp_param)
|
|
2704
3094
|
else:
|
|
2705
3095
|
for param in splits:
|
|
2706
3096
|
param.uniform_(min_val, max_val)
|
|
2707
3097
|
|
|
2708
3098
|
@torch.jit.ignore
|
|
2709
|
-
def split_embedding_weights(self) ->
|
|
3099
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
2710
3100
|
"""
|
|
2711
3101
|
Returns a list of embedding weights (view), split by table
|
|
2712
3102
|
|
|
@@ -2748,7 +3138,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2748
3138
|
raise ValueError(f"Optimizer buffer {state} not found")
|
|
2749
3139
|
|
|
2750
3140
|
@torch.jit.export
|
|
2751
|
-
def get_optimizer_state(self) ->
|
|
3141
|
+
def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
|
|
2752
3142
|
r"""
|
|
2753
3143
|
Get the optimizer state dict that matches the OSS Pytorch optims
|
|
2754
3144
|
TODO: populate the supported list of optimizers
|
|
@@ -2832,7 +3222,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2832
3222
|
@torch.jit.ignore
|
|
2833
3223
|
def split_optimizer_states(
|
|
2834
3224
|
self,
|
|
2835
|
-
) ->
|
|
3225
|
+
) -> list[list[torch.Tensor]]:
|
|
2836
3226
|
"""
|
|
2837
3227
|
Returns a list of optimizer states (view), split by table
|
|
2838
3228
|
|
|
@@ -2880,7 +3270,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2880
3270
|
state_offsets: Tensor,
|
|
2881
3271
|
state_placements: Tensor,
|
|
2882
3272
|
rowwise: bool,
|
|
2883
|
-
) ->
|
|
3273
|
+
) -> list[torch.Tensor]:
|
|
2884
3274
|
splits = []
|
|
2885
3275
|
for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
|
|
2886
3276
|
offset = state_offsets[t]
|
|
@@ -2899,7 +3289,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2899
3289
|
splits.append(state.detach()[offset : offset + rows].view(rows))
|
|
2900
3290
|
return splits
|
|
2901
3291
|
|
|
2902
|
-
states:
|
|
3292
|
+
states: list[list[torch.Tensor]] = []
|
|
2903
3293
|
if self.optimizer not in (OptimType.EXACT_SGD,):
|
|
2904
3294
|
states.append(
|
|
2905
3295
|
get_optimizer_states(
|
|
@@ -3025,7 +3415,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3025
3415
|
return self.learning_rate_tensor.item()
|
|
3026
3416
|
|
|
3027
3417
|
@torch.jit.ignore
|
|
3028
|
-
def update_hyper_parameters(self, params_dict:
|
|
3418
|
+
def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
|
|
3029
3419
|
"""
|
|
3030
3420
|
Sets hyper-parameters from external control flow.
|
|
3031
3421
|
|
|
@@ -3101,10 +3491,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3101
3491
|
self,
|
|
3102
3492
|
split: SplitState,
|
|
3103
3493
|
prefix: str,
|
|
3104
|
-
dtype:
|
|
3494
|
+
dtype: type[torch.dtype],
|
|
3105
3495
|
enforce_hbm: bool = False,
|
|
3106
3496
|
make_dev_param: bool = False,
|
|
3107
|
-
dev_reshape: Optional[
|
|
3497
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
3108
3498
|
uvm_host_mapped: bool = False,
|
|
3109
3499
|
) -> None:
|
|
3110
3500
|
apply_split_helper(
|
|
@@ -3154,6 +3544,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3154
3544
|
dtype = torch.float32
|
|
3155
3545
|
elif cache_precision == SparseType.FP16:
|
|
3156
3546
|
dtype = torch.float16
|
|
3547
|
+
elif cache_precision == SparseType.NFP8:
|
|
3548
|
+
# NFP8 weights use floating point cache.
|
|
3549
|
+
dtype = torch.float16
|
|
3157
3550
|
else:
|
|
3158
3551
|
dtype = torch.float32 # not relevant, but setting it to keep linter happy
|
|
3159
3552
|
if not self.use_cpu > 0:
|
|
@@ -3347,7 +3740,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3347
3740
|
def _update_cache_counter_and_locations(
|
|
3348
3741
|
self,
|
|
3349
3742
|
module: nn.Module,
|
|
3350
|
-
grad_input: Union[
|
|
3743
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
3351
3744
|
) -> None:
|
|
3352
3745
|
"""
|
|
3353
3746
|
Backward prehook function when prefetch_pipeline is enabled.
|
|
@@ -3543,10 +3936,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3543
3936
|
indices: Tensor,
|
|
3544
3937
|
offsets: Tensor,
|
|
3545
3938
|
per_sample_weights: Optional[Tensor] = None,
|
|
3546
|
-
batch_size_per_feature_per_rank: Optional[
|
|
3939
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3547
3940
|
force_cast_input_types: bool = True,
|
|
3548
3941
|
prefetch_pipeline: bool = False,
|
|
3549
|
-
|
|
3942
|
+
vbe_output: Optional[Tensor] = None,
|
|
3943
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3944
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3550
3945
|
"""
|
|
3551
3946
|
Prepare TBE inputs as follows:
|
|
3552
3947
|
|
|
@@ -3572,9 +3967,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3572
3967
|
metadata
|
|
3573
3968
|
"""
|
|
3574
3969
|
|
|
3970
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
3971
|
+
assert (
|
|
3972
|
+
not self.use_cpu
|
|
3973
|
+
), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
|
|
3974
|
+
check_allocated_vbe_output(
|
|
3975
|
+
self.output_dtype,
|
|
3976
|
+
batch_size_per_feature_per_rank,
|
|
3977
|
+
vbe_output,
|
|
3978
|
+
vbe_output_offsets,
|
|
3979
|
+
)
|
|
3980
|
+
|
|
3575
3981
|
# Generate VBE metadata
|
|
3576
3982
|
vbe_metadata = self._generate_vbe_metadata(
|
|
3577
|
-
offsets, batch_size_per_feature_per_rank
|
|
3983
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
3578
3984
|
)
|
|
3579
3985
|
|
|
3580
3986
|
vbe = vbe_metadata.B_offsets is not None
|
|
@@ -3647,7 +4053,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3647
4053
|
self.is_nobag,
|
|
3648
4054
|
vbe_metadata.max_B_feature_rank,
|
|
3649
4055
|
self.info_B_num_bits,
|
|
3650
|
-
offsets.numel() - 1, # total_B
|
|
4056
|
+
offsets.numel() - 1, # total_B,
|
|
4057
|
+
vbe_output_offsets,
|
|
3651
4058
|
)
|
|
3652
4059
|
else:
|
|
3653
4060
|
b_t_map = None
|
|
@@ -3736,7 +4143,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3736
4143
|
# Counts of indices that segment lengths > 1024
|
|
3737
4144
|
counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
|
|
3738
4145
|
|
|
3739
|
-
def compute_numel_and_avg(counts: Tensor) ->
|
|
4146
|
+
def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
|
|
3740
4147
|
numel = counts.numel()
|
|
3741
4148
|
avg = (counts.sum().item() / numel) if numel != 0 else -1.0
|
|
3742
4149
|
return numel, avg
|
|
@@ -3804,6 +4211,240 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3804
4211
|
return _debug_print_input_stats_factory_impl
|
|
3805
4212
|
return _debug_print_input_stats_factory_null
|
|
3806
4213
|
|
|
4214
|
+
@torch.jit.ignore
|
|
4215
|
+
def raw_embedding_stream(self) -> None:
|
|
4216
|
+
if not self.enable_raw_embedding_streaming:
|
|
4217
|
+
return None
|
|
4218
|
+
# when pipelining is enabled
|
|
4219
|
+
# prefetch in iter i happens before the backward sparse in iter i - 1
|
|
4220
|
+
# so embeddings for iter i - 1's changed ids are not updated.
|
|
4221
|
+
# so we can only fetch the indices from the iter i - 2
|
|
4222
|
+
# when pipelining is disabled
|
|
4223
|
+
# prefetch in iter i happens before forward iter i
|
|
4224
|
+
# so we can get the iter i - 1's changed ids safely.
|
|
4225
|
+
target_prev_iter = 1
|
|
4226
|
+
if self.prefetch_pipeline:
|
|
4227
|
+
target_prev_iter = 2
|
|
4228
|
+
if not len(self.prefetched_info_list) > (target_prev_iter - 1):
|
|
4229
|
+
return None
|
|
4230
|
+
with record_function(
|
|
4231
|
+
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4232
|
+
):
|
|
4233
|
+
prefetched_info = self.prefetched_info_list.pop(0)
|
|
4234
|
+
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4235
|
+
prefetched_info.linear_unique_cache_indices,
|
|
4236
|
+
self.lxu_cache_state,
|
|
4237
|
+
self.total_cache_hash_size,
|
|
4238
|
+
gather_cache_stats=False, # not collecting cache stats
|
|
4239
|
+
num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
|
|
4240
|
+
)
|
|
4241
|
+
updated_weights = torch.empty(
|
|
4242
|
+
[
|
|
4243
|
+
prefetched_info.linear_unique_cache_indices.size()[0],
|
|
4244
|
+
self.max_D_cache,
|
|
4245
|
+
],
|
|
4246
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
|
|
4247
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
4248
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
|
|
4249
|
+
device=self.lxu_cache_weights.device,
|
|
4250
|
+
)
|
|
4251
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4252
|
+
updated_weights,
|
|
4253
|
+
updated_locations,
|
|
4254
|
+
self.lxu_cache_weights,
|
|
4255
|
+
prefetched_info.linear_unique_indices_length,
|
|
4256
|
+
)
|
|
4257
|
+
# TODO: this statement triggers a sync
|
|
4258
|
+
# added here to make this diff self-contained
|
|
4259
|
+
# will remove in later change
|
|
4260
|
+
cache_hit_mask_index = (
|
|
4261
|
+
updated_locations.narrow(
|
|
4262
|
+
0, 0, prefetched_info.linear_unique_indices_length.item()
|
|
4263
|
+
)
|
|
4264
|
+
.not_equal(-1)
|
|
4265
|
+
.nonzero()
|
|
4266
|
+
.flatten()
|
|
4267
|
+
)
|
|
4268
|
+
# stream weights
|
|
4269
|
+
self._raw_embedding_streamer.stream(
|
|
4270
|
+
prefetched_info.linear_unique_indices.index_select(
|
|
4271
|
+
dim=0, index=cache_hit_mask_index
|
|
4272
|
+
).to(device=torch.device("cpu")),
|
|
4273
|
+
updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
|
|
4274
|
+
device=torch.device("cpu")
|
|
4275
|
+
),
|
|
4276
|
+
(
|
|
4277
|
+
prefetched_info.hash_zch_identities.index_select(
|
|
4278
|
+
dim=0, index=cache_hit_mask_index
|
|
4279
|
+
).to(device=torch.device("cpu"))
|
|
4280
|
+
if prefetched_info.hash_zch_identities is not None
|
|
4281
|
+
else None
|
|
4282
|
+
),
|
|
4283
|
+
(
|
|
4284
|
+
prefetched_info.hash_zch_runtime_meta.index_select(
|
|
4285
|
+
dim=0, index=cache_hit_mask_index
|
|
4286
|
+
).to(device=torch.device("cpu"))
|
|
4287
|
+
if prefetched_info.hash_zch_runtime_meta is not None
|
|
4288
|
+
else None
|
|
4289
|
+
),
|
|
4290
|
+
prefetched_info.linear_unique_indices_length.to(
|
|
4291
|
+
device=torch.device("cpu")
|
|
4292
|
+
),
|
|
4293
|
+
False, # require_tensor_copy
|
|
4294
|
+
False, # blocking_tensor_copy
|
|
4295
|
+
)
|
|
4296
|
+
|
|
4297
|
+
@staticmethod
|
|
4298
|
+
@torch.jit.ignore
|
|
4299
|
+
def _get_prefetched_info(
|
|
4300
|
+
linear_indices: torch.Tensor,
|
|
4301
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4302
|
+
total_cache_hash_size: int,
|
|
4303
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4304
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4305
|
+
max_indices_length: int,
|
|
4306
|
+
) -> PrefetchedInfo:
|
|
4307
|
+
(
|
|
4308
|
+
linear_unique_cache_indices,
|
|
4309
|
+
linear_unique_cache_indices_length,
|
|
4310
|
+
linear_unique_cache_indices_cnt,
|
|
4311
|
+
linear_unique_cache_inverse_indices,
|
|
4312
|
+
) = torch.ops.fbgemm.get_unique_indices_with_inverse(
|
|
4313
|
+
linear_cache_indices_merged,
|
|
4314
|
+
total_cache_hash_size,
|
|
4315
|
+
compute_count=True,
|
|
4316
|
+
compute_inverse_indices=True,
|
|
4317
|
+
)
|
|
4318
|
+
# pure cpu op, no need to sync, to avoid the indices out size the weights buffer
|
|
4319
|
+
max_len = min(
|
|
4320
|
+
max_indices_length,
|
|
4321
|
+
linear_unique_cache_indices.size(0),
|
|
4322
|
+
)
|
|
4323
|
+
if max_len < linear_unique_cache_indices.size(0):
|
|
4324
|
+
linear_unique_cache_indices_length.clamp_(max=max_len)
|
|
4325
|
+
# linear_unique_indices is the result after deduplication and sorting
|
|
4326
|
+
linear_unique_cache_indices = linear_unique_cache_indices.narrow(
|
|
4327
|
+
0, 0, max_len
|
|
4328
|
+
)
|
|
4329
|
+
# Compute cumulative sum as indices for selecting unique elements to
|
|
4330
|
+
# map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
|
|
4331
|
+
count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
|
|
4332
|
+
linear_unique_cache_indices_cnt
|
|
4333
|
+
)
|
|
4334
|
+
# count_cum_sum will be one more element than linear_unique_cache_indices_cnt
|
|
4335
|
+
count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
|
|
4336
|
+
# clamp the uninitialized elements to avoid out of bound access
|
|
4337
|
+
# the uninitialized elements will be sliced out by linear_unique_cache_indices_length
|
|
4338
|
+
# directly using linear_unique_cache_indices_length requires a sync
|
|
4339
|
+
count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
|
|
4340
|
+
|
|
4341
|
+
# Select indices corresponding to first occurrence of each unique element
|
|
4342
|
+
linear_unique_inverse_indices = (
|
|
4343
|
+
linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
|
|
4344
|
+
)
|
|
4345
|
+
# same as above clamp
|
|
4346
|
+
linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
|
|
4347
|
+
linear_unique_indices = linear_indices.index_select(
|
|
4348
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4349
|
+
)
|
|
4350
|
+
if hash_zch_identities is not None:
|
|
4351
|
+
# Map hash_zch_identities to unique indices
|
|
4352
|
+
hash_zch_identities = hash_zch_identities.index_select(
|
|
4353
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4354
|
+
)
|
|
4355
|
+
|
|
4356
|
+
if hash_zch_runtime_meta is not None:
|
|
4357
|
+
# Map hash_zch_runtime_meta to unique indices
|
|
4358
|
+
hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
|
|
4359
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4360
|
+
)
|
|
4361
|
+
|
|
4362
|
+
return PrefetchedInfo(
|
|
4363
|
+
linear_unique_indices,
|
|
4364
|
+
linear_unique_cache_indices,
|
|
4365
|
+
linear_unique_cache_indices_length,
|
|
4366
|
+
hash_zch_identities,
|
|
4367
|
+
hash_zch_runtime_meta,
|
|
4368
|
+
)
|
|
4369
|
+
|
|
4370
|
+
@torch.jit.ignore
|
|
4371
|
+
def _store_prefetched_tensors(
|
|
4372
|
+
self,
|
|
4373
|
+
indices: torch.Tensor,
|
|
4374
|
+
offsets: torch.Tensor,
|
|
4375
|
+
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
|
|
4376
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4377
|
+
final_lxu_cache_locations: torch.Tensor,
|
|
4378
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4379
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4380
|
+
) -> None:
|
|
4381
|
+
"""
|
|
4382
|
+
NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
|
|
4383
|
+
This function stores the prefetched tensors for the raw embedding streaming.
|
|
4384
|
+
"""
|
|
4385
|
+
if not self.enable_raw_embedding_streaming:
|
|
4386
|
+
return
|
|
4387
|
+
|
|
4388
|
+
with record_function(
|
|
4389
|
+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4390
|
+
):
|
|
4391
|
+
found_in_cache_mask = final_lxu_cache_locations != -1
|
|
4392
|
+
# only process the indices that are found in the cache
|
|
4393
|
+
# this will filter out the indices from tables that doesn't have UVM_CACHE enabled
|
|
4394
|
+
linear_cache_indices_merged_masked = torch.where(
|
|
4395
|
+
found_in_cache_mask,
|
|
4396
|
+
linear_cache_indices_merged,
|
|
4397
|
+
self.total_cache_hash_size,
|
|
4398
|
+
)
|
|
4399
|
+
linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4400
|
+
self.hash_size_cumsum,
|
|
4401
|
+
indices,
|
|
4402
|
+
offsets,
|
|
4403
|
+
vbe_metadata.B_offsets if vbe_metadata is not None else None,
|
|
4404
|
+
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
4405
|
+
)
|
|
4406
|
+
# -1 indices are ignored in raw_embedding_streamer.
|
|
4407
|
+
linearize_indices_masked = torch.where(
|
|
4408
|
+
found_in_cache_mask,
|
|
4409
|
+
linearize_indices,
|
|
4410
|
+
-1,
|
|
4411
|
+
)
|
|
4412
|
+
# Process hash_zch_identities using helper function
|
|
4413
|
+
prefetched_info = self._get_prefetched_info(
|
|
4414
|
+
linearize_indices_masked,
|
|
4415
|
+
linear_cache_indices_merged_masked,
|
|
4416
|
+
self.total_cache_hash_size,
|
|
4417
|
+
hash_zch_identities,
|
|
4418
|
+
hash_zch_runtime_meta,
|
|
4419
|
+
self.lxu_cache_weights.size(0),
|
|
4420
|
+
)
|
|
4421
|
+
|
|
4422
|
+
self.prefetched_info_list.append(prefetched_info)
|
|
4423
|
+
|
|
4424
|
+
@torch.jit.ignore
|
|
4425
|
+
def __report_input_params_factory(
|
|
4426
|
+
self,
|
|
4427
|
+
) -> Optional[Callable[..., None]]:
|
|
4428
|
+
"""
|
|
4429
|
+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
|
|
4430
|
+
|
|
4431
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
|
|
4432
|
+
- Reports input parameters (TBEDataConfig).
|
|
4433
|
+
- Writes the output as a JSON file.
|
|
4434
|
+
|
|
4435
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
|
|
4436
|
+
"""
|
|
4437
|
+
try:
|
|
4438
|
+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
|
|
4439
|
+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
|
|
4440
|
+
|
|
4441
|
+
reporter = TBEBenchmarkParamsReporter.create()
|
|
4442
|
+
return reporter.report_stats
|
|
4443
|
+
except Exception:
|
|
4444
|
+
return None
|
|
4445
|
+
|
|
4446
|
+
return None
|
|
4447
|
+
|
|
3807
4448
|
|
|
3808
4449
|
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
3809
4450
|
"""
|
|
@@ -3817,12 +4458,12 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3817
4458
|
max_D: int
|
|
3818
4459
|
hash_size_cumsum: Tensor
|
|
3819
4460
|
total_hash_size_bits: int
|
|
3820
|
-
embedding_specs:
|
|
4461
|
+
embedding_specs: list[tuple[int, int]]
|
|
3821
4462
|
|
|
3822
4463
|
def __init__(
|
|
3823
4464
|
self,
|
|
3824
|
-
embedding_specs:
|
|
3825
|
-
feature_table_map: Optional[
|
|
4465
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
4466
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
3826
4467
|
weights_precision: SparseType = SparseType.FP32,
|
|
3827
4468
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
3828
4469
|
use_cpu: bool = False,
|
|
@@ -3865,7 +4506,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3865
4506
|
)
|
|
3866
4507
|
|
|
3867
4508
|
self.embedding_specs = embedding_specs
|
|
3868
|
-
|
|
4509
|
+
rows, dims = zip(*embedding_specs)
|
|
3869
4510
|
T_ = len(self.embedding_specs)
|
|
3870
4511
|
assert T_ > 0
|
|
3871
4512
|
|
|
@@ -3935,7 +4576,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3935
4576
|
row for (row, _) in embedding_specs[:t]
|
|
3936
4577
|
)
|
|
3937
4578
|
|
|
3938
|
-
self.weights_physical_offsets:
|
|
4579
|
+
self.weights_physical_offsets: list[int] = weights_offsets
|
|
3939
4580
|
weights_offsets = [weights_offsets[t] for t in feature_table_map]
|
|
3940
4581
|
self.register_buffer(
|
|
3941
4582
|
"weights_offsets",
|
|
@@ -3962,7 +4603,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3962
4603
|
def _generate_vbe_metadata(
|
|
3963
4604
|
self,
|
|
3964
4605
|
offsets: Tensor,
|
|
3965
|
-
batch_size_per_feature_per_rank: Optional[
|
|
4606
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
3966
4607
|
) -> invokers.lookup_args.VBEMetadata:
|
|
3967
4608
|
# Blocking D2H copy, but only runs at first call
|
|
3968
4609
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -3980,7 +4621,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3980
4621
|
offsets: Tensor,
|
|
3981
4622
|
per_sample_weights: Optional[Tensor] = None,
|
|
3982
4623
|
feature_requires_grad: Optional[Tensor] = None,
|
|
3983
|
-
batch_size_per_feature_per_rank: Optional[
|
|
4624
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3984
4625
|
) -> Tensor:
|
|
3985
4626
|
# Generate VBE metadata
|
|
3986
4627
|
vbe_metadata = self._generate_vbe_metadata(
|
|
@@ -4019,7 +4660,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
4019
4660
|
)
|
|
4020
4661
|
|
|
4021
4662
|
@torch.jit.export
|
|
4022
|
-
def split_embedding_weights(self) ->
|
|
4663
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
4023
4664
|
"""
|
|
4024
4665
|
Returns a list of weights, split by table
|
|
4025
4666
|
"""
|