fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,4600 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+
11
+ import contextlib
12
+ import enum
13
+ import functools
14
+ import logging
15
+ import math
16
+ import os
17
+ import uuid
18
+ from dataclasses import dataclass, field
19
+ from itertools import accumulate
20
+ from math import log2
21
+ from typing import Any, Callable, Optional, Union
22
+
23
+ import torch # usort:skip
24
+ from torch import nn, Tensor # usort:skip
25
+ from torch.autograd.profiler import record_function # usort:skip
26
+
27
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
28
+ import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
29
+
30
+ from fbgemm_gpu.config import FeatureGate, FeatureGateName
31
+ from fbgemm_gpu.runtime_monitor import (
32
+ AsyncSeriesTimer,
33
+ TBEStatsReporter,
34
+ TBEStatsReporterConfig,
35
+ )
36
+ from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
37
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
38
+ BoundsCheckMode,
39
+ CacheAlgorithm,
40
+ CacheState,
41
+ ComputeDevice,
42
+ construct_cache_state,
43
+ EmbeddingLocation,
44
+ get_bounds_check_version_for_platform,
45
+ MAX_PREFETCH_DEPTH,
46
+ MultiPassPrefetchConfig,
47
+ PoolingMode,
48
+ RecordCacheMetrics,
49
+ SplitState,
50
+ )
51
+ from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
52
+ generate_vbe_metadata,
53
+ is_torchdynamo_compiling,
54
+ )
55
+ from fbgemm_gpu.tbe_input_multiplexer import (
56
+ TBEInfo,
57
+ TBEInputInfo,
58
+ TBEInputMultiplexer,
59
+ TBEInputMultiplexerConfig,
60
+ )
61
+
62
+ from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
63
+
64
+ try:
65
+ load_torch_module(
66
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu",
67
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training",
68
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training",
69
+ )
70
+ load_torch_module_bc(
71
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_cpu",
72
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training",
73
+ )
74
+ except Exception:
75
+ pass
76
+
77
+
78
+ DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
79
+ INT8_EMB_ROW_DIM_OFFSET = 8
80
+
81
+
82
+ class DoesNotHavePrefix(Exception):
83
+ pass
84
+
85
+
86
+ class WeightDecayMode(enum.IntEnum):
87
+ NONE = 0
88
+ L2 = 1
89
+ DECOUPLE = 2
90
+ COUNTER = 3
91
+ COWCLIP = 4
92
+ DECOUPLE_GLOBAL = 5
93
+
94
+
95
+ class CounterWeightDecayMode(enum.IntEnum):
96
+ NONE = 0
97
+ L2 = 1
98
+ DECOUPLE = 2
99
+ ADAGRADW = 3
100
+
101
+
102
+ class StepMode(enum.IntEnum):
103
+ NONE = 0
104
+ USE_COUNTER = 1
105
+ USE_ITER = 2
106
+
107
+
108
+ class LearningRateMode(enum.IntEnum):
109
+ EQUAL = -1
110
+ TAIL_ID_LR_INCREASE = 0
111
+ TAIL_ID_LR_DECREASE = 1
112
+ COUNTER_SGD = 2
113
+
114
+
115
+ class GradSumDecay(enum.IntEnum):
116
+ NO_DECAY = -1
117
+ CTR_DECAY = 0
118
+
119
+
120
+ @dataclass(frozen=True)
121
+ class TailIdThreshold:
122
+ val: float = 0
123
+ is_ratio: bool = False
124
+
125
+
126
+ @dataclass(frozen=True)
127
+ class CounterBasedRegularizationDefinition:
128
+ counter_weight_decay_mode: CounterWeightDecayMode = CounterWeightDecayMode.NONE
129
+ counter_halflife: int = -1
130
+ adjustment_iter: int = -1
131
+ adjustment_ub: float = 1.0
132
+ learning_rate_mode: LearningRateMode = LearningRateMode.EQUAL
133
+ grad_sum_decay: GradSumDecay = GradSumDecay.NO_DECAY
134
+ tail_id_threshold: TailIdThreshold = field(default_factory=TailIdThreshold)
135
+ max_counter_update_freq: int = 1000
136
+
137
+
138
+ @dataclass(frozen=True)
139
+ class CowClipDefinition:
140
+ counter_weight_decay_mode: CounterWeightDecayMode = CounterWeightDecayMode.NONE
141
+ counter_halflife: int = -1
142
+ weight_norm_coefficient: float = 0.0
143
+ lower_bound: float = 0.0
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class GlobalWeightDecayDefinition:
148
+ start_iter: int = 0
149
+ lower_bound: float = 0.0
150
+
151
+
152
+ @dataclass(frozen=True)
153
+ class UserEnabledConfigDefinition:
154
+ """
155
+ This class is used to configure whether certain modes are to be enabled
156
+ """
157
+
158
+ # This is used in Adam to perform rowwise bias correction using `row_counter`
159
+ # More details can be found in D64848802.
160
+ use_rowwise_bias_correction: bool = False
161
+ use_writeback_bwd_prehook: bool = False
162
+
163
+
164
+ @dataclass(frozen=True)
165
+ class EnsembleModeDefinition:
166
+ step_ema: float = 10000
167
+ step_swap: float = 10000
168
+ step_start: float = 0
169
+ step_ema_coef: float = 0.6
170
+ step_mode: StepMode = StepMode.USE_ITER
171
+
172
+
173
+ @dataclass(frozen=True)
174
+ class EmainplaceModeDefinition:
175
+ step_ema: float = 10
176
+ step_start: float = 0
177
+ step_ema_coef: float = 0.6
178
+
179
+
180
+ # Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
181
+ class UVMCacheStatsIndex(enum.IntEnum):
182
+ num_calls = 0
183
+ num_requested_indices = 1
184
+ num_unique_indices = 2
185
+ num_unique_misses = 3
186
+ num_conflict_unique_misses = 4
187
+ num_conflict_misses = 5
188
+
189
+
190
+ @dataclass
191
+ class RESParams:
192
+ res_server_port: int = 0 # the port of the res server
193
+ res_store_shards: int = 1 # the number of shards to store the raw embeddings
194
+ table_names: list[str] = field(default_factory=list) # table names the TBE holds
195
+ table_offsets: list[int] = field(
196
+ default_factory=list
197
+ ) # table offsets for the global rows the TBE holds
198
+ table_sizes: list[int] = field(
199
+ default_factory=list
200
+ ) # table sizes for the global rows the TBE holds
201
+
202
+
203
+ class PrefetchedInfo:
204
+ """
205
+ Container for prefetched cache information.
206
+
207
+ This class is explicitly defined (not using @dataclass) to be compatible with
208
+ TorchScript's inspect.getsource() requirements.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ linear_unique_indices: torch.Tensor,
214
+ linear_unique_cache_indices: torch.Tensor,
215
+ linear_unique_indices_length: torch.Tensor,
216
+ hash_zch_identities: Optional[torch.Tensor],
217
+ hash_zch_runtime_meta: Optional[torch.Tensor],
218
+ ) -> None:
219
+ self.linear_unique_indices = linear_unique_indices
220
+ self.linear_unique_cache_indices = linear_unique_cache_indices
221
+ self.linear_unique_indices_length = linear_unique_indices_length
222
+ self.hash_zch_identities = hash_zch_identities
223
+ self.hash_zch_runtime_meta = hash_zch_runtime_meta
224
+
225
+
226
+ def construct_split_state(
227
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
228
+ rowwise: bool,
229
+ cacheable: bool,
230
+ precision: SparseType = SparseType.FP32,
231
+ int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
232
+ placement: Optional[EmbeddingLocation] = None,
233
+ ) -> SplitState:
234
+ placements: list[EmbeddingLocation] = []
235
+ offsets: list[int] = []
236
+ dev_size: int = 0
237
+ host_size: int = 0
238
+ uvm_size: int = 0
239
+ for num_embeddings, embedding_dim, location, _ in embedding_specs:
240
+ assert (
241
+ embedding_dim % 4 == 0
242
+ ), f"embedding_dim must be a multiple of 4, but got {embedding_dim}"
243
+ if precision == SparseType.INT8:
244
+ embedding_dim += int8_emb_row_dim_offset
245
+ state_size = num_embeddings * embedding_dim if not rowwise else num_embeddings
246
+ location = placement if placement is not None else location
247
+ if location == EmbeddingLocation.HOST:
248
+ placements.append(EmbeddingLocation.HOST)
249
+ offsets.append(host_size)
250
+ host_size += state_size
251
+ # If table is on device, then opimtizer is on device.
252
+ # If table is managed, then if optimizer state is rowwise, optimizer is on device, otherwise optimizer is managed.
253
+ elif location == EmbeddingLocation.DEVICE or rowwise:
254
+ placements.append(EmbeddingLocation.DEVICE)
255
+ offsets.append(dev_size)
256
+ dev_size += state_size
257
+ else:
258
+ if cacheable and location == EmbeddingLocation.MANAGED_CACHING:
259
+ placements.append(EmbeddingLocation.MANAGED_CACHING)
260
+ else:
261
+ placements.append(EmbeddingLocation.MANAGED)
262
+ offsets.append(uvm_size)
263
+ uvm_size += state_size
264
+ assert len(placements) == len(offsets)
265
+ return SplitState(
266
+ dev_size=dev_size,
267
+ host_size=host_size,
268
+ uvm_size=uvm_size,
269
+ placements=placements,
270
+ offsets=offsets,
271
+ )
272
+
273
+
274
+ def apply_split_helper(
275
+ persistent_state_fn: Callable[[str, Tensor], None],
276
+ set_attr_fn: Callable[
277
+ [str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
278
+ ],
279
+ current_device: torch.device,
280
+ use_cpu: bool,
281
+ feature_table_map: list[int],
282
+ split: SplitState,
283
+ prefix: str,
284
+ dtype: type[torch.dtype],
285
+ enforce_hbm: bool = False,
286
+ make_dev_param: bool = False,
287
+ dev_reshape: Optional[tuple[int, ...]] = None,
288
+ uvm_tensors_log: Optional[list[str]] = None,
289
+ uvm_host_mapped: bool = False,
290
+ ) -> None:
291
+ set_attr_fn(f"{prefix}_physical_placements", split.placements)
292
+ set_attr_fn(f"{prefix}_physical_offsets", split.offsets)
293
+
294
+ offsets = [split.offsets[t] for t in feature_table_map]
295
+ placements = [split.placements[t] for t in feature_table_map]
296
+ persistent_state_fn(
297
+ f"{prefix}_offsets",
298
+ torch.tensor(offsets, device=current_device, dtype=torch.int64),
299
+ )
300
+ persistent_state_fn(
301
+ f"{prefix}_placements",
302
+ torch.tensor(placements, device=current_device, dtype=torch.int32),
303
+ )
304
+ if split.dev_size > 0:
305
+ dev_buffer = torch.zeros(
306
+ split.dev_size,
307
+ device=current_device,
308
+ # pyre-fixme[6]
309
+ dtype=dtype,
310
+ )
311
+ dev_buffer = (
312
+ dev_buffer.view(*dev_reshape) if dev_reshape is not None else dev_buffer
313
+ )
314
+ else:
315
+ # pyre-fixme[6]
316
+ dev_buffer = torch.empty(0, device=current_device, dtype=dtype)
317
+ if make_dev_param:
318
+ set_attr_fn(f"{prefix}_dev", nn.Parameter(dev_buffer))
319
+ else:
320
+ persistent_state_fn(f"{prefix}_dev", dev_buffer)
321
+ if split.host_size > 0:
322
+ if dtype == torch.uint8:
323
+ persistent_state_fn(
324
+ f"{prefix}_host",
325
+ torch.zeros(
326
+ split.host_size,
327
+ device=current_device,
328
+ # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]` for
329
+ # 3rd param but got `Type[Type[torch._dtype]]`.
330
+ dtype=dtype,
331
+ ),
332
+ )
333
+ else:
334
+ set_attr_fn(
335
+ f"{prefix}_host",
336
+ nn.Parameter(
337
+ torch.zeros(
338
+ split.host_size,
339
+ device=current_device,
340
+ # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
341
+ # for 3rd param but got `Type[Type[torch._dtype]]`.
342
+ dtype=dtype,
343
+ )
344
+ ),
345
+ )
346
+ if uvm_tensors_log is not None:
347
+ uvm_tensors_log.append(f"{prefix}_host")
348
+ else:
349
+ persistent_state_fn(
350
+ f"{prefix}_host",
351
+ # pyre-fixme[6]: For 3rd param expected `dtype` but got `Type[dtype]`.
352
+ torch.empty(0, device=current_device, dtype=dtype),
353
+ )
354
+ if split.uvm_size > 0:
355
+ assert not use_cpu
356
+ if enforce_hbm:
357
+ logging.info("Enforce hbm for the cache location")
358
+ persistent_state_fn(
359
+ f"{prefix}_uvm",
360
+ torch.zeros(
361
+ split.uvm_size,
362
+ device=current_device,
363
+ # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]` for
364
+ # 3rd param but got `Type[Type[torch._dtype]]`.
365
+ dtype=dtype,
366
+ ),
367
+ )
368
+ else:
369
+ persistent_state_fn(
370
+ f"{prefix}_uvm",
371
+ torch.zeros(
372
+ split.uvm_size,
373
+ device=current_device,
374
+ out=torch.ops.fbgemm.new_unified_tensor(
375
+ # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
376
+ # for 3rd param but got `Type[Type[torch._dtype]]`.
377
+ torch.zeros(1, device=current_device, dtype=dtype),
378
+ [split.uvm_size],
379
+ is_host_mapped=uvm_host_mapped,
380
+ ),
381
+ ),
382
+ )
383
+ if uvm_tensors_log is not None:
384
+ uvm_tensors_log.append(f"{prefix}_uvm")
385
+ else:
386
+ persistent_state_fn(
387
+ f"{prefix}_uvm",
388
+ # pyre-fixme[6]: For 3rd param expected `dtype` but got `Type[dtype]`.
389
+ torch.empty(0, device=current_device, dtype=dtype),
390
+ )
391
+
392
+
393
+ def get_available_compute_device() -> ComputeDevice:
394
+ if torch.cuda.is_available():
395
+ return ComputeDevice.CUDA
396
+ elif torch.mtia.is_available():
397
+ return ComputeDevice.MTIA
398
+ else:
399
+ return ComputeDevice.CPU
400
+
401
+
402
+ # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
403
+ # pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
404
+ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
405
+ """
406
+ Table Batched Embedding (TBE) operator. Looks up one or more embedding
407
+ tables. The module is application for training. The backward operator is
408
+ fused with optimizer. Thus, the embedding tables are updated during
409
+ backward.
410
+
411
+ Args:
412
+ embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]):
413
+ A list of embedding specifications. Each spec describes a
414
+ specification of a physical embedding table. Each one is a tuple of
415
+ number of embedding rows, embedding dimension (must be a multiple of
416
+ 4), table placement (`EmbeddingLocation`), and compute device
417
+ (`ComputeDevice`).
418
+
419
+ Available `EmbeddingLocation` options are
420
+
421
+ (1) `DEVICE` = placing an embedding table in the GPU global memory
422
+ (HBM)
423
+
424
+ (2) `MANAGED` = placing an embedding in the unified virtual memory
425
+ (accessible from both GPU and CPU)
426
+
427
+ (3) `MANAGED_CACHING` = placing an embedding table in the unified
428
+ virtual memory and using the GPU global memory (HBM) as a cache
429
+
430
+ (4) `HOST` = placing an embedding table in the CPU memory (DRAM)
431
+
432
+ (5) `MTIA` = placing an embedding table in the MTIA memory
433
+
434
+ Available `ComputeDevice` options are
435
+
436
+ (1) `CPU` = performing table lookup on CPU
437
+
438
+ (2) `CUDA` = performing table lookup on GPU
439
+
440
+ (3) `MTIA` = performing table lookup on MTIA
441
+
442
+ feature_table_map (Optional[List[int]] = None): An optional list that
443
+ specifies feature-table mapping. feature_table_map[i] indicates the
444
+ physical embedding table that feature i maps to.
445
+
446
+ cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU): The cache
447
+ algorithm (used when `EmbeddingLocation` is set to
448
+ `MANAGED_CACHING`). Options are
449
+
450
+ (1) `LRU` = least recently used
451
+
452
+ (2) `LFU` = least frequently used
453
+
454
+ cache_load_factor (float = 0.2): A factor used for determining the
455
+ cache capacity when `EmbeddingLocation.MANAGED_CACHING` is used.
456
+ The cache capacity is `cache_load_factor` * the total number of
457
+ rows in all embedding tables
458
+
459
+ cache_sets (int = 0): The number of cache sets (used when
460
+ `EmbeddingLocation` is set to `MANAGED_CACHING`)
461
+
462
+ cache_reserved_memory (float = 0.0): The amount of memory reserved in
463
+ HBM for non-cache purpose (used when `EmbeddingLocation` is set to
464
+ `MANAGED_CACHING`).
465
+
466
+ cache_precision (SparseType = SparseType.FP32): The data type of the
467
+ cache (used when `EmbeddingLocation` is set to `MANAGED_CACHING`).
468
+ Options are `SparseType.FP32` and `SparseType.FP16`
469
+
470
+ weights_precision (SparseType = SparseType.FP32): The data type of
471
+ embedding tables (also known as weights). Options are
472
+ `SparseType.FP32` and `SparseType.FP16`
473
+
474
+ output_dtype (SparseType = SparseType.FP32): The data type of an output
475
+ tensor. Options are `SparseType.FP32` and `SparseType.FP16`
476
+
477
+ enforce_hbm (bool = False): If True, place all weights/momentums in HBM
478
+ when using `EmbeddingLocation.MANAGED_CACHING`
479
+
480
+ optimizer (OptimType = OptimType.EXACT_SGD): An optimizer to use for
481
+ embedding table update in the backward pass. Available `OptimType`
482
+ options are
483
+
484
+ (1) `ADAM` = Adam
485
+
486
+ (2) `EXACT_ADAGRAD` = Adagrad
487
+
488
+ (3) `EXACT_ROWWISE_ADAGRAD` = Rowwise-Aadagrad
489
+
490
+ (4) `EXACT_SGD` = SGD
491
+
492
+ (5) `LAMB` = Lamb
493
+
494
+ (6) `LARS_SGD` = LARS-SGD
495
+
496
+ (7) `PARTIAL_ROWWISE_ADAM` = Partial rowwise-Adam
497
+
498
+ (8) `PARTIAL_ROWWISE_LAMB` = Partial rowwise-Lamb
499
+
500
+ (9) `ENSEMBLE_ROWWISE_ADAGRAD` = Ensemble rowwise-Adagrad
501
+
502
+ (10) `EMAINPLACE_ROWWISE_ADAGRAD` = Ema inplace rowwise-Adagrad
503
+
504
+ (11) `NONE` = Not applying an optimizer update in the backward pass
505
+ and outputting a sparse weight gradient
506
+
507
+ record_cache_metrics (Optional[RecordCacheMetrics] = None): Record
508
+ a number of hits, a number of requests, etc if
509
+ `RecordCacheMetrics.record_cache_miss_counter` is True and record
510
+ the similar metrics table-wise if
511
+ `RecordCacheMetrics.record_tablewise_cache_miss is True`
512
+
513
+ gather_uvm_cache_stats (Optional[bool] = False): If True, collect the
514
+ cache statistics when `EmbeddingLocation` is set to
515
+ `MANAGED_CACHING`
516
+
517
+ stochastic_rounding (bool = True): If True, apply stochastic rounding
518
+ for weight type that is not `SparseType.FP32`
519
+
520
+ gradient_clipping (bool = False): If True, apply gradient clipping
521
+
522
+ max_gradient (float = 1.0): The value for gradient clipping
523
+
524
+ max_norm (float = 0.0): The max norm value
525
+
526
+ learning_rate (float = 0.01): The learning rate
527
+
528
+ eps (float = 1.0e-8): The epsilon value used by Adagrad, LAMB, and
529
+ Adam. Note that default is different from torch.nn.optim.Adagrad
530
+ default of 1e-10
531
+
532
+ momentum (float = 0.9): Momentum used by LARS-SGD
533
+
534
+ weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM,
535
+ and rowwise-Adagrad.
536
+
537
+ (1) EXACT_ADAGRAD, SGD, EXACT_SGD do not support weight decay
538
+
539
+ (2) LAMB, ADAM, PARTIAL_ROWWISE_ADAM, PARTIAL_ROWWISE_LAMB, LARS_SGD
540
+ support decoupled weight decay
541
+
542
+ (3) EXACT_ROWWISE_ADAGRAD support both L2 and decoupled weight decay
543
+ (via weight_decay_mode)
544
+
545
+ weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE): Weight decay
546
+ mode. Options are `WeightDecayMode.NONE`, `WeightDecayMode.L2`,
547
+ and `WeightDecayMode.DECOUPLE`
548
+
549
+ eta (float = 0.001): The eta value used by LARS-SGD
550
+
551
+ beta1 (float = 0.9): The beta1 value used by LAMB and ADAM
552
+
553
+ beta2 (float = 0.999): The beta2 value used by LAMB and ADAM
554
+
555
+ ensemble_mode (Optional[EnsembleModeDefinition] = None):
556
+ Used by Ensemble Rowwise Adagrad
557
+
558
+ emainplace_mode (Optional[EmainplaceModeDefinition] = None):
559
+ Used by EMA in-place Rowwise Adagrad
560
+
561
+ counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None):
562
+ Used by Rowwise Adagrad
563
+
564
+ cowclip_regularization (Optional[CowClipDefinition] = None): Used by
565
+ Rowwise Adagrad
566
+
567
+ pooling_mode (PoolingMode = PoolingMode.SUM): Pooling mode. Available
568
+ `PoolingMode` options are
569
+
570
+ (1) `SUM` = Sum pooling
571
+
572
+ (2) `MEAN` = Mean pooling
573
+
574
+ (3) `NONE` = No pooling (sequence embedding)
575
+
576
+ device (Optional[Union[str, int, torch.device]] = None): The current
577
+ device to place tensors on
578
+
579
+ bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING): Input
580
+ checking mode. Available `BoundsCheckMode` options are
581
+
582
+ (1) `NONE` = skip bounds check
583
+
584
+ (2) `FATAL` = throw an error when encountering an invalid
585
+ index/offset
586
+
587
+ (3) `WARNING` = print a warning message when encountering an
588
+ invalid index/offset and fix it (setting an invalid index to
589
+ zero and adjusting an invalid offset to be within the bound)
590
+
591
+ (4) `IGNORE` = silently fix an invalid index/offset (setting an
592
+ invalid index to zero and adjusting an invalid offset to be
593
+ within the bound)
594
+
595
+ uvm_non_rowwise_momentum (bool = False): If True, place non-rowwise
596
+ momentum on the unified virtual memory
597
+
598
+ use_experimental_tbe (bool = False): If True, use an optimized TBE
599
+ implementation (TBE v2). Note that this is supported only on NVIDIA
600
+ GPUs.
601
+
602
+ prefetch_pipeline (bool = False): If True, enable cache prefetch
603
+ pipeline when using `EmbeddingLocation.MANAGED_CACHING`. Currently
604
+ only supports the LRU cache policy. If a separate stream is used
605
+ for prefetch, the optional `forward_stream` arg of prefetch
606
+ function must be set.
607
+
608
+ stats_reporter_config (Optional[TBEStatsReporterConfig] = None):
609
+ A config for TBE stats reporter
610
+
611
+ table_names (Optional[List[str]] = None): A list of embedding table
612
+ names in this TBE
613
+
614
+ optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None): A
615
+ optimizer state data types dict. Keys are the optimizer state names
616
+ and values are their corresponding types
617
+
618
+ multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None):
619
+ A config for multipass cache prefetching (when
620
+ `EmbeddingLocation.MANAGED_CACHING` is used)
621
+
622
+ global_weight_decay (Optional[GlobalWeightDecayDefinition] = None):
623
+ A config for global weight decay
624
+
625
+ uvm_host_mapped (bool = False): If True, allocate every UVM tensor
626
+ using `malloc` + `cudaHostRegister`. Otherwise use
627
+ `cudaMallocManaged`
628
+
629
+ extra_optimizer_config Optional[UserEnabledConfigDefinition] = None):
630
+ An extra config to enable certain modes for optimizer. These modes
631
+ are not enabled by default.
632
+ - `use_rowwise_bias_correction` is used in Adam to enable rowwise
633
+ bias correction computation
634
+
635
+ embedding_table_index_type (torch.dtype = torch.int64): The data type of
636
+ the embedding table index tensor. Options are `torch.int32` and
637
+ `torch.int64`
638
+
639
+ embedding_table_offset_type (torch.dtype = torch.int64): The data type of
640
+ the embedding table offset tensor. Options are `torch.int32` and
641
+ `torch.int64`
642
+
643
+ embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None): the
644
+ information about shard position and pre-sharded table size. If not set,
645
+ the table is not sharded.
646
+ (preshard_table_height, preshard_table_dim, height_offset, dim_offset)
647
+ """
648
+
649
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
650
+ optimizer_args: invokers.lookup_args.OptimizerArgs
651
+ lxu_cache_locations_list: list[Tensor]
652
+ lxu_cache_locations_empty: Tensor
653
+ timesteps_prefetched: list[int]
654
+ prefetched_info_list: list[PrefetchedInfo]
655
+ record_cache_metrics: RecordCacheMetrics
656
+ # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
657
+ uvm_cache_stats: torch.Tensor
658
+ # pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
659
+ local_uvm_cache_stats: torch.Tensor
660
+ uuid: str
661
+ # pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
662
+ last_uvm_cache_print_state: torch.Tensor
663
+ _vbe_B_offsets: Optional[torch.Tensor]
664
+ _vbe_max_B: int
665
+
666
+ def __init__( # noqa C901
667
+ self,
668
+ embedding_specs: list[
669
+ tuple[int, int, EmbeddingLocation, ComputeDevice]
670
+ ], # tuple of (rows, dims, placements, compute_devices)
671
+ feature_table_map: Optional[list[int]] = None, # [T]
672
+ cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
673
+ cache_load_factor: float = 0.2,
674
+ cache_sets: int = 0,
675
+ cache_reserved_memory: float = 0.0,
676
+ cache_precision: Optional[SparseType] = None,
677
+ weights_precision: SparseType = SparseType.FP32,
678
+ output_dtype: SparseType = SparseType.FP32,
679
+ enforce_hbm: bool = False,
680
+ optimizer: OptimType = OptimType.EXACT_SGD,
681
+ record_cache_metrics: Optional[RecordCacheMetrics] = None,
682
+ gather_uvm_cache_stats: Optional[bool] = False,
683
+ # General Optimizer args
684
+ stochastic_rounding: bool = True,
685
+ gradient_clipping: bool = False,
686
+ max_gradient: float = 1.0,
687
+ max_norm: float = 0.0,
688
+ learning_rate: float = 0.01,
689
+ eps: float = 1.0e-8,
690
+ momentum: float = 0.9,
691
+ weight_decay: float = 0.0,
692
+ weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE,
693
+ eta: float = 0.001,
694
+ beta1: float = 0.9,
695
+ beta2: float = 0.999,
696
+ ensemble_mode: Optional[EnsembleModeDefinition] = None,
697
+ emainplace_mode: Optional[EmainplaceModeDefinition] = None,
698
+ counter_based_regularization: Optional[
699
+ CounterBasedRegularizationDefinition
700
+ ] = None,
701
+ cowclip_regularization: Optional[CowClipDefinition] = None,
702
+ pooling_mode: PoolingMode = PoolingMode.SUM,
703
+ device: Optional[Union[str, int, torch.device]] = None,
704
+ bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
705
+ uvm_non_rowwise_momentum: bool = False,
706
+ use_experimental_tbe: bool = False,
707
+ prefetch_pipeline: bool = False,
708
+ stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
709
+ table_names: Optional[list[str]] = None,
710
+ optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
711
+ multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
712
+ global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
713
+ uvm_host_mapped: bool = False,
714
+ extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
715
+ tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
716
+ embedding_table_index_type: torch.dtype = torch.int64,
717
+ embedding_table_offset_type: torch.dtype = torch.int64,
718
+ embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
719
+ enable_raw_embedding_streaming: bool = False,
720
+ res_params: Optional[RESParams] = None,
721
+ ) -> None:
722
+ super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
723
+ self.uuid = str(uuid.uuid4())
724
+ self.log("SplitTableBatchedEmbeddingBagsCodegen API: V2")
725
+ self.log(f"SplitTableBatchedEmbeddingBagsCodegen Arguments: {locals()}")
726
+ self.log(
727
+ f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
728
+ )
729
+
730
+ self.table_names: Optional[list[str]] = table_names
731
+ self.logging_table_name: str = self.get_table_name_for_logging(table_names)
732
+ self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
733
+ self.pooling_mode = pooling_mode
734
+ self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
735
+
736
+ # If environment variable is set, it overwrites the default bounds check mode.
737
+ self.bounds_check_version: int = (
738
+ 2
739
+ if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2)
740
+ else get_bounds_check_version_for_platform()
741
+ )
742
+ self.bounds_check_mode_int: int = int(
743
+ os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
744
+ )
745
+ # Check if bounds_check_indices_v2 is enabled via the feature gate
746
+ bounds_check_mode = BoundsCheckMode(self.bounds_check_mode_int)
747
+ if bounds_check_mode.name.startswith("V2_"):
748
+ self.bounds_check_version = 2
749
+ if bounds_check_mode == BoundsCheckMode.V2_IGNORE:
750
+ bounds_check_mode = BoundsCheckMode.IGNORE
751
+ elif bounds_check_mode == BoundsCheckMode.V2_WARNING:
752
+ bounds_check_mode = BoundsCheckMode.WARNING
753
+ elif bounds_check_mode == BoundsCheckMode.V2_FATAL:
754
+ bounds_check_mode = BoundsCheckMode.FATAL
755
+
756
+ if bounds_check_mode not in (
757
+ BoundsCheckMode.IGNORE,
758
+ BoundsCheckMode.WARNING,
759
+ BoundsCheckMode.FATAL,
760
+ BoundsCheckMode.NONE,
761
+ ):
762
+ raise NotImplementedError(
763
+ f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} is not supported"
764
+ )
765
+
766
+ self.bounds_check_mode_int = bounds_check_mode.value
767
+
768
+ self.log(
769
+ f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} bounds_check_version={self.bounds_check_version}"
770
+ )
771
+
772
+ self.weights_precision = weights_precision
773
+
774
+ if torch.cuda.is_available() and torch.version.hip:
775
+ # NOTE: It was discovered that FP16 cache precision caused a 500x
776
+ # slowdown in performance of split_embedding_nobag_backward_codegen_rowwise_adagrad_unweighted_kernel_warp_per_row_1
777
+ # kernel on ROCm, so to work around this, we fix cache precision to
778
+ # be FP32 always for the ROCm environment case.
779
+ #
780
+ # See:
781
+ # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
782
+ cache_precision = SparseType.FP32
783
+ self.log("Override cache_precision=SparseType.FP32 on ROCm")
784
+ else:
785
+ # NOTE: The changes from D65865527 are retained here until we can
786
+ # test that the the hack also works for non-ROCm environments.
787
+ cache_precision = (
788
+ weights_precision if cache_precision is None else cache_precision
789
+ )
790
+
791
+ self.output_dtype: int = output_dtype.as_int()
792
+ assert (
793
+ not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
794
+ ), "Only LRU cache policy supports prefetch_pipeline."
795
+ self.prefetch_pipeline: bool = prefetch_pipeline
796
+ self.lock_cache_line: bool = self.prefetch_pipeline
797
+ self.use_uniq_cache_locations_bwd: bool = self.prefetch_pipeline
798
+ self.multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = (
799
+ multipass_prefetch_config
800
+ )
801
+
802
+ if record_cache_metrics is not None:
803
+ self.record_cache_metrics = record_cache_metrics
804
+ else:
805
+ self.record_cache_metrics = RecordCacheMetrics(False, False)
806
+
807
+ if multipass_prefetch_config:
808
+ assert (
809
+ prefetch_pipeline
810
+ ), "Multipass prefetch makes no sense in non-prefetch mode."
811
+ assert (
812
+ cache_algorithm == CacheAlgorithm.LRU
813
+ ), "Multipass prefetch is only supported in LRU cache."
814
+ assert (
815
+ multipass_prefetch_config.num_passes > 0
816
+ ), f"num_passes must be positive, get {multipass_prefetch_config.num_passes}"
817
+ assert (
818
+ multipass_prefetch_config.min_splitable_pass_size > 0
819
+ ), f"min_splitable_pass_size must be positive, get {multipass_prefetch_config.min_splitable_pass_size}"
820
+ assert (
821
+ not self.record_cache_metrics.record_cache_miss_counter
822
+ and not self.record_cache_metrics.record_tablewise_cache_miss
823
+ ), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
824
+
825
+ self.embedding_specs = embedding_specs
826
+ (rows, dims, locations, compute_devices) = zip(*embedding_specs)
827
+ T_ = len(self.embedding_specs)
828
+ self.dims: list[int] = dims
829
+ assert T_ > 0
830
+ # mixed D is not supported by no bag kernels
831
+ mixed_D = False
832
+ D = self.dims[0]
833
+ for d in self.dims:
834
+ if d != D:
835
+ mixed_D = True
836
+ break
837
+ if mixed_D:
838
+ assert (
839
+ self.pooling_mode != PoolingMode.NONE
840
+ ), "Mixed dimension tables only supported for pooling tables."
841
+ self.mixed_D: bool = mixed_D
842
+ assert all(
843
+ cd == compute_devices[0] for cd in compute_devices
844
+ ), "Heterogenous compute_devices are NOT supported!"
845
+ # Split TBE has different function schemas for CUDA and CPU.
846
+ # For MTIA device type, it uses the CPU one.
847
+ self.use_cpu: bool = (
848
+ compute_devices[0] == ComputeDevice.CPU
849
+ or compute_devices[0] == ComputeDevice.MTIA
850
+ )
851
+
852
+ assert not self.use_cpu or all(
853
+ loc == EmbeddingLocation.HOST for loc in locations
854
+ ), "ComputeDevice.CPU is only for EmbeddingLocation.HOST!"
855
+ assert self.use_cpu or all(
856
+ loc != EmbeddingLocation.HOST for loc in locations
857
+ ), "EmbeddingLocation.HOST doesn't work for CUDA device!"
858
+ if self.use_cpu or self.pooling_mode == PoolingMode.NONE:
859
+ assert output_dtype in [
860
+ SparseType.FP32,
861
+ SparseType.FP16,
862
+ SparseType.BF16,
863
+ ], "Fused pooled embedding quantization only supported for cuda."
864
+
865
+ if optimizer == OptimType.NONE:
866
+ assert all(
867
+ loc == EmbeddingLocation.DEVICE for loc in locations
868
+ ), "OptimType.NONE supports only EmbeddingLocation.DEVICE"
869
+ assert all(
870
+ cd == ComputeDevice.CUDA for cd in compute_devices
871
+ ), "OptimType.NONE supports only ComputeDevice.CUDA"
872
+ assert (
873
+ not mixed_D
874
+ ), "OptimType.NONE does not support mixed embedding dimension"
875
+
876
+ if device is None:
877
+ self.current_device: torch.device = (
878
+ torch.device("cpu")
879
+ if self.use_cpu
880
+ else torch.device(torch.cuda.current_device())
881
+ )
882
+ elif isinstance(device, torch.device):
883
+ self.current_device = device
884
+ else:
885
+ self.current_device = torch.device(device)
886
+
887
+ # add placeholder require_grad param tensor to enable autograd with int8 weights
888
+ self.placeholder_autograd_tensor = nn.Parameter(
889
+ torch.zeros(0, device=self.current_device, dtype=torch.float)
890
+ )
891
+
892
+ self.gather_uvm_cache_stats = gather_uvm_cache_stats
893
+ # Define the size of uvm cache stats as class variable
894
+ # to make it work with torch jit script.
895
+ self.uvm_cache_stats_size = 6
896
+ # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
897
+ # 4: N_conflict_unique_misses, 5: N_conflict_misses
898
+
899
+ # Reporter to collect runtime performance stats bottom-up. Reporter may
900
+ # do aggregation across TBEs and publish results per training batch.
901
+ # Example of stats include UVM cache hit rate, table I/O size, etc.
902
+ self.stats_reporter: Optional[TBEStatsReporter] = (
903
+ stats_reporter_config.create_reporter() if stats_reporter_config else None
904
+ )
905
+ self._uvm_tensors_log: list[str] = []
906
+
907
+ self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
908
+ self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
909
+ if self.stats_reporter:
910
+ # When stats_reporter is present, we set up async series timer to
911
+ # measure the GPU time per tracked event accordingly. Each of them
912
+ # is attached to custom callback report function to report collected
913
+ # duration with the corresponding event name.
914
+ self.bwd_wait_prefetch_timer = AsyncSeriesTimer(
915
+ functools.partial(
916
+ SplitTableBatchedEmbeddingBagsCodegen._report_duration,
917
+ self,
918
+ event_name="bwd_wait_for_prefetch",
919
+ )
920
+ )
921
+
922
+ self.prefetch_duration_timer = AsyncSeriesTimer(
923
+ functools.partial(
924
+ SplitTableBatchedEmbeddingBagsCodegen._report_duration,
925
+ self,
926
+ event_name="total_prefetch_duration",
927
+ )
928
+ )
929
+
930
+ self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
931
+
932
+ self.feature_table_map: list[int] = (
933
+ feature_table_map if feature_table_map is not None else list(range(T_))
934
+ )
935
+
936
+ if embedding_shard_info:
937
+ (full_table_heights, full_table_dims, row_offset, col_offset) = zip(
938
+ *embedding_shard_info
939
+ )
940
+ else:
941
+ # Just assume the table is unsharded
942
+ full_table_heights = rows
943
+ full_table_dims = dims
944
+ row_offset = [0] * len(rows)
945
+ col_offset = [0] * len(rows)
946
+ self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
947
+ tbe_input_multiplexer_config.create_tbe_input_multiplexer(
948
+ tbe_info=TBEInfo(
949
+ table_names=(
950
+ table_names
951
+ if table_names
952
+ else [f"table-{i}" for i in range(len(embedding_specs))]
953
+ ),
954
+ table_heights=rows,
955
+ tbe_uuid=self.uuid,
956
+ feature_table_map=self.feature_table_map,
957
+ table_dims=dims,
958
+ full_table_heights=full_table_heights,
959
+ full_table_dims=full_table_dims,
960
+ row_offset=row_offset,
961
+ col_offset=col_offset,
962
+ )
963
+ )
964
+ if tbe_input_multiplexer_config is not None
965
+ else None
966
+ )
967
+ T = len(self.feature_table_map)
968
+ assert T_ <= T
969
+ table_has_feature = [False] * T_
970
+ for t in self.feature_table_map:
971
+ table_has_feature[t] = True
972
+ assert all(table_has_feature), "Each table must have at least one feature!"
973
+
974
+ feature_dims = [dims[t] for t in self.feature_table_map]
975
+ D_offsets = [0] + list(accumulate(feature_dims))
976
+ self.total_D: int = D_offsets[-1]
977
+ self.max_D: int = max(dims)
978
+ cached_dims = [
979
+ embedding_spec[1]
980
+ for embedding_spec in embedding_specs
981
+ if embedding_spec[2] == EmbeddingLocation.MANAGED_CACHING
982
+ ]
983
+ self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
984
+
985
+ self.register_buffer(
986
+ "D_offsets",
987
+ torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
988
+ )
989
+ hash_size_cumsum = [0] + list(accumulate(rows))
990
+ self.total_hash_size: int = int(hash_size_cumsum[-1])
991
+ if self.total_hash_size == 0:
992
+ self.total_hash_size_bits: int = 0
993
+ else:
994
+ self.total_hash_size_bits: int = int(log2(float(self.total_hash_size)) + 1)
995
+ # The last element is to easily access # of rows of each table by
996
+ # hash_size_cumsum[t + 1] - hash_size_cumsum[t]
997
+ hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
998
+ self.total_hash_size
999
+ ]
1000
+ self.register_buffer(
1001
+ "hash_size_cumsum",
1002
+ torch.tensor(
1003
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
1004
+ ),
1005
+ )
1006
+
1007
+ self.register_buffer(
1008
+ "rows_per_table",
1009
+ torch.tensor(
1010
+ [rows[t] for t in self.feature_table_map],
1011
+ device=self.current_device,
1012
+ dtype=torch.int64,
1013
+ ),
1014
+ )
1015
+ self.register_buffer(
1016
+ "bounds_check_warning",
1017
+ torch.tensor([0], device=self.current_device, dtype=torch.int64),
1018
+ )
1019
+ # Required for VBE
1020
+ self.register_buffer(
1021
+ "feature_dims",
1022
+ torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
1023
+ )
1024
+ (_info_B_num_bits, _info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
1025
+ self.D_offsets, # unused tensor
1026
+ 1, # max_B
1027
+ T, # T
1028
+ )
1029
+ self.info_B_num_bits: int = _info_B_num_bits
1030
+ self.info_B_mask: int = _info_B_mask
1031
+
1032
+ # A flag for indicating whether all embedding tables are placed in the
1033
+ # same locations
1034
+ self.use_homogeneous_placements: bool = all(
1035
+ loc == locations[0] for loc in locations
1036
+ )
1037
+
1038
+ self.uvm_host_mapped = uvm_host_mapped
1039
+
1040
+ weight_split = construct_split_state(
1041
+ embedding_specs,
1042
+ rowwise=False,
1043
+ cacheable=True,
1044
+ precision=weights_precision,
1045
+ )
1046
+ table_embedding_dtype = weights_precision.as_dtype()
1047
+
1048
+ self._apply_split(
1049
+ weight_split,
1050
+ prefix="weights",
1051
+ # pyre-fixme[6]: For 3rd param expected `Type[Type[_dtype]]` but got
1052
+ # `Type[_dtype]`.
1053
+ dtype=table_embedding_dtype,
1054
+ enforce_hbm=enforce_hbm,
1055
+ make_dev_param=optimizer == OptimType.NONE,
1056
+ dev_reshape=(-1, self.max_D) if optimizer == OptimType.NONE else None,
1057
+ uvm_host_mapped=self.uvm_host_mapped,
1058
+ )
1059
+
1060
+ assert optimizer not in (
1061
+ OptimType.SGD,
1062
+ OptimType.ROWWISE_ADAGRAD,
1063
+ ), f"Optimizer {optimizer} is deprecated in the CPU + GPU modes."
1064
+
1065
+ if self.use_cpu:
1066
+ # Construct optimizer states
1067
+ assert optimizer in (
1068
+ OptimType.EXACT_ADAGRAD,
1069
+ OptimType.EXACT_ROWWISE_ADAGRAD,
1070
+ OptimType.EXACT_SGD,
1071
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
1072
+ ), f"Optimizer {optimizer} is not supported in CPU mode."
1073
+ else:
1074
+ assert optimizer in (
1075
+ OptimType.ADAM,
1076
+ OptimType.EXACT_ADAGRAD,
1077
+ OptimType.EXACT_ROWWISE_ADAGRAD,
1078
+ OptimType.EXACT_SGD,
1079
+ OptimType.LAMB,
1080
+ OptimType.LARS_SGD,
1081
+ OptimType.PARTIAL_ROWWISE_ADAM,
1082
+ OptimType.PARTIAL_ROWWISE_LAMB,
1083
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
1084
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
1085
+ OptimType.NONE,
1086
+ ), f"Optimizer {optimizer} is not supported."
1087
+
1088
+ self.stochastic_rounding = stochastic_rounding
1089
+ self.optimizer = optimizer
1090
+
1091
+ self.weight_decay_mode = weight_decay_mode
1092
+ if (weight_decay_mode == WeightDecayMode.COUNTER) != (
1093
+ counter_based_regularization is not None
1094
+ ):
1095
+ raise AssertionError(
1096
+ "Need to set weight_decay_mode=WeightDecayMode.COUNTER together with valid counter_based_regularization"
1097
+ )
1098
+ if (weight_decay_mode == WeightDecayMode.COWCLIP) != (
1099
+ cowclip_regularization is not None
1100
+ ):
1101
+ raise AssertionError(
1102
+ "Need to set weight_decay_mode=WeightDecayMode.COWCLIP together with valid cowclip_regularization"
1103
+ )
1104
+
1105
+ self._used_rowwise_adagrad_with_counter: bool = (
1106
+ optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
1107
+ and (
1108
+ weight_decay_mode in (WeightDecayMode.COUNTER, WeightDecayMode.COWCLIP)
1109
+ )
1110
+ )
1111
+
1112
+ if weight_decay_mode == WeightDecayMode.DECOUPLE_GLOBAL and (
1113
+ not optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
1114
+ or global_weight_decay is None
1115
+ ):
1116
+ raise AssertionError(
1117
+ """weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL is supported for
1118
+ optimizer=OptimType.EXACT_ROWWISE_ADAGRAD and global_weight_decay cannot be None.
1119
+ """
1120
+ )
1121
+
1122
+ self._used_rowwise_adagrad_with_global_weight_decay: bool = (
1123
+ optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
1124
+ and (weight_decay_mode == WeightDecayMode.DECOUPLE_GLOBAL)
1125
+ )
1126
+ self.log(
1127
+ f"Using global weight decay = {self._used_rowwise_adagrad_with_global_weight_decay}"
1128
+ )
1129
+ # Declare GWD params here to avoid torch.jit.script error
1130
+ if global_weight_decay is None:
1131
+ global_weight_decay = GlobalWeightDecayDefinition()
1132
+
1133
+ self.gwd_start_iter: int = global_weight_decay.start_iter
1134
+ self.gwd_lower_bound: float = global_weight_decay.lower_bound
1135
+
1136
+ if ensemble_mode is None:
1137
+ ensemble_mode = EnsembleModeDefinition()
1138
+ self._ensemble_mode: dict[str, float] = {
1139
+ key: float(fval) for key, fval in ensemble_mode.__dict__.items()
1140
+ }
1141
+
1142
+ if emainplace_mode is None:
1143
+ emainplace_mode = EmainplaceModeDefinition()
1144
+ self._emainplace_mode: dict[str, float] = {
1145
+ key: float(fval) for key, fval in emainplace_mode.__dict__.items()
1146
+ }
1147
+
1148
+ if counter_based_regularization is None:
1149
+ counter_based_regularization = CounterBasedRegularizationDefinition()
1150
+ if cowclip_regularization is None:
1151
+ cowclip_regularization = CowClipDefinition()
1152
+ self._max_counter_update_freq: int = -1
1153
+ # Extract parameters from CounterBasedRegularizationDefinition or CowClipDefinition
1154
+ # which are passed as entries for OptimizerArgs
1155
+ if self._used_rowwise_adagrad_with_counter:
1156
+ if self.weight_decay_mode == WeightDecayMode.COUNTER:
1157
+ self._max_counter_update_freq = (
1158
+ counter_based_regularization.max_counter_update_freq
1159
+ )
1160
+ opt_arg_weight_decay_mode = (
1161
+ counter_based_regularization.counter_weight_decay_mode
1162
+ )
1163
+ counter_halflife = counter_based_regularization.counter_halflife
1164
+ else:
1165
+ opt_arg_weight_decay_mode = (
1166
+ cowclip_regularization.counter_weight_decay_mode
1167
+ )
1168
+ counter_halflife = cowclip_regularization.counter_halflife
1169
+ else:
1170
+ opt_arg_weight_decay_mode = weight_decay_mode
1171
+ # Default: -1, no decay applied, as a placeholder for OptimizerArgs
1172
+ # which should not be effective when CounterBasedRegularizationDefinition
1173
+ # and CowClipDefinition are not used
1174
+ counter_halflife = -1
1175
+
1176
+ if extra_optimizer_config is None:
1177
+ extra_optimizer_config = UserEnabledConfigDefinition()
1178
+ self.use_rowwise_bias_correction: bool = (
1179
+ extra_optimizer_config.use_rowwise_bias_correction
1180
+ )
1181
+ self.use_writeback_bwd_prehook: bool = (
1182
+ extra_optimizer_config.use_writeback_bwd_prehook
1183
+ )
1184
+ self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
1185
+ if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
1186
+ raise AssertionError(
1187
+ "`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
1188
+ )
1189
+ if self.use_writeback_bwd_prehook and not self.optimizer == OptimType.EXACT_SGD:
1190
+ raise AssertionError(
1191
+ "`use_writeback_bwd_prehook` is only supported for OptimType.EXACT_SGD",
1192
+ )
1193
+
1194
+ self.learning_rate_tensor: torch.Tensor = torch.tensor(
1195
+ learning_rate, device=torch.device("cpu"), dtype=torch.float32
1196
+ )
1197
+
1198
+ self.optimizer_args = invokers.lookup_args.OptimizerArgs(
1199
+ stochastic_rounding=stochastic_rounding,
1200
+ gradient_clipping=gradient_clipping,
1201
+ max_gradient=max_gradient,
1202
+ max_norm=max_norm,
1203
+ eps=eps,
1204
+ beta1=beta1,
1205
+ beta2=beta2,
1206
+ weight_decay=weight_decay,
1207
+ weight_decay_mode=opt_arg_weight_decay_mode.value,
1208
+ eta=eta,
1209
+ momentum=momentum,
1210
+ counter_halflife=counter_halflife,
1211
+ adjustment_iter=counter_based_regularization.adjustment_iter,
1212
+ adjustment_ub=counter_based_regularization.adjustment_ub,
1213
+ learning_rate_mode=counter_based_regularization.learning_rate_mode.value,
1214
+ grad_sum_decay=counter_based_regularization.grad_sum_decay.value,
1215
+ tail_id_threshold=counter_based_regularization.tail_id_threshold.val,
1216
+ is_tail_id_thresh_ratio=int(
1217
+ counter_based_regularization.tail_id_threshold.is_ratio
1218
+ ),
1219
+ total_hash_size=self.total_hash_size,
1220
+ weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
1221
+ lower_bound=cowclip_regularization.lower_bound,
1222
+ regularization_mode=weight_decay_mode.value,
1223
+ use_rowwise_bias_correction=self.use_rowwise_bias_correction,
1224
+ )
1225
+
1226
+ if optimizer != OptimType.NONE:
1227
+ assert (
1228
+ optimizer
1229
+ in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ENSEMBLE_ROWWISE_ADAGRAD)
1230
+ or optimizer_state_dtypes is None
1231
+ ), "optimizer_state_dtypes option is only supported for OptimType.PARTIAL_ROWWISE_ADAM and OptimType.ENSEMBLE_ROWWISE_ADAGRAD"
1232
+ if optimizer in (OptimType.EXACT_SGD,):
1233
+ # NOTE: make TorchScript work!
1234
+ self._register_nonpersistent_buffers("momentum1")
1235
+ else:
1236
+ momentum1_dtype = (
1237
+ torch.float32
1238
+ if (
1239
+ optimizer_state_dtypes is None
1240
+ or "momentum1" not in optimizer_state_dtypes
1241
+ or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD
1242
+ )
1243
+ else optimizer_state_dtypes["momentum1"].as_dtype()
1244
+ )
1245
+ rowwise = optimizer in [
1246
+ OptimType.EXACT_ROWWISE_ADAGRAD,
1247
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
1248
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
1249
+ ]
1250
+ self._apply_split(
1251
+ construct_split_state(
1252
+ embedding_specs,
1253
+ rowwise=rowwise,
1254
+ cacheable=False,
1255
+ placement=(
1256
+ EmbeddingLocation.MANAGED
1257
+ if ((not rowwise) and uvm_non_rowwise_momentum)
1258
+ else None
1259
+ ),
1260
+ ),
1261
+ prefix="momentum1",
1262
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1263
+ # but got `Type[torch.float32]`.
1264
+ dtype=momentum1_dtype,
1265
+ enforce_hbm=enforce_hbm,
1266
+ uvm_host_mapped=self.uvm_host_mapped,
1267
+ )
1268
+ if optimizer in (
1269
+ OptimType.ADAM,
1270
+ OptimType.PARTIAL_ROWWISE_ADAM,
1271
+ OptimType.LAMB,
1272
+ OptimType.PARTIAL_ROWWISE_LAMB,
1273
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
1274
+ ):
1275
+ rowwise = optimizer in (
1276
+ OptimType.PARTIAL_ROWWISE_ADAM,
1277
+ OptimType.PARTIAL_ROWWISE_LAMB,
1278
+ )
1279
+ momentum2_dtype = (
1280
+ torch.float32
1281
+ if (
1282
+ optimizer_state_dtypes is None
1283
+ or "momentum2" not in optimizer_state_dtypes
1284
+ )
1285
+ else optimizer_state_dtypes["momentum2"].as_dtype()
1286
+ )
1287
+ self._apply_split(
1288
+ construct_split_state(
1289
+ embedding_specs,
1290
+ rowwise=rowwise,
1291
+ cacheable=False,
1292
+ placement=(
1293
+ EmbeddingLocation.MANAGED
1294
+ if ((not rowwise) and uvm_non_rowwise_momentum)
1295
+ else None
1296
+ ),
1297
+ ),
1298
+ prefix="momentum2",
1299
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1300
+ # but got `Type[torch.float32]`.
1301
+ dtype=momentum2_dtype,
1302
+ uvm_host_mapped=self.uvm_host_mapped,
1303
+ )
1304
+ else:
1305
+ # NOTE: make TorchScript work!
1306
+ self._register_nonpersistent_buffers("momentum2")
1307
+ if self._used_rowwise_adagrad_with_counter:
1308
+ self._apply_split(
1309
+ construct_split_state(
1310
+ embedding_specs,
1311
+ rowwise=True,
1312
+ cacheable=False,
1313
+ ),
1314
+ prefix="prev_iter",
1315
+ # TODO: ideally we should use int64 to track iter but it failed to compile.
1316
+ # It may be related to low precision training code. Currently using float32
1317
+ # as a workaround while investigating the issue.
1318
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1319
+ # but got `Type[torch.float32]`.
1320
+ dtype=torch.float32,
1321
+ uvm_host_mapped=self.uvm_host_mapped,
1322
+ )
1323
+ self._apply_split(
1324
+ construct_split_state(
1325
+ embedding_specs,
1326
+ rowwise=True,
1327
+ cacheable=False,
1328
+ ),
1329
+ prefix="row_counter",
1330
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1331
+ # but got `Type[torch.float32]`.
1332
+ dtype=torch.float32,
1333
+ uvm_host_mapped=self.uvm_host_mapped,
1334
+ )
1335
+ self.register_buffer(
1336
+ "max_counter", torch.tensor([1], dtype=torch.float32)
1337
+ )
1338
+ elif self._used_rowwise_adagrad_with_global_weight_decay:
1339
+ self._apply_split(
1340
+ construct_split_state(
1341
+ embedding_specs,
1342
+ rowwise=True,
1343
+ cacheable=False,
1344
+ ),
1345
+ prefix="prev_iter",
1346
+ # TODO: ideally we should use int64 to track iter but it failed to compile.
1347
+ # It may be related to low precision training code. Currently using float32
1348
+ # as a workaround while investigating the issue.
1349
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1350
+ # but got `Type[torch.float32]`.
1351
+ dtype=torch.float32,
1352
+ uvm_host_mapped=self.uvm_host_mapped,
1353
+ )
1354
+ self._register_nonpersistent_buffers("row_counter")
1355
+ self.register_buffer(
1356
+ "max_counter",
1357
+ torch.ones(1, dtype=torch.float32, device=self.current_device),
1358
+ persistent=False,
1359
+ )
1360
+ elif optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
1361
+ self._apply_split(
1362
+ construct_split_state(
1363
+ embedding_specs,
1364
+ rowwise=True,
1365
+ cacheable=False,
1366
+ ),
1367
+ prefix="row_counter",
1368
+ # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
1369
+ # but got `Type[torch.float32]`.
1370
+ dtype=torch.float32,
1371
+ uvm_host_mapped=self.uvm_host_mapped,
1372
+ )
1373
+ else:
1374
+ self._register_nonpersistent_buffers("prev_iter")
1375
+ self._register_nonpersistent_buffers("row_counter")
1376
+ self.register_buffer(
1377
+ "max_counter",
1378
+ torch.ones(1, dtype=torch.float32, device=self.current_device),
1379
+ persistent=False,
1380
+ )
1381
+ if (
1382
+ optimizer
1383
+ in (
1384
+ OptimType.ADAM,
1385
+ OptimType.LAMB,
1386
+ OptimType.PARTIAL_ROWWISE_ADAM,
1387
+ OptimType.PARTIAL_ROWWISE_LAMB,
1388
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
1389
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
1390
+ )
1391
+ or self._used_rowwise_adagrad_with_global_weight_decay
1392
+ ):
1393
+ self.register_buffer(
1394
+ "iter",
1395
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1396
+ )
1397
+ else:
1398
+ self.register_buffer(
1399
+ "iter",
1400
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1401
+ persistent=False,
1402
+ )
1403
+
1404
+ self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
1405
+
1406
+ cache_state = construct_cache_state(rows, locations, self.feature_table_map)
1407
+
1408
+ # Add table-wise cache miss counter
1409
+ if self.record_cache_metrics.record_tablewise_cache_miss:
1410
+ num_tables = len(cache_state.cache_hash_size_cumsum) - 1
1411
+ self.register_buffer(
1412
+ "table_wise_cache_miss",
1413
+ torch.zeros(
1414
+ num_tables,
1415
+ device=self.current_device,
1416
+ dtype=torch.int64,
1417
+ ),
1418
+ )
1419
+ # NOTE: make TorchScript work!
1420
+ else:
1421
+ self.register_buffer(
1422
+ "table_wise_cache_miss",
1423
+ torch.zeros(
1424
+ 0,
1425
+ device=self.current_device,
1426
+ dtype=torch.int64,
1427
+ ),
1428
+ )
1429
+
1430
+ self._apply_cache_state(
1431
+ cache_state,
1432
+ cache_algorithm,
1433
+ cache_load_factor,
1434
+ cache_sets,
1435
+ cache_reserved_memory,
1436
+ cache_precision,
1437
+ )
1438
+
1439
+ self.log(f"Contents: {table_names}")
1440
+ self.log(
1441
+ f"Using fused {optimizer} with optimizer_args={self.optimizer_args if optimizer != OptimType.NONE else None}"
1442
+ )
1443
+ self.log(
1444
+ f"Using rowwise_adagrad_with_counter={self._used_rowwise_adagrad_with_counter}"
1445
+ )
1446
+
1447
+ self.step = 0
1448
+ self.last_reported_step = 0
1449
+ self.last_reported_uvm_stats: list[float] = []
1450
+ # Track number of times detailed memory breakdown has been reported
1451
+ self.detailed_mem_breakdown_report_count = 0
1452
+ # Set max number of reports for detailed memory breakdown
1453
+ self.max_detailed_mem_breakdown_reports = 10
1454
+
1455
+ # Check whether to use TBE v2
1456
+ is_experimental = False
1457
+ if use_experimental_tbe:
1458
+ is_experimental = True
1459
+ self.log("use_experimental_tbe is set to True; Using experimental TBE")
1460
+
1461
+ elif int(os.environ.get("FBGEMM_EXPERIMENTAL_TBE", "0")) == 1:
1462
+ # Keep the old feature enablement mechanism to ensure no negative impact on models that have already adopted TBE v2
1463
+ is_experimental = True
1464
+ self.log("FBGEMM_EXPERIMENTAL_TBE is set to True; Using experimental TBE")
1465
+
1466
+ # NOTE: Keep this disabled for now until the backend lands into Pyper
1467
+ # elif FeatureGateName.TBE_V2.is_enabled():
1468
+ # is_experimental = True
1469
+ # self.log("TBE_V2 Knob is set to True; Using experimental TBE")
1470
+
1471
+ self.is_experimental: bool = is_experimental
1472
+
1473
+ # Get a debug function pointer
1474
+ self._debug_print_input_stats: Callable[..., None] = (
1475
+ self._debug_print_input_stats_factory()
1476
+ )
1477
+
1478
+ # Get a reporter function pointer
1479
+ self._report_input_params: Callable[..., None] = (
1480
+ self.__report_input_params_factory()
1481
+ )
1482
+
1483
+ if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
1484
+ # Register writeback hook for Exact_SGD optimizer
1485
+ self.log(
1486
+ "SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
1487
+ )
1488
+ # pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
1489
+ self.register_full_backward_pre_hook(self.writeback_hook)
1490
+
1491
+ if embedding_table_index_type not in [torch.int32, torch.int64]:
1492
+ raise ValueError(
1493
+ f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}"
1494
+ )
1495
+ self.embedding_table_index_type: torch.dtype = embedding_table_index_type
1496
+ if embedding_table_offset_type not in [torch.int32, torch.int64]:
1497
+ raise ValueError(
1498
+ f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}"
1499
+ )
1500
+ self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
1501
+
1502
+ self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
1503
+ list[PrefetchedInfo], []
1504
+ )
1505
+ if self.enable_raw_embedding_streaming:
1506
+ self.res_params: RESParams = res_params or RESParams()
1507
+ self.res_params.table_sizes = [0] + list(accumulate(rows))
1508
+ res_port_from_env = os.getenv("LOCAL_RES_PORT")
1509
+ self.res_params.res_server_port = (
1510
+ int(res_port_from_env) if res_port_from_env else 0
1511
+ )
1512
+ # pyre-fixme[4]: Attribute must be annotated.
1513
+ self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
1514
+ self.uuid,
1515
+ self.enable_raw_embedding_streaming,
1516
+ self.res_params.res_store_shards,
1517
+ self.res_params.res_server_port,
1518
+ self.res_params.table_names,
1519
+ self.res_params.table_offsets,
1520
+ self.res_params.table_sizes,
1521
+ )
1522
+ logging.info(
1523
+ f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
1524
+ )
1525
+
1526
+ @torch.jit.ignore
1527
+ def log(self, msg: str) -> None:
1528
+ """
1529
+ Log with TBE id prefix to distinguish between multiple TBE instances
1530
+ per process
1531
+
1532
+ Args:
1533
+ msg (str): The message to print
1534
+
1535
+ Returns:
1536
+ None
1537
+ """
1538
+ logging.info(f"[TBE={self.uuid}] {msg}")
1539
+
1540
+ def _register_nonpersistent_buffers(self, prefix: str) -> None:
1541
+ # NOTE: make TorchScript work!
1542
+ self.register_buffer(
1543
+ f"{prefix}_dev",
1544
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1545
+ persistent=False,
1546
+ )
1547
+ self.register_buffer(
1548
+ f"{prefix}_host",
1549
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1550
+ persistent=False,
1551
+ )
1552
+ self.register_buffer(
1553
+ f"{prefix}_uvm",
1554
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1555
+ persistent=False,
1556
+ )
1557
+ self.register_buffer(
1558
+ f"{prefix}_placements",
1559
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1560
+ persistent=False,
1561
+ )
1562
+ self.register_buffer(
1563
+ f"{prefix}_offsets",
1564
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
1565
+ persistent=False,
1566
+ )
1567
+
1568
+ @staticmethod
1569
+ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
1570
+ """
1571
+ Given a list of all table names in the TBE, generate a string to
1572
+ represent them in logging. If there is more than one table, this method
1573
+ will count them than list them.
1574
+
1575
+ Args:
1576
+ table_names (Optional[List[str]]): A list of table anmes in TBE
1577
+
1578
+ Returns:
1579
+ A string that represents tables in logging
1580
+ """
1581
+ if table_names is None:
1582
+ return "<Unknown>"
1583
+ # Do this because sometimes multiple shards of the same table could appear
1584
+ # in one TBE.
1585
+ table_name_set = sorted(set(table_names))
1586
+ if len(table_name_set) == 1:
1587
+ return next(iter(table_name_set))
1588
+ return f"<{len(table_name_set)} tables>: {table_name_set}"
1589
+
1590
+ @staticmethod
1591
+ def get_prefetch_passes(
1592
+ multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
1593
+ input_tensor: Tensor,
1594
+ output_tensor: Tensor,
1595
+ ) -> list[tuple[Tensor, Tensor, int]]:
1596
+ """
1597
+ Given inputs (the indices to forward), partition the input and output
1598
+ into smaller chunks and return them as a list of tuples
1599
+ (input[start_idx:end_idx], output[start_idx:end_idx], start_idx).
1600
+
1601
+ The caller must guarantee that input and output have non-zero dimension
1602
+ 0. The returned segments are guaranteed to completely and
1603
+ non-overlappingly cover the input tensor.
1604
+
1605
+ In non-multipass-prefetch mode, it returns the input/output tensor
1606
+ itself.
1607
+
1608
+ Args:
1609
+ multipass_prefetch_config (Optional[MultiPassPrefetchConfig]):
1610
+ A config for multi-pass cache prefetch. If None, multi-pass
1611
+ prefetch is not used.
1612
+
1613
+ input_tensor (Tensor): The input tensor to be partitioned
1614
+
1615
+ output_tensor (Tensor): The output tensor to be partitioned
1616
+
1617
+ Returns:
1618
+ A list of partitioned inputs and outputs (List[Tuple[Tensor,
1619
+ Tensor, int]])
1620
+ """
1621
+ if multipass_prefetch_config is None:
1622
+ return [(input_tensor, output_tensor, 0)]
1623
+ mpp_config: MultiPassPrefetchConfig = multipass_prefetch_config
1624
+
1625
+ N = input_tensor.size(0)
1626
+ if N <= mpp_config.num_passes or mpp_config.num_passes == 1:
1627
+ # One row per pass, just don't split
1628
+ return [(input_tensor, output_tensor, 0)]
1629
+
1630
+ pass_size: int = max(
1631
+ (N + mpp_config.num_passes - 1) // mpp_config.num_passes,
1632
+ mpp_config.min_splitable_pass_size,
1633
+ )
1634
+
1635
+ return list(
1636
+ zip(
1637
+ torch.split(input_tensor, pass_size),
1638
+ torch.split(output_tensor, pass_size),
1639
+ range(0, N, pass_size),
1640
+ )
1641
+ )
1642
+
1643
+ def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1644
+ """
1645
+ Get a state of a given tensor (`prefix`)
1646
+
1647
+ Args:
1648
+ prefix (str): A prefix of the state to obtain
1649
+
1650
+ Returns:
1651
+ A tuple of tensors corresponding to the obtained state containing
1652
+
1653
+ (1) A GPU state tensor
1654
+
1655
+ (2) A CPU state tensor
1656
+
1657
+ (3) A UVM state tensor
1658
+
1659
+ (4) A placement tensor - containing placements of embedding tables
1660
+ (torch.int32_t tensor). (0 = DEVICE, 1 = MANAGED, 2 =
1661
+ MANAGED_CACHING, 3 = HOST, 4 = MTIA)
1662
+
1663
+ (5) An offset tensor - containing the relative positions of
1664
+ embedding tables in the corresponding state tensor (GPU, CPU,
1665
+ or UVM state tensor)
1666
+ """
1667
+ if not hasattr(self, f"{prefix}_physical_placements"):
1668
+ raise DoesNotHavePrefix()
1669
+ dev_param = getattr(self, f"{prefix}_dev")
1670
+ host_param = getattr(self, f"{prefix}_host")
1671
+ uvm_param = getattr(self, f"{prefix}_uvm")
1672
+ placements = getattr(self, f"{prefix}_physical_placements")
1673
+ offsets = getattr(self, f"{prefix}_physical_offsets")
1674
+ return (
1675
+ dev_param,
1676
+ host_param,
1677
+ uvm_param,
1678
+ torch.tensor(placements, dtype=torch.int32),
1679
+ torch.tensor(offsets, dtype=torch.int64),
1680
+ )
1681
+
1682
+ def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
1683
+ """
1684
+ Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
1685
+ `prev_iter`, and `row_counter`)
1686
+
1687
+ Returns:
1688
+ A list of states. Each state is a tuple of tensors (GPU state
1689
+ tensor, CPU state tensor, UVM state tensor, placement tensor and
1690
+ offset tensor)
1691
+ """
1692
+ all_states = []
1693
+ for prefix in ["weights", "momentum1", "momentum2", "prev_iter", "row_counter"]:
1694
+ try:
1695
+ all_states.append(self.get_states(prefix))
1696
+ except DoesNotHavePrefix:
1697
+ pass
1698
+ return all_states
1699
+
1700
+ @torch.jit.export
1701
+ def get_cache_miss_counter(self) -> Tensor:
1702
+ """
1703
+ Get the cache miss counter. `cache_miss_counter` contains two items:
1704
+
1705
+ (1) `cache_miss_forward_count` which records the total number of
1706
+ forwards which has at least one cache miss
1707
+
1708
+ (2) `unique_cache_miss_count` which records to total number of unique
1709
+ (dedup) cache misses
1710
+
1711
+ Returns:
1712
+ The cache miss counter
1713
+ """
1714
+ # pyre-fixme[7]: Expected `Tensor` but got `Union[Module, Tensor]`.
1715
+ return self.cache_miss_counter
1716
+
1717
+ @torch.jit.export
1718
+ def get_table_wise_cache_miss(self) -> Tensor:
1719
+ """
1720
+ Get the table-wise cache miss tensor. `table_wise_cache_miss` contains
1721
+ all the cache miss count for each table in this embedding table object:
1722
+
1723
+ Returns:
1724
+ The table-wise cache miss tensor
1725
+ """
1726
+ return self.table_wise_cache_miss
1727
+
1728
+ # The callback function for AsyncTimer to record duration to different event
1729
+ def _report_duration(
1730
+ self,
1731
+ it_step: int,
1732
+ dur_ms: float,
1733
+ event_name: str,
1734
+ ) -> None:
1735
+ assert (
1736
+ self.stats_reporter
1737
+ ), "We should not be here. AsyncTimer only happens with reporter present."
1738
+ self.stats_reporter.report_duration(
1739
+ iteration_step=it_step,
1740
+ event_name=event_name,
1741
+ duration_ms=dur_ms,
1742
+ embedding_id=self.logging_table_name,
1743
+ tbe_id=self.uuid,
1744
+ )
1745
+
1746
+ def _get_tensor_memory(self, tensor_name: str) -> int:
1747
+ """Get memory usage of a tensor in bytes."""
1748
+ if not hasattr(self, tensor_name):
1749
+ self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
1750
+ return 0
1751
+ tensor = getattr(self, tensor_name)
1752
+ return tensor.numel() * tensor.element_size()
1753
+
1754
+ def _categorize_memory_by_location(
1755
+ self, tensor_names: list[str]
1756
+ ) -> tuple[int, int]:
1757
+ """Categorize memory into HBM and UVM for given tensors.
1758
+
1759
+ Returns:
1760
+ (hbm_bytes, uvm_bytes)
1761
+ """
1762
+ uvm_set = set(self._uvm_tensors_log)
1763
+ hbm_bytes = 0
1764
+ uvm_bytes = 0
1765
+
1766
+ for name in tensor_names:
1767
+ size = self._get_tensor_memory(name)
1768
+ if name in uvm_set:
1769
+ uvm_bytes += size
1770
+ else:
1771
+ hbm_bytes += size
1772
+
1773
+ return hbm_bytes, uvm_bytes
1774
+
1775
+ def _report_hbm_breakdown(
1776
+ self,
1777
+ stats_reporter: TBEStatsReporter,
1778
+ embeddings: int,
1779
+ optimizer_states: int,
1780
+ cache: int,
1781
+ total_static_sparse: int,
1782
+ ephemeral: int,
1783
+ ) -> None:
1784
+ """Report HBM memory breakdown to stats reporter."""
1785
+ stats_reporter.report_data_amount(
1786
+ iteration_step=self.step,
1787
+ event_name="tbe.hbm.embeddings",
1788
+ data_bytes=embeddings,
1789
+ embedding_id=self.logging_table_name,
1790
+ tbe_id=self.uuid,
1791
+ )
1792
+ stats_reporter.report_data_amount(
1793
+ iteration_step=self.step,
1794
+ event_name="tbe.hbm.optimizer_states",
1795
+ data_bytes=optimizer_states,
1796
+ embedding_id=self.logging_table_name,
1797
+ tbe_id=self.uuid,
1798
+ )
1799
+ stats_reporter.report_data_amount(
1800
+ iteration_step=self.step,
1801
+ event_name="tbe.hbm.cache",
1802
+ data_bytes=cache,
1803
+ embedding_id=self.logging_table_name,
1804
+ tbe_id=self.uuid,
1805
+ )
1806
+ stats_reporter.report_data_amount(
1807
+ iteration_step=self.step,
1808
+ event_name="tbe.hbm.total_static_sparse",
1809
+ data_bytes=total_static_sparse,
1810
+ embedding_id=self.logging_table_name,
1811
+ tbe_id=self.uuid,
1812
+ )
1813
+ stats_reporter.report_data_amount(
1814
+ iteration_step=self.step,
1815
+ event_name="tbe.hbm.ephemeral",
1816
+ data_bytes=ephemeral,
1817
+ embedding_id=self.logging_table_name,
1818
+ tbe_id=self.uuid,
1819
+ )
1820
+
1821
+ def _report_uvm_breakdown(
1822
+ self,
1823
+ stats_reporter: TBEStatsReporter,
1824
+ embeddings: int,
1825
+ optimizer_states: int,
1826
+ cache: int,
1827
+ total_static_sparse: int,
1828
+ ephemeral: int,
1829
+ ) -> None:
1830
+ """Report UVM memory breakdown to stats reporter."""
1831
+ stats_reporter.report_data_amount(
1832
+ iteration_step=self.step,
1833
+ event_name="tbe.uvm.embeddings",
1834
+ data_bytes=embeddings,
1835
+ embedding_id=self.logging_table_name,
1836
+ tbe_id=self.uuid,
1837
+ )
1838
+ stats_reporter.report_data_amount(
1839
+ iteration_step=self.step,
1840
+ event_name="tbe.uvm.optimizer_states",
1841
+ data_bytes=optimizer_states,
1842
+ embedding_id=self.logging_table_name,
1843
+ tbe_id=self.uuid,
1844
+ )
1845
+ stats_reporter.report_data_amount(
1846
+ iteration_step=self.step,
1847
+ event_name="tbe.uvm.cache",
1848
+ data_bytes=cache,
1849
+ embedding_id=self.logging_table_name,
1850
+ tbe_id=self.uuid,
1851
+ )
1852
+ stats_reporter.report_data_amount(
1853
+ iteration_step=self.step,
1854
+ event_name="tbe.uvm.total_static_sparse",
1855
+ data_bytes=total_static_sparse,
1856
+ embedding_id=self.logging_table_name,
1857
+ tbe_id=self.uuid,
1858
+ )
1859
+ stats_reporter.report_data_amount(
1860
+ iteration_step=self.step,
1861
+ event_name="tbe.uvm.ephemeral",
1862
+ data_bytes=ephemeral,
1863
+ embedding_id=self.logging_table_name,
1864
+ tbe_id=self.uuid,
1865
+ )
1866
+
1867
+ @torch.jit.ignore
1868
+ def _report_tbe_mem_usage(self) -> None:
1869
+ if self.stats_reporter is None:
1870
+ return
1871
+
1872
+ stats_reporter: TBEStatsReporter = self.stats_reporter
1873
+ if not stats_reporter.should_report(self.step):
1874
+ return
1875
+
1876
+ # Calculate total memory from all parameters and buffers (always needed)
1877
+ total_mem_usage = sum(
1878
+ p.numel() * p.element_size() for p in self.parameters()
1879
+ ) + sum(b.numel() * b.element_size() for b in self.buffers())
1880
+
1881
+ # Calculate total HBM and UVM usage (always needed)
1882
+ if self.use_cpu:
1883
+ total_hbm_usage = 0
1884
+ total_uvm_usage = total_mem_usage
1885
+ else:
1886
+ total_uvm_usage = sum(
1887
+ self._get_tensor_memory(name)
1888
+ for name in self._uvm_tensors_log
1889
+ if hasattr(self, name)
1890
+ )
1891
+ total_hbm_usage = total_mem_usage - total_uvm_usage
1892
+
1893
+ # Report total memory usage metrics (always reported for backward compatibility)
1894
+ stats_reporter.report_data_amount(
1895
+ iteration_step=self.step,
1896
+ event_name="tbe.total_hbm_usage",
1897
+ data_bytes=total_hbm_usage,
1898
+ embedding_id=self.logging_table_name,
1899
+ tbe_id=self.uuid,
1900
+ )
1901
+ stats_reporter.report_data_amount(
1902
+ iteration_step=self.step,
1903
+ event_name="tbe.total_uvm_usage",
1904
+ data_bytes=total_uvm_usage,
1905
+ embedding_id=self.logging_table_name,
1906
+ tbe_id=self.uuid,
1907
+ )
1908
+
1909
+ # Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
1910
+ # steps since static sparse memory (weights, optimizer states, cache) is constant
1911
+ if (
1912
+ self.detailed_mem_breakdown_report_count
1913
+ >= self.max_detailed_mem_breakdown_reports
1914
+ ):
1915
+ return
1916
+ self.detailed_mem_breakdown_report_count += 1
1917
+
1918
+ # Tensor groups for sparse memory categorization
1919
+ weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
1920
+ optimizer_tensors = [
1921
+ "momentum1_dev",
1922
+ "momentum1_host",
1923
+ "momentum1_uvm",
1924
+ "momentum2_dev",
1925
+ "momentum2_host",
1926
+ "momentum2_uvm",
1927
+ ]
1928
+ cache_tensors = [
1929
+ "lxu_cache_weights",
1930
+ "lxu_cache_state",
1931
+ "lxu_state",
1932
+ "cache_hash_size_cumsum",
1933
+ "cache_index_table_map",
1934
+ "cache_miss_counter",
1935
+ "lxu_cache_locking_counter",
1936
+ ]
1937
+
1938
+ # Calculate total memory for each component
1939
+ weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
1940
+ optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
1941
+ cache_total = sum(self._get_tensor_memory(t) for t in cache_tensors)
1942
+
1943
+ # Categorize memory by location (HBM vs UVM)
1944
+ if self.use_cpu:
1945
+ weights_hbm, weights_uvm = 0, weights_total
1946
+ opt_hbm, opt_uvm = 0, optimizer_total
1947
+ cache_hbm, cache_uvm = 0, cache_total
1948
+ else:
1949
+ weights_hbm, weights_uvm = self._categorize_memory_by_location(
1950
+ weight_tensors
1951
+ )
1952
+ opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
1953
+ cache_hbm, cache_uvm = self._categorize_memory_by_location(cache_tensors)
1954
+
1955
+ # Calculate ephemeral memory split between HBM and UVM
1956
+ static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
1957
+ static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
1958
+ ephemeral_hbm = total_hbm_usage - static_sparse_hbm
1959
+ ephemeral_uvm = total_uvm_usage - static_sparse_uvm
1960
+
1961
+ # Report granular memory breakdowns
1962
+ self._report_hbm_breakdown(
1963
+ stats_reporter,
1964
+ weights_hbm,
1965
+ opt_hbm,
1966
+ cache_hbm,
1967
+ static_sparse_hbm,
1968
+ ephemeral_hbm,
1969
+ )
1970
+ self._report_uvm_breakdown(
1971
+ stats_reporter,
1972
+ weights_uvm,
1973
+ opt_uvm,
1974
+ cache_uvm,
1975
+ static_sparse_uvm,
1976
+ ephemeral_uvm,
1977
+ )
1978
+
1979
+ @torch.jit.ignore
1980
+ def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
1981
+ if self.stats_reporter is None:
1982
+ return data
1983
+ stats_reporter: TBEStatsReporter = self.stats_reporter
1984
+ if stats_reporter.should_report(self.step):
1985
+ stats_reporter.report_data_amount(
1986
+ iteration_step=self.step,
1987
+ event_name=f"tbe.{event}_size",
1988
+ data_bytes=data.element_size() * data.numel(),
1989
+ embedding_id=self.logging_table_name,
1990
+ tbe_id=self.uuid,
1991
+ )
1992
+ stats_reporter.report_data_amount(
1993
+ iteration_step=self.step,
1994
+ event_name=f"tbe.{event}_count",
1995
+ data_bytes=data.numel(),
1996
+ embedding_id=self.logging_table_name,
1997
+ tbe_id=self.uuid,
1998
+ )
1999
+ return data
2000
+
2001
+ @torch.jit.ignore
2002
+ def _generate_vbe_metadata(
2003
+ self,
2004
+ offsets: Tensor,
2005
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2006
+ ) -> invokers.lookup_args.VBEMetadata:
2007
+ # Blocking D2H copy, but only runs at first call
2008
+ self.feature_dims = self.feature_dims.cpu()
2009
+ if batch_size_per_feature_per_rank is not None:
2010
+ assert self.optimizer in (
2011
+ OptimType.EXACT_ROWWISE_ADAGRAD,
2012
+ OptimType.EXACT_SGD,
2013
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
2014
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
2015
+ OptimType.NONE,
2016
+ OptimType.ADAM,
2017
+ ), (
2018
+ "Variable batch size TBE support is enabled for "
2019
+ "OptimType.EXACT_ROWWISE_ADAGRAD,EXACT_SGD, "
2020
+ "ENSEMBLE_ROWWISE_ADAGRAD, NONE, and ADAM only"
2021
+ )
2022
+ return generate_vbe_metadata(
2023
+ offsets,
2024
+ batch_size_per_feature_per_rank,
2025
+ self.pooling_mode,
2026
+ self.feature_dims,
2027
+ self.current_device,
2028
+ )
2029
+
2030
+ @torch.jit.ignore
2031
+ def _feature_is_enabled(self, feature: FeatureGateName) -> bool:
2032
+ # Define proxy method so that it can be marked with @torch.jit.ignore
2033
+ # This allows models using this class to compile correctly
2034
+ return FeatureGate.is_enabled(feature)
2035
+
2036
+ def writeback_update_gradient(
2037
+ self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
2038
+ ) -> Tensor:
2039
+ if indices.numel() == 0:
2040
+ return grad[0]
2041
+ num_of_tables = len(set(self.feature_table_map))
2042
+ assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
2043
+ batch_size = offsets.shape[0] // num_of_tables
2044
+ max_indices = indices.max()
2045
+ non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
2046
+ # disable dedup across different table
2047
+ indices = ((offsets[non_empty_index]) // batch_size) * (
2048
+ 1 + max_indices
2049
+ ) + indices
2050
+ grad = grad[0]
2051
+ _, idx, counts = torch.unique(
2052
+ indices, dim=0, sorted=True, return_inverse=True, return_counts=True
2053
+ )
2054
+ _, ind_sorted = torch.sort(idx, stable=True)
2055
+ cum_sum = counts.cumsum(0)
2056
+ cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
2057
+ first_indicies = ind_sorted[cum_sum]
2058
+ mask = torch.zeros_like(grad, device=grad.device)
2059
+ original_index = non_empty_index[first_indicies]
2060
+
2061
+ mask[original_index] = grad[original_index]
2062
+ return mask
2063
+
2064
+ # pyre-fixme[2]: For 1st argument expected not ANY
2065
+ def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
2066
+ indices = self._indices
2067
+ offsets = self._offsets
2068
+
2069
+ return (self.writeback_update_gradient(indices, offsets, grad),)
2070
+
2071
+ def forward( # noqa: C901
2072
+ self,
2073
+ indices: Tensor,
2074
+ offsets: Tensor,
2075
+ per_sample_weights: Optional[Tensor] = None,
2076
+ feature_requires_grad: Optional[Tensor] = None,
2077
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2078
+ total_unique_indices: Optional[int] = None,
2079
+ hash_zch_identities: Optional[Tensor] = None,
2080
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2081
+ ) -> Tensor:
2082
+ """
2083
+ The forward pass function that
2084
+
2085
+ (1) Performs input bound checking
2086
+
2087
+ (2) Generates necessary variable batch size embedding (VBE) metadata (if
2088
+ VBE is used)
2089
+
2090
+ (3) Prefetches data from UVM to cache (if
2091
+ `EmbeddingLocation.MANAGED_CACHING` is used and the user has not
2092
+ explicitly prefetched data)
2093
+
2094
+ (4) Performs the embedding table lookup by invoking a corresponding
2095
+ Autograd function (based on the chosen optimizer)
2096
+
2097
+ Args:
2098
+ indices (Tensor): A 1D-tensor that contains indices to be looked up
2099
+ from all embedding table
2100
+
2101
+ offsets (Tensor): A 1D-tensor that conatins offsets of indices.
2102
+ Shape `(B * T + 1)` where `B` = batch size and `T` = the number
2103
+ of features. `offsets[t * B + b + 1] - offsets[t * B + b]` is
2104
+ the length of bag `b` of feature `t`
2105
+
2106
+ per_sample_weights (Optional[Tensor]): An optional 1D-float-tensor that
2107
+ contains per sample weights. If None, **unweighted** embedding
2108
+ lookup will be perform. Otherwise, **weighted** will be used. The
2109
+ length of this tensor must be the same as the length of the
2110
+ `indices` tensor. The value of `per_sample_weights[i]` will be
2111
+ used to multiply with every element in the looked up row
2112
+ `indices[i]`, where `0 <= i < len(per_sample_weights)`.
2113
+
2114
+ feature_requires_grad (Optional[Tensor]): An optional 1D-tensor for
2115
+ indicating if `per_sample_weights` requires gradient. The
2116
+ length of the tensor must be equal to the number of features
2117
+
2118
+ batch_size_per_feature_per_rank (Optional[List[List[int]]]): An
2119
+ optional 2D-tensor that contains batch sizes for every rank and
2120
+ every feature. If None, TBE assumes that **every feature has the
2121
+ same batch size** and computes the batch size from the `offsets`
2122
+ shape. Otherwise, TBE assumes that different features can have
2123
+ different batch sizes and uses the **variable batch size
2124
+ embedding look up mode (VBE)**. Shape (number of features,
2125
+ number of ranks). `batch_size_per_feature_per_rank[f][r]`
2126
+ represents the batch size of feature `f` and rank `r`
2127
+
2128
+ total_unique_indices (Optional[int]): An optional integer that
2129
+ represents the total number of unique indices. This value must
2130
+ be set when using `OptimType.NONE`. This is because TBE
2131
+ requires this information for allocating the weight gradient
2132
+ tensor in the backward pass.
2133
+
2134
+ hash_zch_identities (Optional[Tensor]): The original raw IDs before
2135
+ remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
2136
+ populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
2137
+ and is required for Raw Embedding Streaming (RES) to maintain
2138
+ consistency between training and inference.
2139
+
2140
+ Returns:
2141
+ A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
2142
+ batch size and `total_D` = the sum of all embedding dimensions in the
2143
+ table
2144
+
2145
+ Example:
2146
+
2147
+ >>> import torch
2148
+ >>>
2149
+ >>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
2150
+ >>> EmbeddingLocation,
2151
+ >>> )
2152
+ >>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2153
+ >>> SplitTableBatchedEmbeddingBagsCodegen,
2154
+ >>> ComputeDevice,
2155
+ >>> )
2156
+ >>>
2157
+ >>> # Two tables
2158
+ >>> embedding_specs = [
2159
+ >>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
2160
+ >>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
2161
+ >>> ]
2162
+ >>>
2163
+ >>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
2164
+ >>> tbe.init_embedding_weights_uniform(-1, 1)
2165
+ >>>
2166
+ >>> print(tbe.split_embedding_weights())
2167
+ [tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021],
2168
+ [-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790],
2169
+ [-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]],
2170
+ device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273],
2171
+ [-0.5399, -0.0229, -0.1455, -0.8770],
2172
+ [-0.9520, 0.4593, -0.7169, 0.6307],
2173
+ [-0.1765, 0.8757, 0.8614, 0.2051],
2174
+ [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
2175
+
2176
+
2177
+ >>> # Batch size = 3
2178
+ >>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
2179
+ >>> device="cuda",
2180
+ >>> dtype=torch.long)
2181
+ >>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
2182
+ >>> device="cuda",
2183
+ >>> dtype=torch.long)
2184
+ >>>
2185
+ >>> output = tbe(indices, offsets)
2186
+ >>>
2187
+ >>> # Batch size = 3, total embedding dimension = 12
2188
+ >>> print(output.shape)
2189
+ torch.Size([3, 12])
2190
+
2191
+ >>> print(output)
2192
+ tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769,
2193
+ -0.7164, 0.8528, 0.7159, -0.6719],
2194
+ [-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
2195
+ -1.2637, -0.9427, -1.8902, 0.3754],
2196
+ [-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
2197
+ -0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0',
2198
+ grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
2199
+
2200
+ """
2201
+ (
2202
+ indices,
2203
+ offsets,
2204
+ per_sample_weights,
2205
+ vbe_metadata,
2206
+ ) = self.prepare_inputs(
2207
+ indices,
2208
+ offsets,
2209
+ per_sample_weights,
2210
+ batch_size_per_feature_per_rank,
2211
+ force_cast_input_types=True,
2212
+ prefetch_pipeline=False,
2213
+ )
2214
+
2215
+ # Print input stats if enable (for debugging purpose only)
2216
+ self._debug_print_input_stats(indices, offsets, per_sample_weights)
2217
+
2218
+ # Extract and Write input stats if enable
2219
+ if self._report_input_params is not None:
2220
+ self._report_input_params(
2221
+ feature_rows=self.rows_per_table,
2222
+ feature_dims=self.feature_dims,
2223
+ iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
2224
+ indices=indices,
2225
+ offsets=offsets,
2226
+ op_id=self.uuid,
2227
+ per_sample_weights=per_sample_weights,
2228
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2229
+ )
2230
+
2231
+ if not is_torchdynamo_compiling():
2232
+ # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
2233
+
2234
+ # Storing tensors for linear_cache_indices recomputation
2235
+ self._indices = indices
2236
+ self._offsets = offsets
2237
+ self._vbe_B_offsets = vbe_metadata.B_offsets
2238
+ self._vbe_max_B = vbe_metadata.max_B
2239
+
2240
+ self.step += 1
2241
+ self._report_io_size_count("fwd_input", indices)
2242
+ self._report_tbe_mem_usage()
2243
+
2244
+ if self.tbe_input_multiplexer is not None:
2245
+ tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer
2246
+ if tbe_input_multiplexer.should_run(self.step):
2247
+ tbe_input_multiplexer.run(
2248
+ tbe_input_info=TBEInputInfo(
2249
+ indices, offsets, batch_size_per_feature_per_rank
2250
+ )
2251
+ )
2252
+
2253
+ if len(self.timesteps_prefetched) == 0:
2254
+ # In forward, we don't enable multi-pass prefetch as we want the process
2255
+ # to be as fast as possible and memory usage doesn't matter (will be recycled
2256
+ # by dense fwd/bwd)
2257
+ self._prefetch(
2258
+ indices,
2259
+ offsets,
2260
+ vbe_metadata,
2261
+ multipass_prefetch_config=None,
2262
+ hash_zch_identities=hash_zch_identities,
2263
+ hash_zch_runtime_meta=hash_zch_runtime_meta,
2264
+ )
2265
+
2266
+ if len(self.timesteps_prefetched) > 0:
2267
+ self.timesteps_prefetched.pop(0)
2268
+
2269
+ self.lxu_cache_locations = (
2270
+ self.lxu_cache_locations_empty
2271
+ if len(self.lxu_cache_locations_list) == 0
2272
+ else self.lxu_cache_locations_list.pop(0)
2273
+ )
2274
+ common_args = invokers.lookup_args.CommonArgs(
2275
+ placeholder_autograd_tensor=self.placeholder_autograd_tensor,
2276
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
2277
+ # `Union[Module, Tensor]`.
2278
+ dev_weights=self.weights_dev,
2279
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
2280
+ # `Union[Module, Tensor]`.
2281
+ host_weights=self.weights_host,
2282
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
2283
+ # `Union[Module, Tensor]`.
2284
+ uvm_weights=self.weights_uvm,
2285
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
2286
+ # `Union[Module, Tensor]`.
2287
+ lxu_cache_weights=self.lxu_cache_weights,
2288
+ # pyre-fixme[6]: For 6th argument expected `Tensor` but got
2289
+ # `Union[Module, Tensor]`.
2290
+ weights_placements=self.weights_placements,
2291
+ # pyre-fixme[6]: For 7th argument expected `Tensor` but got
2292
+ # `Union[Module, Tensor]`.
2293
+ weights_offsets=self.weights_offsets,
2294
+ D_offsets=self.D_offsets,
2295
+ total_D=self.total_D,
2296
+ max_D=self.max_D,
2297
+ hash_size_cumsum=self.hash_size_cumsum,
2298
+ total_hash_size_bits=self.total_hash_size_bits,
2299
+ indices=indices,
2300
+ offsets=offsets,
2301
+ pooling_mode=self.pooling_mode,
2302
+ indice_weights=per_sample_weights,
2303
+ feature_requires_grad=feature_requires_grad,
2304
+ lxu_cache_locations=self.lxu_cache_locations,
2305
+ # Pass the local_uvm_cache_stats bc only that information is
2306
+ # relevant for the current iteration
2307
+ uvm_cache_stats=(
2308
+ self.local_uvm_cache_stats
2309
+ if (
2310
+ self.gather_uvm_cache_stats
2311
+ # Unique conflict misses are only collected when using CacheAlgorithm.LRU
2312
+ and self.cache_algorithm == CacheAlgorithm.LRU
2313
+ )
2314
+ else None
2315
+ ),
2316
+ output_dtype=self.output_dtype,
2317
+ vbe_metadata=vbe_metadata,
2318
+ is_experimental=self.is_experimental,
2319
+ use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
2320
+ use_homogeneous_placements=self.use_homogeneous_placements,
2321
+ learning_rate_tensor=self.learning_rate_tensor,
2322
+ info_B_num_bits=self.info_B_num_bits,
2323
+ info_B_mask=self.info_B_mask,
2324
+ )
2325
+
2326
+ if self.optimizer == OptimType.NONE:
2327
+ assert (
2328
+ total_unique_indices is not None
2329
+ and total_unique_indices <= indices.numel()
2330
+ ), f"OptimType.NONE requires total_unique_indices. Please pass it or check the value (total_unique_indices = {total_unique_indices})"
2331
+ return self._report_io_size_count(
2332
+ "fwd_output",
2333
+ invokers.lookup_none.invoke(
2334
+ common_args, self.optimizer_args, total_unique_indices
2335
+ ),
2336
+ )
2337
+ elif self.optimizer == OptimType.EXACT_SGD:
2338
+ return self._report_io_size_count(
2339
+ "fwd_output",
2340
+ invokers.lookup_sgd.invoke(common_args, self.optimizer_args),
2341
+ )
2342
+
2343
+ momentum1 = invokers.lookup_args.Momentum(
2344
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
2345
+ # `Union[Module, Tensor]`.
2346
+ dev=self.momentum1_dev,
2347
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
2348
+ # `Union[Module, Tensor]`.
2349
+ host=self.momentum1_host,
2350
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
2351
+ # `Union[Module, Tensor]`.
2352
+ uvm=self.momentum1_uvm,
2353
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
2354
+ # `Union[Module, Tensor]`.
2355
+ offsets=self.momentum1_offsets,
2356
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
2357
+ # `Union[Module, Tensor]`.
2358
+ placements=self.momentum1_placements,
2359
+ )
2360
+
2361
+ if self.optimizer == OptimType.LARS_SGD:
2362
+ return self._report_io_size_count(
2363
+ "fwd_output",
2364
+ invokers.lookup_lars_sgd.invoke(
2365
+ common_args, self.optimizer_args, momentum1
2366
+ ),
2367
+ )
2368
+ if self.optimizer == OptimType.EXACT_ADAGRAD:
2369
+ return self._report_io_size_count(
2370
+ "fwd_output",
2371
+ invokers.lookup_adagrad.invoke(
2372
+ common_args, self.optimizer_args, momentum1
2373
+ ),
2374
+ )
2375
+
2376
+ momentum2 = invokers.lookup_args.Momentum(
2377
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
2378
+ # `Union[Module, Tensor]`.
2379
+ dev=self.momentum2_dev,
2380
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
2381
+ # `Union[Module, Tensor]`.
2382
+ host=self.momentum2_host,
2383
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
2384
+ # `Union[Module, Tensor]`.
2385
+ uvm=self.momentum2_uvm,
2386
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
2387
+ # `Union[Module, Tensor]`.
2388
+ offsets=self.momentum2_offsets,
2389
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
2390
+ # `Union[Module, Tensor]`.
2391
+ placements=self.momentum2_placements,
2392
+ )
2393
+
2394
+ # Although self.iter_cpu is created on CPU. It might be transferred to
2395
+ # GPU by the user. So, we need to transfer it to CPU explicitly. This
2396
+ # should be done only once.
2397
+ self.iter_cpu = self.iter_cpu.cpu()
2398
+
2399
+ # Sync with loaded state
2400
+ if (
2401
+ not is_torchdynamo_compiling()
2402
+ ): # wrap to make it compatible with PT2 compile
2403
+ if self.iter_cpu.item() == 0:
2404
+ self.iter_cpu.fill_(self.iter.cpu().item())
2405
+ # Increment the iteration counter
2406
+ iter_int = int(self.iter_cpu.add_(1).item()) # used for local computation
2407
+ self.iter.add_(1) # used for checkpointing
2408
+
2409
+ row_counter = invokers.lookup_args.Momentum(
2410
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
2411
+ # `Union[Module, Tensor]`.
2412
+ dev=self.row_counter_dev,
2413
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
2414
+ # `Union[Module, Tensor]`.
2415
+ host=self.row_counter_host,
2416
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
2417
+ # `Union[Module, Tensor]`.
2418
+ uvm=self.row_counter_uvm,
2419
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
2420
+ # `Union[Module, Tensor]`.
2421
+ offsets=self.row_counter_offsets,
2422
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
2423
+ # `Union[Module, Tensor]`.
2424
+ placements=self.row_counter_placements,
2425
+ )
2426
+
2427
+ if self.optimizer == OptimType.ADAM:
2428
+ return self._report_io_size_count(
2429
+ "fwd_output",
2430
+ invokers.lookup_adam.invoke(
2431
+ common_args,
2432
+ self.optimizer_args,
2433
+ momentum1,
2434
+ momentum2,
2435
+ iter_int,
2436
+ row_counter=(
2437
+ row_counter if self.use_rowwise_bias_correction else None
2438
+ ),
2439
+ ),
2440
+ )
2441
+ if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2442
+ return self._report_io_size_count(
2443
+ "fwd_output",
2444
+ invokers.lookup_partial_rowwise_adam.invoke(
2445
+ common_args,
2446
+ self.optimizer_args,
2447
+ momentum1,
2448
+ momentum2,
2449
+ iter_int,
2450
+ ),
2451
+ )
2452
+ if self.optimizer == OptimType.LAMB:
2453
+ return self._report_io_size_count(
2454
+ "fwd_output",
2455
+ invokers.lookup_lamb.invoke(
2456
+ common_args,
2457
+ self.optimizer_args,
2458
+ momentum1,
2459
+ momentum2,
2460
+ iter_int,
2461
+ ),
2462
+ )
2463
+ if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB:
2464
+ return self._report_io_size_count(
2465
+ "fwd_output",
2466
+ invokers.lookup_partial_rowwise_lamb.invoke(
2467
+ common_args,
2468
+ self.optimizer_args,
2469
+ momentum1,
2470
+ momentum2,
2471
+ iter_int,
2472
+ ),
2473
+ )
2474
+
2475
+ prev_iter = invokers.lookup_args.Momentum(
2476
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
2477
+ # `Union[Module, Tensor]`.
2478
+ dev=self.prev_iter_dev,
2479
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
2480
+ # `Union[Module, Tensor]`.
2481
+ host=self.prev_iter_host,
2482
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
2483
+ # `Union[Module, Tensor]`.
2484
+ uvm=self.prev_iter_uvm,
2485
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
2486
+ # `Union[Module, Tensor]`.
2487
+ offsets=self.prev_iter_offsets,
2488
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
2489
+ # `Union[Module, Tensor]`.
2490
+ placements=self.prev_iter_placements,
2491
+ )
2492
+
2493
+ if self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD:
2494
+ with torch.no_grad():
2495
+ if self.training:
2496
+ self.ema_inplace(self._emainplace_mode)
2497
+ return self._report_io_size_count(
2498
+ "fwd_output",
2499
+ invokers.lookup_rowwise_adagrad.invoke(
2500
+ common_args,
2501
+ self.optimizer_args,
2502
+ momentum1,
2503
+ ),
2504
+ )
2505
+
2506
+ if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
2507
+ assert self._feature_is_enabled(
2508
+ FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD
2509
+ ), "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!"
2510
+ with torch.no_grad():
2511
+ if self.training:
2512
+ self.ensemble_and_swap(self._ensemble_mode)
2513
+ return self._report_io_size_count(
2514
+ "fwd_output",
2515
+ invokers.lookup_rowwise_adagrad.invoke(
2516
+ common_args,
2517
+ self.optimizer_args,
2518
+ momentum1,
2519
+ ),
2520
+ )
2521
+
2522
+ if self._used_rowwise_adagrad_with_counter:
2523
+ if (
2524
+ self._max_counter_update_freq > 0
2525
+ and iter_int % self._max_counter_update_freq == 0
2526
+ ):
2527
+ row_counter_dev = self.row_counter_dev.detach()
2528
+ if row_counter_dev.numel() > 0:
2529
+ self.max_counter[0] = torch.max(row_counter_dev).cpu().item() + 1
2530
+ else:
2531
+ self.max_counter[0] = 1
2532
+
2533
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2534
+ if self._used_rowwise_adagrad_with_counter:
2535
+ return self._report_io_size_count(
2536
+ "fwd_output",
2537
+ invokers.lookup_rowwise_adagrad_with_counter.invoke(
2538
+ common_args,
2539
+ self.optimizer_args,
2540
+ momentum1,
2541
+ prev_iter,
2542
+ row_counter,
2543
+ iter_int,
2544
+ self.max_counter.item(),
2545
+ mixed_D=self.mixed_D,
2546
+ ),
2547
+ )
2548
+ elif self._used_rowwise_adagrad_with_global_weight_decay:
2549
+ apply_global_weight_decay = (
2550
+ iter_int >= self.gwd_start_iter and self.training
2551
+ )
2552
+ return self._report_io_size_count(
2553
+ "fwd_output",
2554
+ invokers.lookup_rowwise_adagrad.invoke(
2555
+ common_args,
2556
+ self.optimizer_args,
2557
+ momentum1,
2558
+ iter=iter_int,
2559
+ apply_global_weight_decay=apply_global_weight_decay,
2560
+ # pyre-fixme[6]: For 6th argument expected
2561
+ # `Optional[Tensor]` but got `Union[Module, Tensor]`.
2562
+ prev_iter_dev=self.prev_iter_dev,
2563
+ gwd_lower_bound=self.gwd_lower_bound,
2564
+ mixed_D=self.mixed_D,
2565
+ ),
2566
+ )
2567
+ else:
2568
+ return self._report_io_size_count(
2569
+ "fwd_output",
2570
+ invokers.lookup_rowwise_adagrad.invoke(
2571
+ common_args,
2572
+ self.optimizer_args,
2573
+ momentum1,
2574
+ mixed_D=self.mixed_D,
2575
+ ),
2576
+ )
2577
+
2578
+ raise ValueError(f"Invalid OptimType: {self.optimizer}")
2579
+
2580
+ def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
2581
+ """
2582
+ Perform ema operations on the full sparse embedding tables.
2583
+ We organize the sparse table, in the following way.
2584
+
2585
+ Emb_table:
2586
+ -------------------------------------------------
2587
+ - -- -
2588
+ - Fast part -- Slow part -
2589
+ - (RL) main part -- target part -
2590
+ - -- -
2591
+ -------------------------------------------------
2592
+
2593
+ In every "step_ema" step, we perform
2594
+ slow_part += coef_ema * (fast_part - slow_part)
2595
+ """
2596
+ iter_int = int(self.iter_cpu.item())
2597
+ if iter_int % int(emainplace_mode["step_ema"]) == 0 and iter_int >= int(
2598
+ emainplace_mode["step_start"]
2599
+ ):
2600
+ weights = self.split_embedding_weights()
2601
+ for table_i, (_, dim, _, _) in enumerate(self.embedding_specs):
2602
+ assert (
2603
+ dim & 1 == 0
2604
+ ), f"table dimension {dim} is odd, not supported for ema_inplace_rowwise_adagrad" # make sure that the dimension is even
2605
+ weights[table_i][:, dim // 2 :].data.lerp_(
2606
+ weights[table_i][:, : dim // 2].data,
2607
+ emainplace_mode["step_ema_coef"],
2608
+ )
2609
+
2610
+ def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
2611
+ """
2612
+ Perform ensemble and swap operations on the full sparse embedding tables.
2613
+
2614
+ Returns:
2615
+ Sparse embedding weights and optimizer states will be updated in-place.
2616
+ """
2617
+ iter_int = int(self.iter_cpu.item())
2618
+ should_ema = iter_int % int(ensemble_mode["step_ema"]) == 0
2619
+ should_swap = iter_int % int(ensemble_mode["step_swap"]) == 0
2620
+ if should_ema or should_swap:
2621
+ weights = self.split_embedding_weights()
2622
+ states = self.split_optimizer_states()
2623
+ coef_ema = (
2624
+ 0.0
2625
+ if iter_int <= int(ensemble_mode["step_start"])
2626
+ else ensemble_mode["step_ema_coef"]
2627
+ )
2628
+ for i in range(len(self.embedding_specs)):
2629
+ # 0) copying weights from gpu to cpu
2630
+ weights_cpu = weights[i].to(
2631
+ dtype=states[i][1].dtype, device=states[i][1].device
2632
+ )
2633
+ # 1) ema step
2634
+ if should_ema:
2635
+ states[i][1].lerp_(weights_cpu, 1.0 - coef_ema)
2636
+ # 2) swap step
2637
+ if should_swap:
2638
+ weights[i].copy_(states[i][1], non_blocking=True)
2639
+ # 3) post-processing step
2640
+ if should_ema:
2641
+ if int(ensemble_mode["step_mode"]) == 0: # embedding scaling
2642
+ states[i][1].mul_(0.0)
2643
+ # elif int(ensemble_mode["step_mode"]) == 2: pure ema
2644
+
2645
+ def reset_uvm_cache_stats(self) -> None:
2646
+ assert (
2647
+ self.gather_uvm_cache_stats
2648
+ ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
2649
+ self.uvm_cache_stats.zero_()
2650
+ self.local_uvm_cache_stats.zero_()
2651
+
2652
+ def get_uvm_cache_stats(self, use_local_cache: bool = False) -> Tensor:
2653
+ assert (
2654
+ self.gather_uvm_cache_stats
2655
+ ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
2656
+ return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
2657
+
2658
+ def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
2659
+ snapshot = self.get_uvm_cache_stats(use_local_cache)
2660
+ if use_local_cache:
2661
+ return snapshot.tolist()
2662
+
2663
+ # Stats are accumulated over multiple steps. Compute delta, and update state.
2664
+ delta = snapshot - self.last_uvm_cache_print_state
2665
+ self.last_uvm_cache_print_state = snapshot.clone()
2666
+ return delta.tolist()
2667
+
2668
+ @torch.jit.ignore
2669
+ def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
2670
+ # TODO: Create a separate reporter class to unify the stdlog reporting
2671
+ uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
2672
+ N = max(1, uvm_cache_stats[0])
2673
+ m = {
2674
+ "N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
2675
+ "requested_indices": uvm_cache_stats[
2676
+ UVMCacheStatsIndex.num_requested_indices
2677
+ ]
2678
+ / N,
2679
+ "unique_indices": uvm_cache_stats[UVMCacheStatsIndex.num_unique_indices]
2680
+ / N,
2681
+ "unique_misses": uvm_cache_stats[UVMCacheStatsIndex.num_unique_misses] / N,
2682
+ "conflict_unique_misses": uvm_cache_stats[
2683
+ UVMCacheStatsIndex.num_conflict_unique_misses
2684
+ ]
2685
+ / N,
2686
+ "conflict_misses": uvm_cache_stats[UVMCacheStatsIndex.num_conflict_misses]
2687
+ / N,
2688
+ }
2689
+ if uvm_cache_stats[1]:
2690
+ m.update(
2691
+ {
2692
+ "unique indices / requested indices": uvm_cache_stats[
2693
+ UVMCacheStatsIndex.num_unique_indices
2694
+ ]
2695
+ / uvm_cache_stats[UVMCacheStatsIndex.num_requested_indices],
2696
+ "unique misses / requested indices": uvm_cache_stats[
2697
+ UVMCacheStatsIndex.num_unique_misses
2698
+ ]
2699
+ / uvm_cache_stats[UVMCacheStatsIndex.num_requested_indices],
2700
+ }
2701
+ )
2702
+ self.log(f"uvm_cache_stats={m}")
2703
+
2704
+ @torch.jit.ignore
2705
+ def _report_uvm_cache_stats(self) -> None:
2706
+ if self.stats_reporter is None:
2707
+ return
2708
+ stats_reporter: TBEStatsReporter = self.stats_reporter
2709
+ passed_steps = self.step - self.last_reported_step
2710
+ if passed_steps == 0:
2711
+ return
2712
+ if not stats_reporter.should_report(self.step):
2713
+ return
2714
+
2715
+ uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
2716
+ use_local_cache=False
2717
+ ).tolist()
2718
+ self.last_reported_step = self.step
2719
+
2720
+ if len(self.last_reported_uvm_stats) == 0:
2721
+ self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
2722
+ uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
2723
+ for i in range(len(uvm_cache_stats)):
2724
+ uvm_cache_stats_delta[i] = (
2725
+ uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
2726
+ )
2727
+ self.last_reported_uvm_stats = uvm_cache_stats
2728
+
2729
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
2730
+ # a function.
2731
+ element_size = self.lxu_cache_weights.element_size()
2732
+ for stat_index in UVMCacheStatsIndex:
2733
+ stats_reporter.report_data_amount(
2734
+ iteration_step=self.step,
2735
+ event_name=f"tbe.prefetch.cache_stats_by_data_size.{stat_index.name.lower()}",
2736
+ data_bytes=int(
2737
+ uvm_cache_stats_delta[stat_index.value]
2738
+ * element_size
2739
+ * self.max_D_cache
2740
+ / passed_steps
2741
+ ),
2742
+ embedding_id=self.logging_table_name,
2743
+ tbe_id=self.uuid,
2744
+ )
2745
+
2746
+ def prefetch(
2747
+ self,
2748
+ indices: Tensor,
2749
+ offsets: Tensor,
2750
+ forward_stream: Optional[torch.cuda.Stream] = None,
2751
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2752
+ ) -> None:
2753
+ if self.prefetch_stream is None and forward_stream is not None:
2754
+ self.prefetch_stream = torch.cuda.current_stream()
2755
+ assert (
2756
+ self.prefetch_stream != forward_stream
2757
+ ), "prefetch_stream and forward_stream should not be the same stream"
2758
+
2759
+ indices, offsets, _, vbe_metadata = self.prepare_inputs(
2760
+ indices,
2761
+ offsets,
2762
+ per_sample_weights=None,
2763
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2764
+ force_cast_input_types=False,
2765
+ prefetch_pipeline=self.prefetch_pipeline,
2766
+ )
2767
+
2768
+ with self._recording_to_timer(
2769
+ self.prefetch_duration_timer,
2770
+ context=self.step,
2771
+ stream=torch.cuda.current_stream(),
2772
+ ):
2773
+ self._prefetch(
2774
+ indices,
2775
+ offsets,
2776
+ vbe_metadata,
2777
+ multipass_prefetch_config=self.multipass_prefetch_config,
2778
+ )
2779
+
2780
+ if forward_stream is not None:
2781
+ self._prefetch_tensors_record_stream(forward_stream)
2782
+
2783
+ def _prefetch(
2784
+ self,
2785
+ indices: Tensor,
2786
+ offsets: Tensor,
2787
+ vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
2788
+ multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
2789
+ hash_zch_identities: Optional[Tensor] = None,
2790
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2791
+ ) -> None:
2792
+ if not is_torchdynamo_compiling():
2793
+ # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
2794
+ self.timestep += 1
2795
+ self.timesteps_prefetched.append(self.timestep)
2796
+
2797
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
2798
+ # a function.
2799
+ if not self.lxu_cache_weights.numel():
2800
+ return
2801
+
2802
+ # Clear the local_uvm_cache_stats before the prefetch instead of after
2803
+ # the prefetch step, since it will be used in the CommonArgs in the
2804
+ # forward step
2805
+ if self.gather_uvm_cache_stats:
2806
+ self.local_uvm_cache_stats.zero_()
2807
+ self._report_io_size_count("prefetch_input", indices)
2808
+
2809
+ # streaming before updating the cache
2810
+ self.raw_embedding_stream()
2811
+
2812
+ final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
2813
+ linear_cache_indices_merged = torch.zeros(
2814
+ 0, dtype=indices.dtype, device=indices.device
2815
+ )
2816
+ for (
2817
+ partial_indices,
2818
+ partial_lxu_cache_locations,
2819
+ base_offset,
2820
+ ) in self.get_prefetch_passes(
2821
+ multipass_prefetch_config, indices, final_lxu_cache_locations
2822
+ ):
2823
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
2824
+ self.cache_hash_size_cumsum,
2825
+ partial_indices,
2826
+ offsets,
2827
+ vbe_metadata.B_offsets if vbe_metadata is not None else None,
2828
+ vbe_metadata.max_B if vbe_metadata is not None else -1,
2829
+ base_offset,
2830
+ )
2831
+ linear_cache_indices_merged = torch.cat(
2832
+ [linear_cache_indices_merged, linear_cache_indices]
2833
+ )
2834
+
2835
+ if (
2836
+ self.record_cache_metrics.record_cache_miss_counter
2837
+ or self.record_cache_metrics.record_tablewise_cache_miss
2838
+ ):
2839
+ lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
2840
+ linear_cache_indices,
2841
+ self.lxu_cache_state,
2842
+ self.total_cache_hash_size,
2843
+ self.gather_uvm_cache_stats,
2844
+ self.local_uvm_cache_stats,
2845
+ )
2846
+ if self.record_cache_metrics.record_cache_miss_counter:
2847
+ self._update_cache_miss_counter(
2848
+ lxu_cache_locations, linear_cache_indices
2849
+ )
2850
+ if self.record_cache_metrics.record_tablewise_cache_miss:
2851
+ self._update_tablewise_cache_miss(
2852
+ lxu_cache_locations, linear_cache_indices, offsets
2853
+ )
2854
+
2855
+ if self.cache_algorithm == CacheAlgorithm.LRU:
2856
+ torch.ops.fbgemm.lru_cache_populate(
2857
+ self.weights_uvm,
2858
+ self.cache_hash_size_cumsum,
2859
+ self.total_cache_hash_size,
2860
+ self.cache_index_table_map,
2861
+ self.weights_offsets,
2862
+ self.D_offsets,
2863
+ linear_cache_indices,
2864
+ self.lxu_cache_state,
2865
+ self.lxu_cache_weights,
2866
+ self.timestep,
2867
+ self.lxu_state,
2868
+ self.stochastic_rounding,
2869
+ self.gather_uvm_cache_stats,
2870
+ self.local_uvm_cache_stats,
2871
+ self.lock_cache_line,
2872
+ self.lxu_cache_locking_counter,
2873
+ )
2874
+ elif self.cache_algorithm == CacheAlgorithm.LFU:
2875
+ torch.ops.fbgemm.lfu_cache_populate(
2876
+ self.weights_uvm,
2877
+ self.cache_hash_size_cumsum,
2878
+ self.total_cache_hash_size,
2879
+ self.cache_index_table_map,
2880
+ self.weights_offsets,
2881
+ self.D_offsets,
2882
+ linear_cache_indices,
2883
+ self.lxu_cache_state,
2884
+ self.lxu_cache_weights,
2885
+ self.lxu_state,
2886
+ self.stochastic_rounding,
2887
+ )
2888
+
2889
+ torch.ops.fbgemm.lxu_cache_lookup(
2890
+ linear_cache_indices,
2891
+ self.lxu_cache_state,
2892
+ self.total_cache_hash_size,
2893
+ self.gather_uvm_cache_stats,
2894
+ self.local_uvm_cache_stats,
2895
+ lxu_cache_locations_output=partial_lxu_cache_locations,
2896
+ )
2897
+
2898
+ assert (
2899
+ len(self.lxu_cache_locations_list) < self.max_prefetch_depth
2900
+ ), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
2901
+ self.lxu_cache_locations_list.append(final_lxu_cache_locations)
2902
+
2903
+ if self.gather_uvm_cache_stats:
2904
+ # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
2905
+ # We may want to do this accumulation atomically, but as it's only
2906
+ # for monitoring, slightly inaccurate result may be acceptable.
2907
+ self.uvm_cache_stats = torch.add(
2908
+ self.uvm_cache_stats, self.local_uvm_cache_stats
2909
+ )
2910
+ self._report_uvm_cache_stats()
2911
+ if self.should_log():
2912
+ self.print_uvm_cache_stats(use_local_cache=False)
2913
+
2914
+ self._store_prefetched_tensors(
2915
+ indices,
2916
+ offsets,
2917
+ vbe_metadata,
2918
+ linear_cache_indices_merged,
2919
+ final_lxu_cache_locations,
2920
+ hash_zch_identities,
2921
+ hash_zch_runtime_meta,
2922
+ )
2923
+
2924
+ def should_log(self) -> bool:
2925
+ """Determines if we should log for this step, using exponentially decreasing frequency.
2926
+
2927
+ Logs for steps: 100 200 ... 1,000 2,000 ... 10,000 20,000 ... 100,000 200,000 ...
2928
+ """
2929
+ s = self.step + 1 # step starts at 0
2930
+ return s >= 100 and s % (10 ** int(math.log10(s))) == 0
2931
+
2932
+ def _prefetch_tensors_record_stream(
2933
+ self, forward_stream: torch.cuda.Stream
2934
+ ) -> None:
2935
+ # Record the tensors created by prefetch stream and consumed by forward/backward
2936
+ # to the forward stream. In PyTorch, each backward CUDA op runs on the same
2937
+ # stream that was used for its corresponding forward op.
2938
+
2939
+ for t in self.lxu_cache_locations_list:
2940
+ t.record_stream(forward_stream)
2941
+
2942
+ def _update_cache_miss_counter(
2943
+ self,
2944
+ lxu_cache_locations: Tensor,
2945
+ linear_cache_indices: Tensor,
2946
+ ) -> None:
2947
+ CACHE_MISS = -1
2948
+ CACHE_HIT = -2
2949
+
2950
+ cache_missed_locations = torch.where(
2951
+ lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
2952
+ )
2953
+ unique_ids_list = torch.unique(cache_missed_locations)
2954
+ unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
2955
+
2956
+ miss_count = torch.sum(unique_ids_count_list)
2957
+
2958
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
2959
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
2960
+ self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64)
2961
+
2962
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
2963
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
2964
+ self.cache_miss_counter[1] += miss_count
2965
+
2966
+ def _update_tablewise_cache_miss(
2967
+ self,
2968
+ lxu_cache_locations: Tensor,
2969
+ linear_cache_indices: Tensor,
2970
+ offsets: Tensor,
2971
+ ) -> None:
2972
+ CACHE_MISS = -1
2973
+ CACHE_HIT = -2
2974
+
2975
+ # pyre-fixme[6]: For 1st argument expected
2976
+ # `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`.
2977
+ num_tables = len(self.cache_hash_size_cumsum) - 1
2978
+ num_offsets_per_table = (len(offsets) - 1) // num_tables
2979
+ cache_missed_locations = torch.where(
2980
+ lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
2981
+ )
2982
+
2983
+ for i in range(num_tables):
2984
+ start = offsets[i * num_offsets_per_table]
2985
+ end = offsets[(i + 1) * num_offsets_per_table]
2986
+
2987
+ current_cache_missed_locations = cache_missed_locations[start:end]
2988
+ unique_ids_list = torch.unique(current_cache_missed_locations)
2989
+ unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
2990
+
2991
+ miss_count = torch.sum(unique_ids_count_list)
2992
+
2993
+ self.table_wise_cache_miss[i] += miss_count
2994
+
2995
+ def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None:
2996
+ splits = self.split_embedding_weights()
2997
+ if self.weights_precision == SparseType.INT8:
2998
+ # TODO: add in-place FloatToFused8BitRowwiseQuantized conversion
2999
+ for emb in splits:
3000
+ assert (
3001
+ len(emb.shape) == 2
3002
+ ), "Int8 embedding only supported for 2D weight tensors."
3003
+ shape = [emb.shape[0], emb.shape[1] - self.int8_emb_row_dim_offset]
3004
+ tmp_emb = torch.zeros(shape, device=self.current_device)
3005
+ tmp_emb.uniform_(min_val, max_val)
3006
+ tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
3007
+ emb.data.copy_(tmp_emb_i8)
3008
+ # Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
3009
+ elif self.weights_precision == SparseType.NFP8:
3010
+ assert (
3011
+ self.current_device.type == "cuda"
3012
+ ), "NFP8 is currently only supportd on GPU."
3013
+ assert self.optimizer in [
3014
+ OptimType.EXACT_ADAGRAD,
3015
+ OptimType.ROWWISE_ADAGRAD,
3016
+ OptimType.EXACT_ROWWISE_ADAGRAD,
3017
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
3018
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
3019
+ ], "NFP8 is currently only supportd with adagrad optimizers."
3020
+ for param in splits:
3021
+ tmp_param = torch.zeros(param.shape, device=self.current_device)
3022
+ # Create initialized weights and cast to fp8.
3023
+ fp8_dtype = (
3024
+ torch.float8_e4m3fnuz
3025
+ if torch.version.hip is not None
3026
+ else torch.float8_e4m3fn
3027
+ )
3028
+ tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
3029
+ param.data.copy_(tmp_param)
3030
+ else:
3031
+ for param in splits:
3032
+ param.uniform_(min_val, max_val)
3033
+
3034
+ @torch.jit.ignore
3035
+ def split_embedding_weights(self) -> list[Tensor]:
3036
+ """
3037
+ Returns a list of embedding weights (view), split by table
3038
+
3039
+ Returns:
3040
+ A list of weights. Length = the number of tables
3041
+ """
3042
+ splits = []
3043
+ for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
3044
+ if self.weights_precision == SparseType.INT8:
3045
+ dim += self.int8_emb_row_dim_offset
3046
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
3047
+ placement = self.weights_physical_placements[t]
3048
+ # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
3049
+ offset = self.weights_physical_offsets[t]
3050
+ if placement == EmbeddingLocation.DEVICE.value:
3051
+ weights = self.weights_dev
3052
+ elif placement == EmbeddingLocation.HOST.value:
3053
+ weights = self.weights_host
3054
+ else:
3055
+ weights = self.weights_uvm
3056
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is
3057
+ # not a function.
3058
+ if weights.dim() == 2:
3059
+ weights = weights.flatten()
3060
+ splits.append(
3061
+ weights.detach()[offset : offset + rows * dim].view(rows, dim)
3062
+ )
3063
+ return splits
3064
+
3065
+ @torch.jit.ignore
3066
+ def get_optimizer_buffer(self, state: str) -> torch.Tensor:
3067
+ if self.optimizer == OptimType.NONE:
3068
+ raise NotImplementedError(
3069
+ f"Getting optimizer buffer is not supported for {self.optimizer}"
3070
+ )
3071
+ for name, buffer in self.named_buffers():
3072
+ if name == state:
3073
+ return buffer
3074
+ raise ValueError(f"Optimizer buffer {state} not found")
3075
+
3076
+ @torch.jit.export
3077
+ def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
3078
+ r"""
3079
+ Get the optimizer state dict that matches the OSS Pytorch optims
3080
+ TODO: populate the supported list of optimizers
3081
+ """
3082
+ split_optimizer_states = self.split_optimizer_states()
3083
+ if (
3084
+ self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
3085
+ or self.optimizer == OptimType.EXACT_ADAGRAD
3086
+ or self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD
3087
+ ):
3088
+ list_of_state_dict = [
3089
+ (
3090
+ (
3091
+ {
3092
+ "sum": states[0],
3093
+ "prev_iter": states[1],
3094
+ "row_counter": states[2],
3095
+ "iter": self.iter,
3096
+ }
3097
+ if self.optimizer_args.regularization_mode
3098
+ == WeightDecayMode.COUNTER.value
3099
+ and self.optimizer_args.weight_decay_mode
3100
+ == CounterWeightDecayMode.ADAGRADW.value
3101
+ else {
3102
+ "sum": states[0],
3103
+ "prev_iter": states[1],
3104
+ "row_counter": states[2],
3105
+ }
3106
+ )
3107
+ if self._used_rowwise_adagrad_with_counter
3108
+ else (
3109
+ {
3110
+ "sum": states[0],
3111
+ "prev_iter": states[1],
3112
+ "iter": self.iter,
3113
+ }
3114
+ if self._used_rowwise_adagrad_with_global_weight_decay
3115
+ else {"sum": states[0]}
3116
+ )
3117
+ )
3118
+ for states in split_optimizer_states
3119
+ ]
3120
+ elif self.optimizer == OptimType.SGD or self.optimizer == OptimType.EXACT_SGD:
3121
+ list_of_state_dict = [
3122
+ {"momentum_buffer": states[0]} for states in split_optimizer_states
3123
+ ]
3124
+ elif self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
3125
+ list_of_state_dict = [
3126
+ {
3127
+ "exp_avg": states[0],
3128
+ "exp_avg_sq": states[1],
3129
+ "row_counter": states[2],
3130
+ }
3131
+ for states in split_optimizer_states
3132
+ ]
3133
+ elif (
3134
+ self.optimizer == OptimType.ADAM
3135
+ or self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
3136
+ or self.optimizer == OptimType.LAMB
3137
+ or self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB
3138
+ ):
3139
+ list_of_state_dict = [
3140
+ {"exp_avg": states[0], "exp_avg_sq": states[1]}
3141
+ for states in split_optimizer_states
3142
+ ]
3143
+ elif self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
3144
+ list_of_state_dict = [
3145
+ {
3146
+ "sum": states[0],
3147
+ "sparse_ema": states[1],
3148
+ }
3149
+ for states in split_optimizer_states
3150
+ ]
3151
+ else:
3152
+ raise NotImplementedError(
3153
+ f"Getting optimizer state {self.optimizer} is not implmeneted"
3154
+ )
3155
+
3156
+ return list_of_state_dict
3157
+
3158
+ @torch.jit.ignore
3159
+ def split_optimizer_states(
3160
+ self,
3161
+ ) -> list[list[torch.Tensor]]:
3162
+ """
3163
+ Returns a list of optimizer states (view), split by table
3164
+
3165
+ Returns:
3166
+ A list of list of states. Shape = (the number of tables, the number
3167
+ of states).
3168
+
3169
+ The following shows the list of states (in the returned order) for
3170
+ each optimizer:
3171
+
3172
+ (1) `ADAM`: `momentum1`, `momentum2`
3173
+
3174
+ (2) `EXACT_ADAGRAD`: `momentum1`
3175
+
3176
+ (3) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `prev_iter`
3177
+ (rowwise; only when using `WeightDecayMode` = `COUNTER` or
3178
+ `COWCLIP` or `global_weight_decay` is not None), `row_counter`
3179
+ (rowwise; only when using `WeightDecayMode` = `COUNTER` or
3180
+ `COWCLIP`)
3181
+
3182
+ (4) `EXACT_SGD`: no states
3183
+
3184
+ (5) `LAMB`: `momentum1`, `momentum2`
3185
+
3186
+ (6) `LARS_SGD`: `momentum1`
3187
+
3188
+ (7) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
3189
+
3190
+ (8) `PARTIAL_ROWWISE_LAMB`: `momentum1`, `momentum2` (rowwise)
3191
+
3192
+ (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `momentum2`
3193
+
3194
+ (10) `NONE`: no states (throwing an error)
3195
+
3196
+ """
3197
+ if self.optimizer == OptimType.NONE:
3198
+ raise NotImplementedError(
3199
+ f"Getting optimizer states is not supported for {self.optimizer}"
3200
+ )
3201
+
3202
+ def get_optimizer_states(
3203
+ state_dev: Tensor,
3204
+ state_host: Tensor,
3205
+ state_uvm: Tensor,
3206
+ state_offsets: Tensor,
3207
+ state_placements: Tensor,
3208
+ rowwise: bool,
3209
+ ) -> list[torch.Tensor]:
3210
+ splits = []
3211
+ for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
3212
+ offset = state_offsets[t]
3213
+ placement = state_placements[t]
3214
+ if placement == EmbeddingLocation.DEVICE:
3215
+ state = state_dev
3216
+ elif placement == EmbeddingLocation.HOST:
3217
+ state = state_host
3218
+ else:
3219
+ state = state_uvm
3220
+ if not rowwise:
3221
+ splits.append(
3222
+ state.detach()[offset : offset + rows * dim].view(rows, dim)
3223
+ )
3224
+ else:
3225
+ splits.append(state.detach()[offset : offset + rows].view(rows))
3226
+ return splits
3227
+
3228
+ states: list[list[torch.Tensor]] = []
3229
+ if self.optimizer not in (OptimType.EXACT_SGD,):
3230
+ states.append(
3231
+ get_optimizer_states(
3232
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
3233
+ # `Union[Module, Tensor]`.
3234
+ self.momentum1_dev,
3235
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
3236
+ # `Union[Module, Tensor]`.
3237
+ self.momentum1_host,
3238
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
3239
+ # `Union[Module, Tensor]`.
3240
+ self.momentum1_uvm,
3241
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
3242
+ # `Union[Module, Tensor]`.
3243
+ self.momentum1_physical_offsets,
3244
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
3245
+ # `Union[Module, Tensor]`.
3246
+ self.momentum1_physical_placements,
3247
+ rowwise=self.optimizer
3248
+ in [
3249
+ OptimType.EXACT_ROWWISE_ADAGRAD,
3250
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
3251
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
3252
+ ],
3253
+ )
3254
+ )
3255
+ if self.optimizer in (
3256
+ OptimType.ADAM,
3257
+ OptimType.PARTIAL_ROWWISE_ADAM,
3258
+ OptimType.LAMB,
3259
+ OptimType.PARTIAL_ROWWISE_LAMB,
3260
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
3261
+ ):
3262
+ states.append(
3263
+ get_optimizer_states(
3264
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
3265
+ # `Union[Module, Tensor]`.
3266
+ self.momentum2_dev,
3267
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
3268
+ # `Union[Module, Tensor]`.
3269
+ self.momentum2_host,
3270
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
3271
+ # `Union[Module, Tensor]`.
3272
+ self.momentum2_uvm,
3273
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
3274
+ # `Union[Module, Tensor]`.
3275
+ self.momentum2_physical_offsets,
3276
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
3277
+ # `Union[Module, Tensor]`.
3278
+ self.momentum2_physical_placements,
3279
+ rowwise=self.optimizer
3280
+ in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB),
3281
+ )
3282
+ )
3283
+ if (
3284
+ self._used_rowwise_adagrad_with_counter
3285
+ or self._used_rowwise_adagrad_with_global_weight_decay
3286
+ ):
3287
+ states.append(
3288
+ get_optimizer_states(
3289
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
3290
+ # `Union[Module, Tensor]`.
3291
+ self.prev_iter_dev,
3292
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
3293
+ # `Union[Module, Tensor]`.
3294
+ self.prev_iter_host,
3295
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
3296
+ # `Union[Module, Tensor]`.
3297
+ self.prev_iter_uvm,
3298
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
3299
+ # `Union[Module, Tensor]`.
3300
+ self.prev_iter_physical_offsets,
3301
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
3302
+ # `Union[Module, Tensor]`.
3303
+ self.prev_iter_physical_placements,
3304
+ rowwise=True,
3305
+ )
3306
+ )
3307
+ if self._used_rowwise_adagrad_with_counter or (
3308
+ self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction
3309
+ ):
3310
+ states.append(
3311
+ get_optimizer_states(
3312
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
3313
+ # `Union[Module, Tensor]`.
3314
+ self.row_counter_dev,
3315
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
3316
+ # `Union[Module, Tensor]`.
3317
+ self.row_counter_host,
3318
+ # pyre-fixme[6]: For 3rd argument expected `Tensor` but got
3319
+ # `Union[Module, Tensor]`.
3320
+ self.row_counter_uvm,
3321
+ # pyre-fixme[6]: For 4th argument expected `Tensor` but got
3322
+ # `Union[Module, Tensor]`.
3323
+ self.row_counter_physical_offsets,
3324
+ # pyre-fixme[6]: For 5th argument expected `Tensor` but got
3325
+ # `Union[Module, Tensor]`.
3326
+ self.row_counter_physical_placements,
3327
+ rowwise=True,
3328
+ )
3329
+ )
3330
+ return_states = [list(s) for s in zip(*states)]
3331
+ return return_states
3332
+
3333
+ @torch.jit.export
3334
+ def set_learning_rate(self, lr: float) -> None:
3335
+ """
3336
+ Sets the learning rate.
3337
+
3338
+ Args:
3339
+ lr (float): The learning rate value to set to
3340
+ """
3341
+ if self.optimizer == OptimType.NONE:
3342
+ raise NotImplementedError(
3343
+ f"Setting learning rate is not supported for {self.optimizer}"
3344
+ )
3345
+ self._set_learning_rate(lr)
3346
+
3347
+ def get_learning_rate(self) -> float:
3348
+ """
3349
+ Get and return the learning rate.
3350
+ """
3351
+ return self.learning_rate_tensor.item()
3352
+
3353
+ @torch.jit.ignore
3354
+ def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
3355
+ """
3356
+ Sets hyper-parameters from external control flow.
3357
+
3358
+ Args:
3359
+ params_dict (Dict[str, float]): The dict that contains the
3360
+ hyper-parameter names and their values
3361
+ """
3362
+ if self.optimizer == OptimType.NONE:
3363
+ raise NotImplementedError(
3364
+ f"Setting learning rate is not supported for {self.optimizer}"
3365
+ )
3366
+ for parameter_name, value in params_dict.items():
3367
+ if parameter_name == "lr":
3368
+ self._set_learning_rate(value)
3369
+ elif parameter_name == "eps":
3370
+ self.optimizer_args = self.optimizer_args._replace(eps=value)
3371
+ elif parameter_name == "beta1":
3372
+ self.optimizer_args = self.optimizer_args._replace(beta1=value)
3373
+ elif parameter_name == "beta2":
3374
+ self.optimizer_args = self.optimizer_args._replace(beta2=value)
3375
+ elif parameter_name == "weight_decay":
3376
+ self.optimizer_args = self.optimizer_args._replace(weight_decay=value)
3377
+ elif parameter_name == "lower_bound":
3378
+ self.gwd_lower_bound = value
3379
+ else:
3380
+ raise NotImplementedError(
3381
+ f"Setting hyper-parameter {parameter_name} is not supported"
3382
+ )
3383
+
3384
+ @torch.jit.ignore
3385
+ def _set_learning_rate(self, lr: float) -> float:
3386
+ """
3387
+ Helper function to script `set_learning_rate`.
3388
+ Note that returning None does not work.
3389
+ """
3390
+ self.learning_rate_tensor.fill_(lr)
3391
+ return 0.0
3392
+
3393
+ @torch.jit.ignore
3394
+ def set_optimizer_step(self, step: int) -> None:
3395
+ """
3396
+ Sets the optimizer step.
3397
+
3398
+ Args:
3399
+ step (int): The step value to set to
3400
+ """
3401
+ self.log(f"set_optimizer_step from {self.iter[0]=} to {step=}")
3402
+ if self.optimizer == OptimType.NONE:
3403
+ raise NotImplementedError(
3404
+ f"Setting optimizer step is not supported for {self.optimizer}"
3405
+ )
3406
+ self.iter[0] = step
3407
+
3408
+ @torch.jit.export
3409
+ def flush(self) -> None:
3410
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
3411
+ # a function.
3412
+ if not self.lxu_cache_weights.numel():
3413
+ return
3414
+ torch.ops.fbgemm.lxu_cache_flush(
3415
+ self.weights_uvm,
3416
+ self.cache_hash_size_cumsum,
3417
+ self.cache_index_table_map,
3418
+ self.weights_offsets,
3419
+ self.D_offsets,
3420
+ self.total_D,
3421
+ self.lxu_cache_state,
3422
+ self.lxu_cache_weights,
3423
+ self.stochastic_rounding,
3424
+ )
3425
+
3426
+ def _apply_split(
3427
+ self,
3428
+ split: SplitState,
3429
+ prefix: str,
3430
+ dtype: type[torch.dtype],
3431
+ enforce_hbm: bool = False,
3432
+ make_dev_param: bool = False,
3433
+ dev_reshape: Optional[tuple[int, ...]] = None,
3434
+ uvm_host_mapped: bool = False,
3435
+ ) -> None:
3436
+ apply_split_helper(
3437
+ self.register_buffer,
3438
+ functools.partial(setattr, self),
3439
+ self.current_device,
3440
+ self.use_cpu,
3441
+ self.feature_table_map,
3442
+ split,
3443
+ prefix,
3444
+ dtype,
3445
+ enforce_hbm,
3446
+ make_dev_param,
3447
+ dev_reshape,
3448
+ self._uvm_tensors_log,
3449
+ uvm_host_mapped=uvm_host_mapped,
3450
+ )
3451
+
3452
+ def _apply_cache_state(
3453
+ self,
3454
+ cache_state: CacheState,
3455
+ cache_algorithm: CacheAlgorithm,
3456
+ cache_load_factor: float,
3457
+ cache_sets: int,
3458
+ cache_reserved_memory: float,
3459
+ cache_precision: SparseType,
3460
+ ) -> None:
3461
+ self.cache_algorithm = cache_algorithm
3462
+ self.timestep = 1
3463
+ self.timesteps_prefetched = []
3464
+
3465
+ self.max_prefetch_depth = MAX_PREFETCH_DEPTH
3466
+ self.lxu_cache_locations_list = []
3467
+ self.lxu_cache_locations_empty = torch.empty(
3468
+ 0, device=self.current_device, dtype=torch.int32
3469
+ ).fill_(-1)
3470
+ self.lxu_cache_locations = self.lxu_cache_locations_empty
3471
+ self._indices = self.lxu_cache_locations_empty
3472
+ self._offsets = self.lxu_cache_locations_empty
3473
+ self._vbe_B_offsets = self.lxu_cache_locations_empty
3474
+ self._vbe_max_B = -1
3475
+ self.prefetch_stream: Optional[torch.cuda.Stream] = None
3476
+
3477
+ self._init_uvm_cache_stats()
3478
+
3479
+ if cache_precision == SparseType.FP32:
3480
+ dtype = torch.float32
3481
+ elif cache_precision == SparseType.FP16:
3482
+ dtype = torch.float16
3483
+ elif cache_precision == SparseType.NFP8:
3484
+ # NFP8 weights use floating point cache.
3485
+ dtype = torch.float16
3486
+ else:
3487
+ dtype = torch.float32 # not relevant, but setting it to keep linter happy
3488
+ if not self.use_cpu > 0:
3489
+ raise AssertionError(
3490
+ f"cache_precision {cache_precision} not supported!"
3491
+ )
3492
+
3493
+ # NOTE: no cache for CPU mode!
3494
+ if cache_state.total_cache_hash_size == 0 or self.use_cpu:
3495
+ self.register_buffer(
3496
+ "lxu_cache_weights",
3497
+ torch.zeros(0, 0, device=self.current_device, dtype=dtype),
3498
+ )
3499
+ # NOTE: make TorchScript work!
3500
+ self.register_buffer(
3501
+ "cache_hash_size_cumsum",
3502
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
3503
+ persistent=False,
3504
+ )
3505
+ self.register_buffer(
3506
+ "total_cache_hash_size",
3507
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
3508
+ persistent=False,
3509
+ )
3510
+ self.register_buffer(
3511
+ "cache_index_table_map",
3512
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
3513
+ persistent=False,
3514
+ )
3515
+ self.register_buffer(
3516
+ "lxu_cache_state",
3517
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
3518
+ persistent=False,
3519
+ )
3520
+ self.register_buffer(
3521
+ "lxu_state",
3522
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
3523
+ persistent=False,
3524
+ )
3525
+ self.register_buffer(
3526
+ "cache_miss_counter",
3527
+ torch.tensor([0, 0], dtype=torch.int64),
3528
+ persistent=False,
3529
+ )
3530
+ self._init_uvm_cache_counter(cache_sets, persistent=False)
3531
+ return
3532
+
3533
+ assert cache_load_factor > 0
3534
+ element_size = 2 if dtype == torch.float16 else 4
3535
+ if cache_sets <= 0:
3536
+ total_memory = torch.cuda.get_device_properties(
3537
+ self.current_device
3538
+ ).total_memory
3539
+ free_memory = (
3540
+ total_memory
3541
+ - torch.cuda.memory_reserved(self.current_device)
3542
+ - int(cache_reserved_memory)
3543
+ )
3544
+ assert free_memory > 0
3545
+ cache_sets = (
3546
+ int(cache_state.total_cache_hash_size * cache_load_factor)
3547
+ + DEFAULT_ASSOC
3548
+ - 1
3549
+ ) // DEFAULT_ASSOC
3550
+ cache_sets = 1 if cache_sets == 0 else cache_sets
3551
+ cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache
3552
+ if cache_size > free_memory:
3553
+ cache_sets = (
3554
+ int(1.0 * free_memory / self.max_D_cache / element_size)
3555
+ + DEFAULT_ASSOC
3556
+ - 1
3557
+ ) // DEFAULT_ASSOC
3558
+ cache_load_factor = (
3559
+ 1.0 * cache_sets * DEFAULT_ASSOC / int(cache_state.total_cache_hash_size)
3560
+ )
3561
+ assert cache_sets > 0
3562
+ if cache_algorithm == CacheAlgorithm.LFU:
3563
+ assert cache_sets < 2**24 - 1
3564
+ cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache
3565
+ self.log(
3566
+ f"Using on-device cache with admission algorithm "
3567
+ f"{cache_algorithm}, {cache_sets} sets, "
3568
+ f"load_factor: {cache_load_factor : .3f}, "
3569
+ f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
3570
+ f"cache_precision: {dtype}, "
3571
+ f"weights_precision: {self.weights_precision}"
3572
+ )
3573
+
3574
+ self.total_cache_hash_size = cache_state.total_cache_hash_size
3575
+ # 8x of # tables, trivial size
3576
+ self.register_buffer(
3577
+ "cache_hash_size_cumsum",
3578
+ torch.tensor(
3579
+ cache_state.cache_hash_size_cumsum,
3580
+ device=self.current_device,
3581
+ dtype=torch.int64,
3582
+ ),
3583
+ )
3584
+ # 4x total embedding hash size with uvm cache
3585
+ self.register_buffer(
3586
+ "cache_index_table_map",
3587
+ torch.tensor(
3588
+ cache_state.cache_index_table_map,
3589
+ device=self.current_device,
3590
+ dtype=torch.int32,
3591
+ ),
3592
+ )
3593
+ # 8x of total cache slots (embedding hash size * clf)
3594
+ self.register_buffer(
3595
+ "lxu_cache_state",
3596
+ torch.zeros(
3597
+ cache_sets, DEFAULT_ASSOC, device=self.current_device, dtype=torch.int64
3598
+ ).fill_(-1),
3599
+ )
3600
+ # Cache itself, not auxiliary size
3601
+ self.register_buffer(
3602
+ "lxu_cache_weights",
3603
+ torch.zeros(
3604
+ cache_sets * DEFAULT_ASSOC,
3605
+ self.max_D_cache,
3606
+ device=self.current_device,
3607
+ dtype=dtype,
3608
+ ),
3609
+ )
3610
+ # LRU: 8x of total cache slots (embedding hash size * clf)
3611
+ # LFU: 8x of total embedding hash size with uvm cache
3612
+ self.register_buffer(
3613
+ "lxu_state",
3614
+ torch.zeros(
3615
+ size=(
3616
+ (self.total_cache_hash_size + 1,)
3617
+ if cache_algorithm == CacheAlgorithm.LFU
3618
+ else (cache_sets, DEFAULT_ASSOC)
3619
+ ),
3620
+ device=self.current_device,
3621
+ dtype=torch.int64,
3622
+ ),
3623
+ )
3624
+ self.register_buffer(
3625
+ "cache_miss_counter",
3626
+ torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
3627
+ )
3628
+ self._init_uvm_cache_counter(cache_sets, persistent=True)
3629
+ if self.prefetch_pipeline:
3630
+ # using the placeholder_autograd_tensor to make sure
3631
+ # the hook is executed after the backward pass
3632
+ # not using register_module_full_backward_hook
3633
+ # due to https://github.com/pytorch/pytorch/issues/100528
3634
+ self.placeholder_autograd_tensor.register_hook(
3635
+ self._sync_stream_post_backward
3636
+ )
3637
+ self.register_full_backward_pre_hook(
3638
+ self._update_cache_counter_and_locations
3639
+ )
3640
+
3641
+ if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
3642
+ raise ValueError(
3643
+ f"cache_algorithm must be {CacheAlgorithm.LRU} "
3644
+ f"or {CacheAlgorithm.LFU}"
3645
+ )
3646
+
3647
+ # pyre-ignore
3648
+ def _recording_to_timer(
3649
+ self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
3650
+ ) -> Any:
3651
+ if self.stats_reporter is not None and self.stats_reporter.should_report(
3652
+ self.step
3653
+ ):
3654
+ assert (
3655
+ timer
3656
+ ), "We shouldn't be here, async timer must have been initiated if reporter is present."
3657
+ return timer.recording(**kwargs)
3658
+ # No-Op context manager
3659
+ return contextlib.nullcontext()
3660
+
3661
+ def _sync_stream_post_backward(
3662
+ self,
3663
+ grad: Tensor,
3664
+ ) -> None:
3665
+ """
3666
+ backward hook function when prefetch_pipeline is enabled.
3667
+
3668
+ With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}).
3669
+ There is race condition that backward(batch_i) writes to UVM memory and
3670
+ at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces
3671
+ backward(batch_i) to finish before prefetch(batch_{i+2}).
3672
+ """
3673
+ if self.prefetch_stream is not None:
3674
+ self.prefetch_stream.wait_stream(torch.cuda.current_stream())
3675
+
3676
+ def _update_cache_counter_and_locations(
3677
+ self,
3678
+ module: nn.Module,
3679
+ grad_input: Union[tuple[Tensor, ...], Tensor],
3680
+ ) -> None:
3681
+ """
3682
+ Backward prehook function when prefetch_pipeline is enabled.
3683
+
3684
+ This function does 3 things:
3685
+ 1. backward stream waits for prefetch stream to finish.
3686
+ Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i).
3687
+ If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1},
3688
+ there is race condition that backward(batch_i) writes to UVM memory and
3689
+ at the same time prefetch(batch_{i+1}) loads UVM memory to cache.
3690
+
3691
+ 2. decrement the lxu_cache_locking_counter to indicate the current batch is finished.
3692
+ The lxu_cache_locking_counter is updated in both prefetch and TBE backward.
3693
+ As there is no overlap between prefetch and backward, we can decrement either before or
3694
+ after backward. It's better to decrement before lxu_cache_locations gets updated.
3695
+
3696
+ 3. update lxu_cache_locations to address the cache inconsistency issue.
3697
+ In the case that the same index is not inserted into cache in batch_i,
3698
+ but it is inserted in batch_{i+1}, the cache can be invalid in
3699
+ the sense that the cached weight for this index does not have the
3700
+ backward update of batch_i.
3701
+
3702
+ Example of the issue is as follows:
3703
+ idx is in batch_i, batch_{i+1}
3704
+ prefetch(batch_i)
3705
+ - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
3706
+ forward(batch_i)
3707
+ prefetch(batch_{i+1})
3708
+ - insert idx into cache, cache is loaded from host memory
3709
+ backward(batch_i)
3710
+ - cache_locations_batch_i of idx is -1, the host memory is updated
3711
+ forward(batch_{i+1})
3712
+ - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
3713
+
3714
+ The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i,
3715
+ so that the cache gets updated correctly by the backward pass of TBE.
3716
+ """
3717
+
3718
+ if self.prefetch_stream is not None:
3719
+ # need to wait for the prefetch of next batch,
3720
+ # so that cache states are valid
3721
+ with self._recording_to_timer(
3722
+ self.bwd_wait_prefetch_timer,
3723
+ context=self.step,
3724
+ stream=torch.cuda.current_stream(),
3725
+ ):
3726
+ torch.cuda.current_stream().wait_stream(self.prefetch_stream)
3727
+
3728
+ torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
3729
+ self.lxu_cache_locking_counter,
3730
+ self.lxu_cache_locations,
3731
+ )
3732
+ # Recompute linear_cache_indices
3733
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
3734
+ self.cache_hash_size_cumsum,
3735
+ self._indices,
3736
+ self._offsets,
3737
+ self._vbe_B_offsets,
3738
+ self._vbe_max_B,
3739
+ )
3740
+ (
3741
+ linear_unique_indices,
3742
+ linear_unique_indices_length,
3743
+ _,
3744
+ ) = torch.ops.fbgemm.get_unique_indices(
3745
+ linear_cache_indices,
3746
+ self.total_cache_hash_size,
3747
+ compute_count=False,
3748
+ )
3749
+ torch.ops.fbgemm.lxu_cache_lookup(
3750
+ linear_unique_indices,
3751
+ self.lxu_cache_state,
3752
+ self.total_cache_hash_size,
3753
+ gather_cache_stats=False, # not collecting cache stats
3754
+ num_uniq_cache_indices=linear_unique_indices_length,
3755
+ lxu_cache_locations_output=self.lxu_cache_locations,
3756
+ )
3757
+
3758
+ def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None:
3759
+ if self.prefetch_pipeline and persistent:
3760
+ self.register_buffer(
3761
+ "lxu_cache_locking_counter",
3762
+ torch.zeros(
3763
+ cache_sets,
3764
+ DEFAULT_ASSOC,
3765
+ device=self.current_device,
3766
+ dtype=torch.int32,
3767
+ ),
3768
+ )
3769
+ else:
3770
+ self.register_buffer(
3771
+ "lxu_cache_locking_counter",
3772
+ torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
3773
+ persistent=persistent,
3774
+ )
3775
+
3776
+ def _init_uvm_cache_stats(self) -> None:
3777
+ if not self.gather_uvm_cache_stats:
3778
+ # If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.
3779
+ # Since we're not using these variables, we can choose minimize tensor size to keep state_dict size small.
3780
+ self.register_buffer(
3781
+ "uvm_cache_stats",
3782
+ torch.zeros(
3783
+ 1,
3784
+ device=self.current_device,
3785
+ dtype=torch.int64,
3786
+ ),
3787
+ persistent=False,
3788
+ )
3789
+ self.register_buffer(
3790
+ "local_uvm_cache_stats",
3791
+ torch.zeros(
3792
+ 1,
3793
+ device=self.current_device,
3794
+ dtype=torch.int32,
3795
+ ),
3796
+ persistent=False,
3797
+ )
3798
+ else:
3799
+ self.register_buffer(
3800
+ "uvm_cache_stats",
3801
+ torch.zeros(
3802
+ size=(self.uvm_cache_stats_size,),
3803
+ device=self.current_device,
3804
+ dtype=torch.int64,
3805
+ ),
3806
+ )
3807
+ self.register_buffer(
3808
+ "local_uvm_cache_stats",
3809
+ torch.zeros(
3810
+ size=(self.uvm_cache_stats_size,),
3811
+ device=self.current_device,
3812
+ dtype=torch.int32,
3813
+ ),
3814
+ )
3815
+ self.reset_uvm_cache_stats()
3816
+ self.last_uvm_cache_print_state = torch.zeros_like(self.uvm_cache_stats)
3817
+
3818
+ def reset_cache_states(self) -> None:
3819
+ # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
3820
+ # a function.
3821
+ if not self.lxu_cache_weights.numel():
3822
+ return
3823
+ self.lxu_cache_state.fill_(-1)
3824
+ self.lxu_state.fill_(0)
3825
+ self.timestep = 1
3826
+
3827
+ def reset_embedding_weight_momentum(
3828
+ self,
3829
+ pruned_indices: Tensor,
3830
+ pruned_indices_offsets: Tensor,
3831
+ logical_table_ids: Tensor,
3832
+ buffer_ids: Tensor,
3833
+ ) -> None:
3834
+ if self.optimizer == OptimType.NONE:
3835
+ raise NotImplementedError(
3836
+ f"Resetting embedding weight momentum is not supported for {self.optimizer}"
3837
+ )
3838
+ total_cache_hash_size = 0
3839
+ if isinstance(self.total_cache_hash_size, Tensor):
3840
+ total_cache_hash_size = self.total_cache_hash_size.item()
3841
+ else:
3842
+ total_cache_hash_size = self.total_cache_hash_size
3843
+
3844
+ rowwise = self.optimizer in [
3845
+ OptimType.EXACT_ROWWISE_ADAGRAD,
3846
+ ]
3847
+ if rowwise:
3848
+ torch.ops.fbgemm.reset_weight_momentum(
3849
+ dev_weights=self.weights_dev,
3850
+ uvm_weights=self.weights_uvm,
3851
+ lxu_cache_weights=self.lxu_cache_weights,
3852
+ weights_placements=self.weights_placements,
3853
+ weights_offsets=self.weights_offsets,
3854
+ momentum1_dev=self.momentum1_dev,
3855
+ momentum1_uvm=self.momentum1_uvm,
3856
+ momentum1_placements=self.momentum1_placements,
3857
+ momentum1_offsets=self.momentum1_offsets,
3858
+ D_offsets=self.D_offsets,
3859
+ pruned_indices=pruned_indices.to(device=self.current_device),
3860
+ pruned_indices_offsets=pruned_indices_offsets.to(
3861
+ device=self.current_device
3862
+ ),
3863
+ logical_table_ids=logical_table_ids.to(device=self.current_device),
3864
+ buffer_ids=buffer_ids.to(device=self.current_device),
3865
+ cache_hash_size_cumsum=self.cache_hash_size_cumsum,
3866
+ lxu_cache_state=self.lxu_cache_state,
3867
+ total_cache_hash_size=total_cache_hash_size,
3868
+ )
3869
+
3870
+ def prepare_inputs(
3871
+ self,
3872
+ indices: Tensor,
3873
+ offsets: Tensor,
3874
+ per_sample_weights: Optional[Tensor] = None,
3875
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3876
+ force_cast_input_types: bool = True,
3877
+ prefetch_pipeline: bool = False,
3878
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3879
+ """
3880
+ Prepare TBE inputs as follows:
3881
+
3882
+ (1) Create VBE metadata
3883
+ (2) Convert input types if `force_cast_input_types=True`
3884
+ (3) Run `bounds_check_indices` if `bounds_check_mode` is not
3885
+ BoundsCheckMode.NONE
3886
+
3887
+ Args:
3888
+ indices (Tensor): Input indices
3889
+ offsets (Tensor): Input offsets
3890
+ per_sample_weights (Optional[Tensor]): Input per sample
3891
+ weights
3892
+ batch_size_per_feature_per_rank
3893
+ (Optional[List[List[int]]]): A 2D tensor of batch size
3894
+ for each rank and feature. Shape = (number of
3895
+ features, number of ranks)
3896
+ force_cast_input_types (bool): A flag to force convert
3897
+ input types if set to True
3898
+
3899
+ Returns:
3900
+ A tuple of indices, offsets, per_sample_weights, and VBE
3901
+ metadata
3902
+ """
3903
+
3904
+ # Generate VBE metadata
3905
+ vbe_metadata = self._generate_vbe_metadata(
3906
+ offsets, batch_size_per_feature_per_rank
3907
+ )
3908
+
3909
+ vbe = vbe_metadata.B_offsets is not None
3910
+ # Note this check has already been done in C++ side
3911
+ # TODO: max_B <= self.info_B_mask in python
3912
+ # We cannot use assert as it breaks pt2 compile for dynamic shape
3913
+ # and need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
3914
+ # torch._check(
3915
+ # max_B <= self.info_B_mask,
3916
+ # "Not enough infos bits to accommodate T and B.",
3917
+ # )
3918
+ # We cannot use lambda as it fails jit script.
3919
+ # torch._check is also not supported in jitscript
3920
+
3921
+ # TODO: remove this and add an assert after updating
3922
+ # bounds_check_indices to support different indices type and offset
3923
+ # type
3924
+ force_cast_input_types = (
3925
+ indices.dtype != offsets.dtype or force_cast_input_types
3926
+ )
3927
+
3928
+ if force_cast_input_types:
3929
+ # NOTE: Force offsets to have the same dtype as indices since the
3930
+ # kernels assume same dtype. We might need to revisit the assumption
3931
+ # of same dtypes in the future.
3932
+ if self.embedding_table_index_type == torch.int32:
3933
+ self.log(
3934
+ "Casting indices to int32 based on embedding_table_index_type input."
3935
+ )
3936
+ indices = indices.to(torch.int32)
3937
+ if self.embedding_table_index_type != self.embedding_table_offset_type:
3938
+ self.log(
3939
+ f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type."
3940
+ )
3941
+ offsets = offsets.to(dtype=indices.dtype)
3942
+
3943
+ # Force casting per_sample_weights to float
3944
+ if per_sample_weights is not None:
3945
+ per_sample_weights = per_sample_weights.float()
3946
+
3947
+ if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
3948
+ # Override the bounds check version based on prefetch_pipeline
3949
+ use_bounds_check_v2 = self.bounds_check_version == 2 or prefetch_pipeline
3950
+ bounds_check_version = (
3951
+ 2 if use_bounds_check_v2 else self.bounds_check_version
3952
+ )
3953
+
3954
+ vbe = vbe_metadata.B_offsets is not None
3955
+
3956
+ # Compute B info and VBE metadata for bounds_check_indices only if
3957
+ # VBE and bounds check indices v2 are used
3958
+ if vbe and use_bounds_check_v2:
3959
+ B_offsets = vbe_metadata.B_offsets
3960
+ B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
3961
+ output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
3962
+ assert isinstance(B_offsets, Tensor), "B_offsets must be tensor"
3963
+ assert isinstance(
3964
+ B_offsets_rank_per_feature, Tensor
3965
+ ), "B_offsets_rank_per_feature must be tensor"
3966
+ assert isinstance(
3967
+ output_offsets_feature_rank, Tensor
3968
+ ), "output_offsets_feature_rank must be tensor"
3969
+
3970
+ row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
3971
+ B_offsets,
3972
+ B_offsets_rank_per_feature,
3973
+ output_offsets_feature_rank,
3974
+ self.D_offsets,
3975
+ self.max_D,
3976
+ self.is_nobag,
3977
+ vbe_metadata.max_B_feature_rank,
3978
+ self.info_B_num_bits,
3979
+ offsets.numel() - 1, # total_B
3980
+ )
3981
+ else:
3982
+ b_t_map = None
3983
+
3984
+ torch.ops.fbgemm.bounds_check_indices(
3985
+ self.rows_per_table,
3986
+ indices,
3987
+ offsets,
3988
+ self.bounds_check_mode_int,
3989
+ self.bounds_check_warning,
3990
+ per_sample_weights,
3991
+ B_offsets=vbe_metadata.B_offsets,
3992
+ max_B=vbe_metadata.max_B,
3993
+ b_t_map=b_t_map,
3994
+ info_B_num_bits=self.info_B_num_bits,
3995
+ info_B_mask=self.info_B_mask,
3996
+ bounds_check_version=bounds_check_version,
3997
+ prefetch_pipeline=prefetch_pipeline,
3998
+ )
3999
+
4000
+ return indices, offsets, per_sample_weights, vbe_metadata
4001
+
4002
+ def _debug_print_input_stats_factory(self) -> Callable[..., None]:
4003
+ """
4004
+ If the environment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1,
4005
+ return a function pointer of a function that prints input
4006
+ stats including weighted/unweighted, number of features,
4007
+ batch size, average pooling factor, total number of indices,
4008
+ number of unique indices, and number of indices that goes
4009
+ through the different backward functions. Otherwise, return
4010
+ a dummy function pointer.
4011
+ """
4012
+
4013
+ @torch.jit.ignore
4014
+ def _debug_print_input_stats_factory_impl(
4015
+ indices: Tensor,
4016
+ offsets: Tensor,
4017
+ per_sample_weights: Optional[Tensor] = None,
4018
+ ) -> None:
4019
+ """
4020
+ Print input stats (for debugging purpose only)
4021
+
4022
+ Args:
4023
+ indices (Tensor): Input indices
4024
+ offsets (Tensor): Input offsets
4025
+ per_sample_weights (Optional[Tensor]): Input per
4026
+ sample weights
4027
+ """
4028
+ # pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
4029
+ # float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
4030
+ if self.debug_step % 100 == 0:
4031
+ # Get number of features (T) and batch size (B)
4032
+ T = len(self.feature_table_map)
4033
+ B = (offsets.numel() - 1) // T
4034
+
4035
+ # Transfer hash_size_cumsum, indices and offsets to CPU
4036
+ hash_size_cumsum_cpu = self.hash_size_cumsum.cpu()
4037
+ indices_cpu = indices.cpu()
4038
+ offsets_cpu = offsets.cpu()
4039
+
4040
+ # Compute linear indices
4041
+ for t in range(T):
4042
+ start = offsets_cpu[B * t].item()
4043
+ end = offsets_cpu[B * (t + 1)].item()
4044
+ indices_cpu[start:end] += hash_size_cumsum_cpu[t]
4045
+
4046
+ # Compute unique indices
4047
+ uniq_indices_cpu, counts = indices_cpu.unique(return_counts=True)
4048
+
4049
+ # Compute num unique indices
4050
+ num_uniq_indices = uniq_indices_cpu.numel()
4051
+
4052
+ # The warp_per_row kernel handles indices that their
4053
+ # segment lengths <= 32
4054
+ #
4055
+ # The cta_per_row kernel handles indices that their
4056
+ # segment lengths > 32. A single thread block is used
4057
+ # if segment lengths <= 1024. Otherwise, multiple
4058
+ # thread blocks are used.
4059
+ #
4060
+ # Counts of indices that segment lengths <= 32
4061
+ counts_warp_per_row = counts[counts <= 32]
4062
+ counts_cta_per_row = counts[counts > 32]
4063
+ # Counts of indices that segment lengths > 32 and <= 1024
4064
+ counts_cta_per_row_sth = counts_cta_per_row[counts_cta_per_row <= 1024]
4065
+ # Counts of indices that segment lengths > 1024
4066
+ counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
4067
+
4068
+ def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
4069
+ numel = counts.numel()
4070
+ avg = (counts.sum().item() / numel) if numel != 0 else -1.0
4071
+ return numel, avg
4072
+
4073
+ # warp_per_row stats
4074
+ num_warp_per_row, avg_seglen_warp_per_row = compute_numel_and_avg(
4075
+ counts_warp_per_row
4076
+ )
4077
+ # cta_per_row using a single thread block stats
4078
+ num_cta_per_row_sth, avg_seglen_cta_per_row_sth = compute_numel_and_avg(
4079
+ counts_cta_per_row_sth
4080
+ )
4081
+ # cta_per_row using multiple thread block stats
4082
+ num_cta_per_row_mth, avg_seglen_cta_per_row_mth = compute_numel_and_avg(
4083
+ counts_cta_per_row_mth
4084
+ )
4085
+
4086
+ assert num_uniq_indices == (
4087
+ num_warp_per_row + num_cta_per_row_sth + num_cta_per_row_mth
4088
+ )
4089
+
4090
+ self.log(
4091
+ "TBE_DEBUG: "
4092
+ "weighted {} "
4093
+ "num features {} "
4094
+ "batch size {} "
4095
+ "avg pooling factor {:.2f} "
4096
+ "total num indices {} "
4097
+ "num unique indices {} "
4098
+ "num warp_per_row {} (avg segment length {:.2f}) "
4099
+ "num cta_per_row single thread block (avg segment length) {} ({:.2f}) "
4100
+ "num cta_per_row multiple thread blocks (avg segment length) {} ({:.2f})".format(
4101
+ per_sample_weights is not None,
4102
+ T,
4103
+ B,
4104
+ indices.numel() / (B * T),
4105
+ indices.numel(),
4106
+ num_uniq_indices,
4107
+ num_warp_per_row,
4108
+ avg_seglen_warp_per_row,
4109
+ num_cta_per_row_sth,
4110
+ avg_seglen_cta_per_row_sth,
4111
+ num_cta_per_row_mth,
4112
+ avg_seglen_cta_per_row_mth,
4113
+ )
4114
+ )
4115
+ # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no
4116
+ # attribute `debug_step`.
4117
+ # pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
4118
+ # float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
4119
+ self.debug_step += 1
4120
+
4121
+ @torch.jit.ignore
4122
+ def _debug_print_input_stats_factory_null(
4123
+ indices: Tensor,
4124
+ offsets: Tensor,
4125
+ per_sample_weights: Optional[Tensor] = None,
4126
+ ) -> None:
4127
+ pass
4128
+
4129
+ if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1:
4130
+ # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no
4131
+ # attribute `debug_step`.
4132
+ self.debug_step = 0
4133
+ return _debug_print_input_stats_factory_impl
4134
+ return _debug_print_input_stats_factory_null
4135
+
4136
+ @torch.jit.ignore
4137
+ def raw_embedding_stream(self) -> None:
4138
+ if not self.enable_raw_embedding_streaming:
4139
+ return None
4140
+ # when pipelining is enabled
4141
+ # prefetch in iter i happens before the backward sparse in iter i - 1
4142
+ # so embeddings for iter i - 1's changed ids are not updated.
4143
+ # so we can only fetch the indices from the iter i - 2
4144
+ # when pipelining is disabled
4145
+ # prefetch in iter i happens before forward iter i
4146
+ # so we can get the iter i - 1's changed ids safely.
4147
+ target_prev_iter = 1
4148
+ if self.prefetch_pipeline:
4149
+ target_prev_iter = 2
4150
+ if not len(self.prefetched_info_list) > (target_prev_iter - 1):
4151
+ return None
4152
+ with record_function(
4153
+ "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4154
+ ):
4155
+ prefetched_info = self.prefetched_info_list.pop(0)
4156
+ updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
4157
+ prefetched_info.linear_unique_cache_indices,
4158
+ self.lxu_cache_state,
4159
+ self.total_cache_hash_size,
4160
+ gather_cache_stats=False, # not collecting cache stats
4161
+ num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
4162
+ )
4163
+ updated_weights = torch.empty(
4164
+ [
4165
+ prefetched_info.linear_unique_cache_indices.size()[0],
4166
+ self.max_D_cache,
4167
+ ],
4168
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
4169
+ dtype=self.lxu_cache_weights.dtype,
4170
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
4171
+ device=self.lxu_cache_weights.device,
4172
+ )
4173
+ torch.ops.fbgemm.masked_index_select(
4174
+ updated_weights,
4175
+ updated_locations,
4176
+ self.lxu_cache_weights,
4177
+ prefetched_info.linear_unique_indices_length,
4178
+ )
4179
+ # TODO: this statement triggers a sync
4180
+ # added here to make this diff self-contained
4181
+ # will remove in later change
4182
+ cache_hit_mask_index = (
4183
+ updated_locations.narrow(
4184
+ 0, 0, prefetched_info.linear_unique_indices_length.item()
4185
+ )
4186
+ .not_equal(-1)
4187
+ .nonzero()
4188
+ .flatten()
4189
+ )
4190
+ # stream weights
4191
+ self._raw_embedding_streamer.stream(
4192
+ prefetched_info.linear_unique_indices.index_select(
4193
+ dim=0, index=cache_hit_mask_index
4194
+ ).to(device=torch.device("cpu")),
4195
+ updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
4196
+ device=torch.device("cpu")
4197
+ ),
4198
+ (
4199
+ prefetched_info.hash_zch_identities.index_select(
4200
+ dim=0, index=cache_hit_mask_index
4201
+ ).to(device=torch.device("cpu"))
4202
+ if prefetched_info.hash_zch_identities is not None
4203
+ else None
4204
+ ),
4205
+ (
4206
+ prefetched_info.hash_zch_runtime_meta.index_select(
4207
+ dim=0, index=cache_hit_mask_index
4208
+ ).to(device=torch.device("cpu"))
4209
+ if prefetched_info.hash_zch_runtime_meta is not None
4210
+ else None
4211
+ ),
4212
+ prefetched_info.linear_unique_indices_length.to(
4213
+ device=torch.device("cpu")
4214
+ ),
4215
+ False, # require_tensor_copy
4216
+ False, # blocking_tensor_copy
4217
+ )
4218
+
4219
+ @staticmethod
4220
+ @torch.jit.ignore
4221
+ def _get_prefetched_info(
4222
+ linear_indices: torch.Tensor,
4223
+ linear_cache_indices_merged: torch.Tensor,
4224
+ total_cache_hash_size: int,
4225
+ hash_zch_identities: Optional[torch.Tensor],
4226
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4227
+ max_indices_length: int,
4228
+ ) -> PrefetchedInfo:
4229
+ (
4230
+ linear_unique_cache_indices,
4231
+ linear_unique_cache_indices_length,
4232
+ linear_unique_cache_indices_cnt,
4233
+ linear_unique_cache_inverse_indices,
4234
+ ) = torch.ops.fbgemm.get_unique_indices_with_inverse(
4235
+ linear_cache_indices_merged,
4236
+ total_cache_hash_size,
4237
+ compute_count=True,
4238
+ compute_inverse_indices=True,
4239
+ )
4240
+ # pure cpu op, no need to sync, to avoid the indices out size the weights buffer
4241
+ max_len = min(
4242
+ max_indices_length,
4243
+ linear_unique_cache_indices.size(0),
4244
+ )
4245
+ if max_len < linear_unique_cache_indices.size(0):
4246
+ linear_unique_cache_indices_length.clamp_(max=max_len)
4247
+ # linear_unique_indices is the result after deduplication and sorting
4248
+ linear_unique_cache_indices = linear_unique_cache_indices.narrow(
4249
+ 0, 0, max_len
4250
+ )
4251
+ # Compute cumulative sum as indices for selecting unique elements to
4252
+ # map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
4253
+ count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
4254
+ linear_unique_cache_indices_cnt
4255
+ )
4256
+ # count_cum_sum will be one more element than linear_unique_cache_indices_cnt
4257
+ count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
4258
+ # clamp the uninitialized elements to avoid out of bound access
4259
+ # the uninitialized elements will be sliced out by linear_unique_cache_indices_length
4260
+ # directly using linear_unique_cache_indices_length requires a sync
4261
+ count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
4262
+
4263
+ # Select indices corresponding to first occurrence of each unique element
4264
+ linear_unique_inverse_indices = (
4265
+ linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
4266
+ )
4267
+ # same as above clamp
4268
+ linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
4269
+ linear_unique_indices = linear_indices.index_select(
4270
+ dim=0, index=linear_unique_inverse_indices
4271
+ )
4272
+ if hash_zch_identities is not None:
4273
+ # Map hash_zch_identities to unique indices
4274
+ hash_zch_identities = hash_zch_identities.index_select(
4275
+ dim=0, index=linear_unique_inverse_indices
4276
+ )
4277
+
4278
+ if hash_zch_runtime_meta is not None:
4279
+ # Map hash_zch_runtime_meta to unique indices
4280
+ hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
4281
+ dim=0, index=linear_unique_inverse_indices
4282
+ )
4283
+
4284
+ return PrefetchedInfo(
4285
+ linear_unique_indices,
4286
+ linear_unique_cache_indices,
4287
+ linear_unique_cache_indices_length,
4288
+ hash_zch_identities,
4289
+ hash_zch_runtime_meta,
4290
+ )
4291
+
4292
+ @torch.jit.ignore
4293
+ def _store_prefetched_tensors(
4294
+ self,
4295
+ indices: torch.Tensor,
4296
+ offsets: torch.Tensor,
4297
+ vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
4298
+ linear_cache_indices_merged: torch.Tensor,
4299
+ final_lxu_cache_locations: torch.Tensor,
4300
+ hash_zch_identities: Optional[torch.Tensor],
4301
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4302
+ ) -> None:
4303
+ """
4304
+ NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
4305
+ This function stores the prefetched tensors for the raw embedding streaming.
4306
+ """
4307
+ if not self.enable_raw_embedding_streaming:
4308
+ return
4309
+
4310
+ with record_function(
4311
+ "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4312
+ ):
4313
+ found_in_cache_mask = final_lxu_cache_locations != -1
4314
+ # only process the indices that are found in the cache
4315
+ # this will filter out the indices from tables that doesn't have UVM_CACHE enabled
4316
+ linear_cache_indices_merged_masked = torch.where(
4317
+ found_in_cache_mask,
4318
+ linear_cache_indices_merged,
4319
+ self.total_cache_hash_size,
4320
+ )
4321
+ linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
4322
+ self.hash_size_cumsum,
4323
+ indices,
4324
+ offsets,
4325
+ vbe_metadata.B_offsets if vbe_metadata is not None else None,
4326
+ vbe_metadata.max_B if vbe_metadata is not None else -1,
4327
+ )
4328
+ # -1 indices are ignored in raw_embedding_streamer.
4329
+ linearize_indices_masked = torch.where(
4330
+ found_in_cache_mask,
4331
+ linearize_indices,
4332
+ -1,
4333
+ )
4334
+ # Process hash_zch_identities using helper function
4335
+ prefetched_info = self._get_prefetched_info(
4336
+ linearize_indices_masked,
4337
+ linear_cache_indices_merged_masked,
4338
+ self.total_cache_hash_size,
4339
+ hash_zch_identities,
4340
+ hash_zch_runtime_meta,
4341
+ self.lxu_cache_weights.size(0),
4342
+ )
4343
+
4344
+ self.prefetched_info_list.append(prefetched_info)
4345
+
4346
+ @torch.jit.ignore
4347
+ def __report_input_params_factory(
4348
+ self,
4349
+ ) -> Optional[Callable[..., None]]:
4350
+ """
4351
+ This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
4352
+
4353
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
4354
+ - Reports input parameters (TBEDataConfig).
4355
+ - Writes the output as a JSON file.
4356
+
4357
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
4358
+ """
4359
+ try:
4360
+ if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
4361
+ from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
4362
+
4363
+ reporter = TBEBenchmarkParamsReporter.create()
4364
+ return reporter.report_stats
4365
+ except Exception:
4366
+ return None
4367
+
4368
+ return None
4369
+
4370
+
4371
+ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
4372
+ """
4373
+ Table-batched version of nn.EmbeddingBag(sparse=False)
4374
+ """
4375
+
4376
+ weights: Tensor
4377
+ weights_offsets: Tensor
4378
+ D_offsets: Tensor
4379
+ total_D: int
4380
+ max_D: int
4381
+ hash_size_cumsum: Tensor
4382
+ total_hash_size_bits: int
4383
+ embedding_specs: list[tuple[int, int]]
4384
+
4385
+ def __init__(
4386
+ self,
4387
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
4388
+ feature_table_map: Optional[list[int]] = None, # [T]
4389
+ weights_precision: SparseType = SparseType.FP32,
4390
+ pooling_mode: PoolingMode = PoolingMode.SUM,
4391
+ use_cpu: bool = False,
4392
+ output_dtype: SparseType = SparseType.FP32,
4393
+ use_mtia: bool = False,
4394
+ ) -> None: # noqa C901 # tuple of (rows, dims,)
4395
+ super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()
4396
+ self.uuid = str(uuid.uuid4())
4397
+
4398
+ self.log(
4399
+ f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
4400
+ )
4401
+
4402
+ self.pooling_mode = pooling_mode
4403
+ self.weights_precision = weights_precision
4404
+ self.output_dtype: int = output_dtype.as_int()
4405
+ table_embedding_dtype = weights_precision.as_dtype()
4406
+
4407
+ self.use_cpu: bool = use_cpu
4408
+ self.use_mtia: bool = use_mtia
4409
+
4410
+ assert not (use_cpu and use_mtia), "Cannot use CPU and MTIA at the same time"
4411
+
4412
+ if self.use_cpu or self.pooling_mode == PoolingMode.NONE:
4413
+ assert output_dtype in [
4414
+ SparseType.FP32,
4415
+ SparseType.FP16,
4416
+ SparseType.BF16,
4417
+ ], "Fused pooled embedding quantization only supported for cuda."
4418
+
4419
+ # pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
4420
+ self.current_device: torch.device = (
4421
+ torch.device("cpu")
4422
+ if self.use_cpu
4423
+ else (
4424
+ torch.device(f"mtia:{torch.mtia.current_device()}")
4425
+ if self.use_mtia
4426
+ else torch.cuda.current_device()
4427
+ )
4428
+ )
4429
+
4430
+ self.embedding_specs = embedding_specs
4431
+ (rows, dims) = zip(*embedding_specs)
4432
+ T_ = len(self.embedding_specs)
4433
+ assert T_ > 0
4434
+
4435
+ feature_table_map = (
4436
+ feature_table_map if feature_table_map is not None else list(range(T_))
4437
+ )
4438
+ T = len(feature_table_map)
4439
+ assert T_ <= T
4440
+
4441
+ feature_dims = [dims[t] for t in feature_table_map]
4442
+ D_offsets = [0] + list(accumulate(feature_dims))
4443
+ self.total_D = D_offsets[-1]
4444
+ self.max_D = max(dims)
4445
+ self.register_buffer(
4446
+ "D_offsets",
4447
+ torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
4448
+ )
4449
+ assert self.D_offsets.numel() == T + 1
4450
+
4451
+ # Required for VBE
4452
+ self.register_buffer(
4453
+ "feature_dims",
4454
+ torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
4455
+ )
4456
+
4457
+ hash_size_cumsum = [0] + list(accumulate(rows))
4458
+ if hash_size_cumsum[-1] == 0:
4459
+ self.total_hash_size_bits: int = 0
4460
+ else:
4461
+ self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
4462
+ # The last element is to easily access # of rows of each table by
4463
+ # hash_size_cumsum[t + 1] - hash_size_cumsum[t]
4464
+ hash_size_cumsum = [hash_size_cumsum[t] for t in feature_table_map] + [
4465
+ hash_size_cumsum[-1]
4466
+ ]
4467
+ self.register_buffer(
4468
+ "hash_size_cumsum",
4469
+ torch.tensor(
4470
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
4471
+ ),
4472
+ )
4473
+ weights_offsets = [0] + list(
4474
+ accumulate([row * dim for (row, dim) in embedding_specs])
4475
+ )
4476
+ self.weights = nn.Parameter(
4477
+ torch.randn(
4478
+ weights_offsets[-1],
4479
+ device=self.current_device,
4480
+ dtype=table_embedding_dtype,
4481
+ )
4482
+ )
4483
+ for feature in range(T):
4484
+ t = feature_table_map[feature]
4485
+ row, dim = embedding_specs[t]
4486
+ if (
4487
+ self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()
4488
+ != row * dim
4489
+ ):
4490
+ self.log(
4491
+ f"row {row} dim {dim} feature {feature} t {t} {self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()}"
4492
+ )
4493
+ assert (
4494
+ self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()
4495
+ == row * dim
4496
+ )
4497
+ assert self.hash_size_cumsum[feature] == sum(
4498
+ row for (row, _) in embedding_specs[:t]
4499
+ )
4500
+
4501
+ self.weights_physical_offsets: list[int] = weights_offsets
4502
+ weights_offsets = [weights_offsets[t] for t in feature_table_map]
4503
+ self.register_buffer(
4504
+ "weights_offsets",
4505
+ torch.tensor(
4506
+ weights_offsets, device=self.current_device, dtype=torch.int64
4507
+ ),
4508
+ )
4509
+
4510
+ @torch.jit.ignore
4511
+ def log(self, msg: str) -> None:
4512
+ """
4513
+ Log with TBE id prefix to distinguish between multiple TBE instances
4514
+ per process
4515
+
4516
+ Args:
4517
+ msg (str): The message to print
4518
+
4519
+ Returns:
4520
+ None
4521
+ """
4522
+ logging.info(f"[TBE={self.uuid}] {msg}")
4523
+
4524
+ @torch.jit.ignore
4525
+ def _generate_vbe_metadata(
4526
+ self,
4527
+ offsets: Tensor,
4528
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
4529
+ ) -> invokers.lookup_args.VBEMetadata:
4530
+ # Blocking D2H copy, but only runs at first call
4531
+ self.feature_dims = self.feature_dims.cpu()
4532
+ return generate_vbe_metadata(
4533
+ offsets,
4534
+ batch_size_per_feature_per_rank,
4535
+ self.pooling_mode,
4536
+ self.feature_dims,
4537
+ self.current_device,
4538
+ )
4539
+
4540
+ def forward(
4541
+ self,
4542
+ indices: Tensor,
4543
+ offsets: Tensor,
4544
+ per_sample_weights: Optional[Tensor] = None,
4545
+ feature_requires_grad: Optional[Tensor] = None,
4546
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
4547
+ ) -> Tensor:
4548
+ # Generate VBE metadata
4549
+ vbe_metadata = self._generate_vbe_metadata(
4550
+ offsets, batch_size_per_feature_per_rank
4551
+ )
4552
+
4553
+ # NOTE: Force offsets to have the same dtype as indices since the
4554
+ # kernels assume same dtype. We might need to revisit the assumption
4555
+ # of same dtypes in the future.
4556
+ offsets = offsets.to(dtype=indices.dtype)
4557
+
4558
+ # Force casting per_sample_weights to float
4559
+ if per_sample_weights is not None:
4560
+ per_sample_weights = per_sample_weights.float()
4561
+
4562
+ return torch.ops.fbgemm.dense_embedding_codegen_lookup_function(
4563
+ dev_weights=self.weights,
4564
+ weights_offsets=self.weights_offsets,
4565
+ D_offsets=self.D_offsets,
4566
+ total_D=self.total_D,
4567
+ max_D=self.max_D,
4568
+ hash_size_cumsum=self.hash_size_cumsum,
4569
+ total_hash_size_bits=self.total_hash_size_bits,
4570
+ indices=indices,
4571
+ offsets=offsets,
4572
+ pooling_mode=self.pooling_mode,
4573
+ indice_weights=per_sample_weights,
4574
+ feature_requires_grad=feature_requires_grad,
4575
+ output_dtype=self.output_dtype,
4576
+ B_offsets=vbe_metadata.B_offsets,
4577
+ vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
4578
+ vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
4579
+ max_B=vbe_metadata.max_B,
4580
+ max_B_feature_rank=vbe_metadata.max_B_feature_rank,
4581
+ vbe_output_size=vbe_metadata.output_size,
4582
+ )
4583
+
4584
+ @torch.jit.export
4585
+ def split_embedding_weights(self) -> list[Tensor]:
4586
+ """
4587
+ Returns a list of weights, split by table
4588
+ """
4589
+ splits = []
4590
+ for t, (rows, dim) in enumerate(self.embedding_specs):
4591
+ offset = self.weights_physical_offsets[t]
4592
+ splits.append(
4593
+ self.weights.detach()[offset : offset + rows * dim].view(rows, dim)
4594
+ )
4595
+ return splits
4596
+
4597
+ def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None:
4598
+ splits = self.split_embedding_weights()
4599
+ for param in splits:
4600
+ param.uniform_(min_val, max_val)