fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,73 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+ # flake8: noqa F401
11
+
12
+ import torch # usort:skip
13
+ import warnings
14
+
15
+ # This module is a compatibility wrapper that re-exports the symbols from:
16
+ # fbgemm_gpu.split_table_batched_embeddings_ops_common
17
+ # fbgemm_gpu.split_table_batched_embeddings_ops_inference
18
+ # fbgemm_gpu.split_table_batched_embeddings_ops_training
19
+
20
+ from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
21
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
22
+ BoundsCheckMode,
23
+ CacheAlgorithm,
24
+ CacheState,
25
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
26
+ EmbeddingLocation,
27
+ PoolingMode,
28
+ RecordCacheMetrics,
29
+ round_up,
30
+ SplitState,
31
+ )
32
+ from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
33
+ align_to_cacheline,
34
+ IntNBitTableBatchedEmbeddingBagsCodegen,
35
+ rounded_row_size_in_bytes,
36
+ unpadded_row_size_in_bytes,
37
+ )
38
+ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
39
+ ComputeDevice,
40
+ CounterBasedRegularizationDefinition,
41
+ CounterWeightDecayMode,
42
+ DEFAULT_ASSOC,
43
+ DenseTableBatchedEmbeddingBagsCodegen,
44
+ GradSumDecay,
45
+ INT8_EMB_ROW_DIM_OFFSET,
46
+ LearningRateMode,
47
+ SplitTableBatchedEmbeddingBagsCodegen,
48
+ TailIdThreshold,
49
+ WeightDecayMode,
50
+ )
51
+
52
+ try:
53
+ if torch.version.hip:
54
+ torch.ops.load_library(
55
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip"
56
+ )
57
+ else:
58
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
59
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
60
+ except Exception:
61
+ pass
62
+
63
+ warnings.warn(
64
+ f"""\033[93m
65
+ The Python module {__name__} is now DEPRECATED and will be removed in the
66
+ future. Users should instead declare dependencies on
67
+ //deeplearning/fbgemm/fbgemm_gpu/split_table_batched_embeddings_ops_{{training, inference}}
68
+ in their TARGETS file and import the
69
+ fbgemm_gpu.split_table_batched_embeddings_ops_{{training, inference}}
70
+ modules as needed in their scripts.
71
+ \033[0m""",
72
+ DeprecationWarning,
73
+ )
@@ -0,0 +1,484 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # pyre-ignore-all-errors[56]
11
+
12
+ import enum
13
+ from dataclasses import dataclass
14
+ from typing import FrozenSet, NamedTuple, Optional, Tuple
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+
20
+ # Maximum number of times prefetch() can be called without
21
+ # a corresponding forward() call
22
+ MAX_PREFETCH_DEPTH = 100
23
+
24
+ # GPU and CPU use 16-bit scale and bias for quantized embedding bags in TBE
25
+ # The total size is 2 + 2 = 4 bytes
26
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES = 4
27
+
28
+
29
+ class EmbeddingLocation(enum.IntEnum):
30
+ DEVICE = 0
31
+ MANAGED = 1
32
+ MANAGED_CACHING = 2
33
+ HOST = 3
34
+ MTIA = 4
35
+
36
+ @classmethod
37
+ # pyre-ignore[3]
38
+ def str_values(cls):
39
+ return [
40
+ "device",
41
+ "managed",
42
+ "managed_caching",
43
+ "host",
44
+ "mtia",
45
+ ]
46
+
47
+ @classmethod
48
+ # pyre-ignore[3]
49
+ def from_str(cls, key: str):
50
+ lookup = {
51
+ "device": EmbeddingLocation.DEVICE,
52
+ "managed": EmbeddingLocation.MANAGED,
53
+ "managed_caching": EmbeddingLocation.MANAGED_CACHING,
54
+ "host": EmbeddingLocation.HOST,
55
+ "mtia": EmbeddingLocation.MTIA,
56
+ }
57
+ if key in lookup:
58
+ return lookup[key]
59
+ else:
60
+ raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
61
+
62
+
63
+ class EvictionPolicy(NamedTuple):
64
+ eviction_trigger_mode: int = (
65
+ 0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
66
+ )
67
+ eviction_strategy: int = (
68
+ 0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
69
+ )
70
+ eviction_step_intervals: Optional[int] = (
71
+ None # trigger_step_interval if trigger mode is iteration
72
+ )
73
+ eviction_mem_threshold_gb: Optional[int] = (
74
+ None # eviction trigger condition if trigger mode is mem_util
75
+ )
76
+ counter_thresholds: Optional[list[int]] = (
77
+ None # count_thresholds for each table if eviction strategy is counter
78
+ )
79
+ ttls_in_mins: Optional[list[int]] = (
80
+ None # ttls_in_mins for each table if eviction strategy is timestamp
81
+ )
82
+ counter_decay_rates: Optional[list[float]] = (
83
+ None # count_decay_rates for each table if eviction strategy is counter
84
+ )
85
+ feature_score_counter_decay_rates: Optional[list[float]] = (
86
+ None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
87
+ )
88
+ training_id_eviction_trigger_count: Optional[list[int]] = (
89
+ None # Number of training IDs that, when exceeded, will trigger eviction for each table.
90
+ )
91
+ training_id_keep_count: Optional[list[int]] = (
92
+ None # Target number of training IDs to retain in each table after eviction.
93
+ )
94
+ l2_weight_thresholds: Optional[list[float]] = (
95
+ None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
96
+ )
97
+ threshold_calculation_bucket_stride: Optional[float] = (
98
+ 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
99
+ )
100
+ threshold_calculation_bucket_num: Optional[int] = (
101
+ 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
102
+ )
103
+ interval_for_insufficient_eviction_s: int = (
104
+ # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
105
+ # insufficient means we didn't evict enough rows, so we want to wait longer time to
106
+ # avoid another insufficient eviction
107
+ 600
108
+ )
109
+ interval_for_sufficient_eviction_s: int = (
110
+ # wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
111
+ 60
112
+ )
113
+ interval_for_feature_statistics_decay_s: int = (
114
+ 24 * 3600 # 1 day, interval for feature statistics decay
115
+ )
116
+ meta_header_lens: Optional[list[int]] = None # metaheader length for each table
117
+ eviction_free_mem_threshold_gb: Optional[int] = (
118
+ None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
119
+ )
120
+ eviction_free_mem_check_interval_batch: Optional[int] = (
121
+ None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
122
+ )
123
+ enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
124
+ None # enable eviction if eviction policy is feature score, false means no eviction
125
+ )
126
+
127
+ def validate(self) -> None:
128
+ assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
129
+ "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
130
+ f"actual {self.eviction_trigger_mode}"
131
+ )
132
+ if self.eviction_trigger_mode == 0:
133
+ return
134
+
135
+ assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
136
+ "eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
137
+ f"actual {self.eviction_strategy}"
138
+ )
139
+ if self.eviction_trigger_mode == 1:
140
+ assert (
141
+ self.eviction_step_intervals is not None
142
+ and self.eviction_step_intervals > 0
143
+ ), (
144
+ "eviction_step_intervals must be positive if eviction_trigger_mode is 1, "
145
+ f"actual {self.eviction_step_intervals}"
146
+ )
147
+ elif self.eviction_trigger_mode == 2:
148
+ assert (
149
+ self.eviction_mem_threshold_gb is not None
150
+ ), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
151
+ elif self.eviction_trigger_mode == 4:
152
+ assert (
153
+ self.training_id_eviction_trigger_count is not None
154
+ ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
155
+ elif self.eviction_trigger_mode == 5:
156
+ assert (
157
+ self.eviction_free_mem_threshold_gb is not None
158
+ ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
159
+ assert (
160
+ self.eviction_free_mem_check_interval_batch is not None
161
+ ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
162
+
163
+ if self.eviction_strategy == 0:
164
+ assert self.ttls_in_mins is not None, (
165
+ "ttls_in_mins must be set if eviction_strategy is 0, "
166
+ f"actual {self.ttls_in_mins}"
167
+ )
168
+ elif self.eviction_strategy == 1:
169
+ assert self.counter_thresholds is not None, (
170
+ "counter_thresholds must be set if eviction_strategy is 1, "
171
+ f"actual {self.counter_thresholds}"
172
+ )
173
+ assert self.counter_decay_rates is not None, (
174
+ "counter_decay_rates must be set if eviction_strategy is 1, "
175
+ f"actual {self.counter_decay_rates}"
176
+ )
177
+ assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
178
+ "counter_thresholds and counter_decay_rates must have the same length, "
179
+ f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
180
+ )
181
+ elif self.eviction_strategy == 2:
182
+ assert self.counter_thresholds is not None, (
183
+ "counter_thresholds must be set if eviction_strategy is 2, "
184
+ f"actual {self.counter_thresholds}"
185
+ )
186
+ assert self.counter_decay_rates is not None, (
187
+ "counter_decay_rates must be set if eviction_strategy is 2, "
188
+ f"actual {self.counter_decay_rates}"
189
+ )
190
+ assert self.ttls_in_mins is not None, (
191
+ "ttls_in_mins must be set if eviction_strategy is 2, "
192
+ f"actual {self.ttls_in_mins}"
193
+ )
194
+ assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
195
+ "counter_thresholds and counter_decay_rates must have the same length, "
196
+ f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
197
+ )
198
+ assert len(self.counter_thresholds) == len(self.ttls_in_mins), (
199
+ "counter_thresholds and ttls_in_mins must have the same length, "
200
+ f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
201
+ )
202
+ elif self.eviction_strategy == 5:
203
+ assert self.feature_score_counter_decay_rates is not None, (
204
+ "feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
205
+ f"actual {self.feature_score_counter_decay_rates}"
206
+ )
207
+ assert self.training_id_eviction_trigger_count is not None, (
208
+ "training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
209
+ f"actual {self.training_id_eviction_trigger_count}"
210
+ )
211
+ assert self.training_id_keep_count is not None, (
212
+ "training_id_keep_count must be set if eviction_strategy is 5,"
213
+ f"actual {self.training_id_keep_count}"
214
+ )
215
+ assert self.threshold_calculation_bucket_stride is not None, (
216
+ "threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
217
+ f"actual {self.threshold_calculation_bucket_stride}"
218
+ )
219
+ assert self.threshold_calculation_bucket_num is not None, (
220
+ "threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
221
+ f"actual {self.threshold_calculation_bucket_num}"
222
+ )
223
+ assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
224
+ "enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
225
+ f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
226
+ )
227
+ assert (
228
+ len(self.enable_eviction_for_feature_score_eviction_policy)
229
+ == len(self.training_id_keep_count)
230
+ == len(self.feature_score_counter_decay_rates)
231
+ ), (
232
+ "feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
233
+ f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
234
+ )
235
+
236
+
237
+ class KVZCHParams(NamedTuple):
238
+ # global bucket id start and global bucket id end offsets for each logical table,
239
+ # where start offset is inclusive and end offset is exclusive
240
+ bucket_offsets: list[tuple[int, int]] = []
241
+ # bucket size for each logical table
242
+ # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
243
+ bucket_sizes: list[int] = []
244
+ # enable optimizer offloading or not
245
+ enable_optimizer_offloading: bool = False
246
+ # when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
247
+ # can only be enabled when enable_optimizer_offloading is enabled
248
+ backend_return_whole_row: bool = False
249
+ eviction_policy: EvictionPolicy = EvictionPolicy()
250
+ embedding_cache_mode: bool = False
251
+ load_ckpt_without_opt: bool = False
252
+ optimizer_type_for_st: Optional[str] = None
253
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
254
+
255
+ def validate(self) -> None:
256
+ assert len(self.bucket_offsets) == len(self.bucket_sizes), (
257
+ "bucket_offsets and bucket_sizes must have the same length, "
258
+ f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
259
+ )
260
+ self.eviction_policy.validate()
261
+ assert (
262
+ not self.backend_return_whole_row or self.enable_optimizer_offloading
263
+ ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
264
+
265
+
266
+ class KVZCHTBEConfig(NamedTuple):
267
+ # Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
268
+ kvzch_eviction_trigger_mode: int = 2 # mem_util
269
+ # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
270
+ eviction_free_mem_threshold_gb: int = 200 # 200GB
271
+ # Number of batches between checks for free memory threshold when using free_mem trigger mode.
272
+ eviction_free_mem_check_interval_batch: int = 1000
273
+ # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
274
+ threshold_calculation_bucket_stride: float = 0.2
275
+ # Total number of feature score buckets used for threshold calculation in feature score-based eviction.
276
+ threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
277
+ # When true, we only save weight to kvzch backend and not optimizer state.
278
+ load_ckpt_without_opt: bool = False
279
+ # [DO NOT USE] This is for st publish only, do not set it in your config
280
+ optimizer_type_for_st: Optional[str] = None
281
+ # [DO NOT USE] This is for st publish only, do not set it in your config
282
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
283
+
284
+
285
+ class BackendType(enum.IntEnum):
286
+ SSD = 0
287
+ DRAM = 1
288
+ PS = 2
289
+
290
+ @classmethod
291
+ # pyre-ignore[3]
292
+ def from_str(cls, key: str):
293
+ lookup = {
294
+ "ssd": BackendType.SSD,
295
+ "dram": BackendType.DRAM,
296
+ }
297
+ if key in lookup:
298
+ return lookup[key]
299
+ else:
300
+ raise ValueError(f"Cannot parse value into BackendType: {key}")
301
+
302
+
303
+ class CacheAlgorithm(enum.Enum):
304
+ LRU = 0
305
+ LFU = 1
306
+
307
+
308
+ class MultiPassPrefetchConfig(NamedTuple):
309
+ # Number of passes to split indices tensor into. Actual number of passes may
310
+ # be less if indices tensor is too small to split.
311
+ num_passes: int = 12
312
+
313
+ # The minimal number of element in indices tensor to be able to split into
314
+ # two passes. This is useful to prevent too many prefetch kernels spamming
315
+ # the CUDA launch queue.
316
+ # The default 6M indices means 6M * 8 * 6 = approx. 300MB of memory overhead
317
+ # per pass.
318
+ min_splitable_pass_size: int = 6 * 1024 * 1024
319
+
320
+
321
+ class PoolingMode(enum.IntEnum):
322
+ SUM = 0
323
+ MEAN = 1
324
+ NONE = 2
325
+
326
+ def do_pooling(self) -> bool:
327
+ return self is not PoolingMode.NONE
328
+
329
+ @classmethod
330
+ # pyre-ignore[3]
331
+ def from_str(cls, key: str):
332
+ lookup = {
333
+ "sum": PoolingMode.SUM,
334
+ "mean": PoolingMode.MEAN,
335
+ "none": PoolingMode.NONE,
336
+ }
337
+ if key in lookup:
338
+ return lookup[key]
339
+ else:
340
+ raise ValueError(f"Cannot parse value into PoolingMode: {key}")
341
+
342
+
343
+ class BoundsCheckMode(enum.IntEnum):
344
+ # Raise an exception (CPU) or device-side assert (CUDA)
345
+ FATAL = 0
346
+ # Log the first out-of-bounds instance per kernel, and set to zero.
347
+ WARNING = 1
348
+ # Set to zero.
349
+ IGNORE = 2
350
+ # No bounds checks.
351
+ NONE = 3
352
+ # IGNORE with V2 enabled
353
+ V2_IGNORE = 4
354
+ # WARNING with V2 enabled
355
+ V2_WARNING = 5
356
+ # FATAL with V2 enabled
357
+ V2_FATAL = 6
358
+
359
+
360
+ class ComputeDevice(enum.IntEnum):
361
+ CPU = 0
362
+ CUDA = 1
363
+ MTIA = 2
364
+
365
+
366
+ class EmbeddingSpecInfo(enum.IntEnum):
367
+ feature_names = 0
368
+ rows = 1
369
+ dims = 2
370
+ sparse_type = 3
371
+ embedding_location = 4
372
+
373
+
374
+ RecordCacheMetrics: NamedTuple = NamedTuple(
375
+ "RecordCacheMetrics",
376
+ [("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)],
377
+ )
378
+
379
+ SplitState: NamedTuple = NamedTuple(
380
+ "SplitState",
381
+ [
382
+ ("dev_size", int),
383
+ ("host_size", int),
384
+ ("uvm_size", int),
385
+ ("placements", list[EmbeddingLocation]),
386
+ ("offsets", list[int]),
387
+ ],
388
+ )
389
+
390
+
391
+ @dataclass
392
+ class CacheState:
393
+ # T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
394
+ cache_hash_size_cumsum: list[int]
395
+ cache_index_table_map: list[int]
396
+ total_cache_hash_size: int
397
+
398
+
399
+ def construct_cache_state(
400
+ row_list: list[int],
401
+ location_list: list[EmbeddingLocation],
402
+ feature_table_map: list[int],
403
+ ) -> CacheState:
404
+ _cache_hash_size_cumsum = [0]
405
+ total_cache_hash_size = 0
406
+ for num_embeddings, location in zip(row_list, location_list):
407
+ if location == EmbeddingLocation.MANAGED_CACHING:
408
+ total_cache_hash_size += num_embeddings
409
+ _cache_hash_size_cumsum.append(total_cache_hash_size)
410
+ # [T], -1: non-cached table
411
+ cache_hash_size_cumsum = []
412
+ # [total_cache_hash_size], linear cache index -> table index
413
+ cache_index_table_map = [-1] * total_cache_hash_size
414
+ unique_feature_table_map = {}
415
+ for t, t_ in enumerate(feature_table_map):
416
+ unique_feature_table_map[t_] = t
417
+ for t_, t in unique_feature_table_map.items():
418
+ start, end = _cache_hash_size_cumsum[t_], _cache_hash_size_cumsum[t_ + 1]
419
+ cache_index_table_map[start:end] = [t] * (end - start)
420
+ cache_hash_size_cumsum = [
421
+ (
422
+ _cache_hash_size_cumsum[t_]
423
+ if location_list[t_] == EmbeddingLocation.MANAGED_CACHING
424
+ else -1
425
+ )
426
+ for t_ in feature_table_map
427
+ ]
428
+ cache_hash_size_cumsum.append(total_cache_hash_size)
429
+ s = CacheState(
430
+ cache_hash_size_cumsum=cache_hash_size_cumsum,
431
+ cache_index_table_map=cache_index_table_map,
432
+ total_cache_hash_size=total_cache_hash_size,
433
+ )
434
+ return s
435
+
436
+
437
+ # NOTE: This is also defined in fbgemm_gpu.tbe.utils, but declaring
438
+ # target dependency on :split_embedding_utils will result in compatibility
439
+ # breakage with Caffe2 module_factory because it will pull in numpy
440
+ def round_up(a: int, b: int) -> int:
441
+ return int((a + b - 1) // b) * b
442
+
443
+
444
+ def tensor_to_device(tensor: torch.Tensor, device: torch.device) -> Tensor:
445
+ if tensor.device == torch.device("meta"):
446
+ return torch.empty_like(tensor, device=device)
447
+ return tensor.to(device)
448
+
449
+
450
+ def get_new_embedding_location(
451
+ device: torch.device, cache_load_factor: float
452
+ ) -> EmbeddingLocation:
453
+ """
454
+ Based on the cache_load_factor and device, return the embedding location intended
455
+ for the TBE weights.
456
+ """
457
+ # Only support CPU and GPU device
458
+ assert device.type == "cpu" or device.type == "cuda"
459
+ if cache_load_factor < 0 or cache_load_factor > 1:
460
+ raise ValueError(
461
+ f"cache_load_factor must be between 0.0 and 1.0, got {cache_load_factor}"
462
+ )
463
+
464
+ if device.type == "cpu":
465
+ return EmbeddingLocation.HOST
466
+ # UVM only
467
+ elif cache_load_factor == 0:
468
+ return EmbeddingLocation.MANAGED
469
+ # HBM only
470
+ elif cache_load_factor == 1.0:
471
+ return EmbeddingLocation.DEVICE
472
+ # UVM caching
473
+ else:
474
+ return EmbeddingLocation.MANAGED_CACHING
475
+
476
+
477
+ def get_bounds_check_version_for_platform() -> int:
478
+ # NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
479
+ # constraint that the gridDim * blockDim has to be smaller than
480
+ # 2^32. The v1 kernel can be launched with gridDim * blockDim >
481
+ # 2^32 while the v2 kernel limits the gridDim size to 64 * # of
482
+ # SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
483
+ # than 2^32
484
+ return 2 if (torch.cuda.is_available() and torch.version.hip) else 1