fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -11,12 +11,11 @@
11
11
 
12
12
  import enum
13
13
  from dataclasses import dataclass
14
- from typing import List, NamedTuple
14
+ from typing import FrozenSet, NamedTuple, Optional, Tuple
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
18
18
 
19
-
20
19
  # Maximum number of times prefetch() can be called without
21
20
  # a corresponding forward() call
22
21
  MAX_PREFETCH_DEPTH = 100
@@ -33,6 +32,17 @@ class EmbeddingLocation(enum.IntEnum):
33
32
  HOST = 3
34
33
  MTIA = 4
35
34
 
35
+ @classmethod
36
+ # pyre-ignore[3]
37
+ def str_values(cls):
38
+ return [
39
+ "device",
40
+ "managed",
41
+ "managed_caching",
42
+ "host",
43
+ "mtia",
44
+ ]
45
+
36
46
  @classmethod
37
47
  # pyre-ignore[3]
38
48
  def from_str(cls, key: str):
@@ -49,6 +59,246 @@ class EmbeddingLocation(enum.IntEnum):
49
59
  raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
50
60
 
51
61
 
62
+ class EvictionPolicy(NamedTuple):
63
+ eviction_trigger_mode: int = (
64
+ 0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
65
+ )
66
+ eviction_strategy: int = (
67
+ 0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
68
+ )
69
+ eviction_step_intervals: Optional[int] = (
70
+ None # trigger_step_interval if trigger mode is iteration
71
+ )
72
+ eviction_mem_threshold_gb: Optional[int] = (
73
+ None # eviction trigger condition if trigger mode is mem_util
74
+ )
75
+ counter_thresholds: Optional[list[int]] = (
76
+ None # count_thresholds for each table if eviction strategy is counter
77
+ )
78
+ ttls_in_mins: Optional[list[int]] = (
79
+ None # ttls_in_mins for each table if eviction strategy is timestamp
80
+ )
81
+ counter_decay_rates: Optional[list[float]] = (
82
+ None # count_decay_rates for each table if eviction strategy is counter
83
+ )
84
+ feature_score_counter_decay_rates: Optional[list[float]] = (
85
+ None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
86
+ )
87
+ training_id_eviction_trigger_count: Optional[list[int]] = (
88
+ None # Number of training IDs that, when exceeded, will trigger eviction for each table.
89
+ )
90
+ training_id_keep_count: Optional[list[int]] = (
91
+ None # Target number of training IDs to retain in each table after eviction.
92
+ )
93
+ l2_weight_thresholds: Optional[list[float]] = (
94
+ None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
95
+ )
96
+ threshold_calculation_bucket_stride: Optional[float] = (
97
+ 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
98
+ )
99
+ threshold_calculation_bucket_num: Optional[int] = (
100
+ 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
101
+ )
102
+ interval_for_insufficient_eviction_s: int = (
103
+ # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
104
+ # insufficient means we didn't evict enough rows, so we want to wait longer time to
105
+ # avoid another insufficient eviction
106
+ 600
107
+ )
108
+ interval_for_sufficient_eviction_s: int = (
109
+ # wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
110
+ 60
111
+ )
112
+ interval_for_feature_statistics_decay_s: int = (
113
+ 24 * 3600 # 1 day, interval for feature statistics decay
114
+ )
115
+ meta_header_lens: Optional[list[int]] = None # metaheader length for each table
116
+ eviction_free_mem_threshold_gb: Optional[int] = (
117
+ None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
118
+ )
119
+ eviction_free_mem_check_interval_batch: Optional[int] = (
120
+ None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
121
+ )
122
+ enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
123
+ None # enable eviction if eviction policy is feature score, false means no eviction
124
+ )
125
+
126
+ def validate(self) -> None:
127
+ assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
128
+ "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
129
+ f"actual {self.eviction_trigger_mode}"
130
+ )
131
+ if self.eviction_trigger_mode == 0:
132
+ return
133
+
134
+ assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
135
+ "eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
136
+ f"actual {self.eviction_strategy}"
137
+ )
138
+ if self.eviction_trigger_mode == 1:
139
+ assert (
140
+ self.eviction_step_intervals is not None
141
+ and self.eviction_step_intervals > 0
142
+ ), (
143
+ "eviction_step_intervals must be positive if eviction_trigger_mode is 1, "
144
+ f"actual {self.eviction_step_intervals}"
145
+ )
146
+ elif self.eviction_trigger_mode == 2:
147
+ assert (
148
+ self.eviction_mem_threshold_gb is not None
149
+ ), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
150
+ elif self.eviction_trigger_mode == 4:
151
+ assert (
152
+ self.training_id_eviction_trigger_count is not None
153
+ ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
154
+ elif self.eviction_trigger_mode == 5:
155
+ assert (
156
+ self.eviction_free_mem_threshold_gb is not None
157
+ ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
158
+ assert (
159
+ self.eviction_free_mem_check_interval_batch is not None
160
+ ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
161
+
162
+ if self.eviction_strategy == 0:
163
+ assert self.ttls_in_mins is not None, (
164
+ "ttls_in_mins must be set if eviction_strategy is 0, "
165
+ f"actual {self.ttls_in_mins}"
166
+ )
167
+ elif self.eviction_strategy == 1:
168
+ assert self.counter_thresholds is not None, (
169
+ "counter_thresholds must be set if eviction_strategy is 1, "
170
+ f"actual {self.counter_thresholds}"
171
+ )
172
+ assert self.counter_decay_rates is not None, (
173
+ "counter_decay_rates must be set if eviction_strategy is 1, "
174
+ f"actual {self.counter_decay_rates}"
175
+ )
176
+ assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
177
+ "counter_thresholds and counter_decay_rates must have the same length, "
178
+ f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
179
+ )
180
+ elif self.eviction_strategy == 2:
181
+ assert self.counter_thresholds is not None, (
182
+ "counter_thresholds must be set if eviction_strategy is 2, "
183
+ f"actual {self.counter_thresholds}"
184
+ )
185
+ assert self.counter_decay_rates is not None, (
186
+ "counter_decay_rates must be set if eviction_strategy is 2, "
187
+ f"actual {self.counter_decay_rates}"
188
+ )
189
+ assert self.ttls_in_mins is not None, (
190
+ "ttls_in_mins must be set if eviction_strategy is 2, "
191
+ f"actual {self.ttls_in_mins}"
192
+ )
193
+ assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
194
+ "counter_thresholds and counter_decay_rates must have the same length, "
195
+ f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
196
+ )
197
+ assert len(self.counter_thresholds) == len(self.ttls_in_mins), (
198
+ "counter_thresholds and ttls_in_mins must have the same length, "
199
+ f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
200
+ )
201
+ elif self.eviction_strategy == 5:
202
+ assert self.feature_score_counter_decay_rates is not None, (
203
+ "feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
204
+ f"actual {self.feature_score_counter_decay_rates}"
205
+ )
206
+ assert self.training_id_eviction_trigger_count is not None, (
207
+ "training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
208
+ f"actual {self.training_id_eviction_trigger_count}"
209
+ )
210
+ assert self.training_id_keep_count is not None, (
211
+ "training_id_keep_count must be set if eviction_strategy is 5,"
212
+ f"actual {self.training_id_keep_count}"
213
+ )
214
+ assert self.threshold_calculation_bucket_stride is not None, (
215
+ "threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
216
+ f"actual {self.threshold_calculation_bucket_stride}"
217
+ )
218
+ assert self.threshold_calculation_bucket_num is not None, (
219
+ "threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
220
+ f"actual {self.threshold_calculation_bucket_num}"
221
+ )
222
+ assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
223
+ "enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
224
+ f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
225
+ )
226
+ assert (
227
+ len(self.enable_eviction_for_feature_score_eviction_policy)
228
+ == len(self.training_id_keep_count)
229
+ == len(self.feature_score_counter_decay_rates)
230
+ ), (
231
+ "feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
232
+ f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
233
+ )
234
+
235
+
236
+ class KVZCHParams(NamedTuple):
237
+ # global bucket id start and global bucket id end offsets for each logical table,
238
+ # where start offset is inclusive and end offset is exclusive
239
+ bucket_offsets: list[tuple[int, int]] = []
240
+ # bucket size for each logical table
241
+ # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
242
+ bucket_sizes: list[int] = []
243
+ # enable optimizer offloading or not
244
+ enable_optimizer_offloading: bool = False
245
+ # when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
246
+ # can only be enabled when enable_optimizer_offloading is enabled
247
+ backend_return_whole_row: bool = False
248
+ eviction_policy: EvictionPolicy = EvictionPolicy()
249
+ embedding_cache_mode: bool = False
250
+ load_ckpt_without_opt: bool = False
251
+ optimizer_type_for_st: Optional[str] = None
252
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
253
+
254
+ def validate(self) -> None:
255
+ assert len(self.bucket_offsets) == len(self.bucket_sizes), (
256
+ "bucket_offsets and bucket_sizes must have the same length, "
257
+ f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
258
+ )
259
+ self.eviction_policy.validate()
260
+ assert (
261
+ not self.backend_return_whole_row or self.enable_optimizer_offloading
262
+ ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
263
+
264
+
265
+ class KVZCHTBEConfig(NamedTuple):
266
+ # Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
267
+ kvzch_eviction_trigger_mode: int = 2 # mem_util
268
+ # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
269
+ eviction_free_mem_threshold_gb: int = 200 # 200GB
270
+ # Number of batches between checks for free memory threshold when using free_mem trigger mode.
271
+ eviction_free_mem_check_interval_batch: int = 1000
272
+ # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
273
+ threshold_calculation_bucket_stride: float = 0.2
274
+ # Total number of feature score buckets used for threshold calculation in feature score-based eviction.
275
+ threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
276
+ # When true, we only save weight to kvzch backend and not optimizer state.
277
+ load_ckpt_without_opt: bool = False
278
+ # [DO NOT USE] This is for st publish only, do not set it in your config
279
+ optimizer_type_for_st: Optional[str] = None
280
+ # [DO NOT USE] This is for st publish only, do not set it in your config
281
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
282
+
283
+
284
+ class BackendType(enum.IntEnum):
285
+ SSD = 0
286
+ DRAM = 1
287
+ PS = 2
288
+
289
+ @classmethod
290
+ # pyre-ignore[3]
291
+ def from_str(cls, key: str):
292
+ lookup = {
293
+ "ssd": BackendType.SSD,
294
+ "dram": BackendType.DRAM,
295
+ }
296
+ if key in lookup:
297
+ return lookup[key]
298
+ else:
299
+ raise ValueError(f"Cannot parse value into BackendType: {key}")
300
+
301
+
52
302
  class CacheAlgorithm(enum.Enum):
53
303
  LRU = 0
54
304
  LFU = 1
@@ -106,6 +356,12 @@ class BoundsCheckMode(enum.IntEnum):
106
356
  V2_FATAL = 6
107
357
 
108
358
 
359
+ class ComputeDevice(enum.IntEnum):
360
+ CPU = 0
361
+ CUDA = 1
362
+ MTIA = 2
363
+
364
+
109
365
  class EmbeddingSpecInfo(enum.IntEnum):
110
366
  feature_names = 0
111
367
  rows = 1
@@ -125,8 +381,8 @@ SplitState: NamedTuple = NamedTuple(
125
381
  ("dev_size", int),
126
382
  ("host_size", int),
127
383
  ("uvm_size", int),
128
- ("placements", List[EmbeddingLocation]),
129
- ("offsets", List[int]),
384
+ ("placements", list[EmbeddingLocation]),
385
+ ("offsets", list[int]),
130
386
  ],
131
387
  )
132
388
 
@@ -134,15 +390,15 @@ SplitState: NamedTuple = NamedTuple(
134
390
  @dataclass
135
391
  class CacheState:
136
392
  # T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
137
- cache_hash_size_cumsum: List[int]
138
- cache_index_table_map: List[int]
393
+ cache_hash_size_cumsum: list[int]
394
+ cache_index_table_map: list[int]
139
395
  total_cache_hash_size: int
140
396
 
141
397
 
142
398
  def construct_cache_state(
143
- row_list: List[int],
144
- location_list: List[EmbeddingLocation],
145
- feature_table_map: List[int],
399
+ row_list: list[int],
400
+ location_list: list[EmbeddingLocation],
401
+ feature_table_map: list[int],
146
402
  ) -> CacheState:
147
403
  _cache_hash_size_cumsum = [0]
148
404
  total_cache_hash_size = 0
@@ -215,3 +471,13 @@ def get_new_embedding_location(
215
471
  # UVM caching
216
472
  else:
217
473
  return EmbeddingLocation.MANAGED_CACHING
474
+
475
+
476
+ def get_bounds_check_version_for_platform() -> int:
477
+ # NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
478
+ # constraint that the gridDim * blockDim has to be smaller than
479
+ # 2^32. The v1 kernel can be launched with gridDim * blockDim >
480
+ # 2^32 while the v2 kernel limits the gridDim size to 64 * # of
481
+ # SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
482
+ # than 2^32
483
+ return 2 if (torch.cuda.is_available() and torch.version.hip) else 1
@@ -12,7 +12,7 @@
12
12
  import logging
13
13
  import uuid
14
14
  from itertools import accumulate
15
- from typing import List, Optional, Tuple, Union
15
+ from typing import Optional, Union
16
16
 
17
17
  import fbgemm_gpu # noqa: F401
18
18
  import torch # usort:skip
@@ -28,6 +28,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
28
28
  DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
29
29
  EmbeddingLocation,
30
30
  EmbeddingSpecInfo,
31
+ get_bounds_check_version_for_platform,
31
32
  get_new_embedding_location,
32
33
  MAX_PREFETCH_DEPTH,
33
34
  PoolingMode,
@@ -91,14 +92,14 @@ def align_to_cacheline(a: int) -> int:
91
92
 
92
93
 
93
94
  def nbit_construct_split_state(
94
- embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]],
95
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
95
96
  cacheable: bool,
96
97
  row_alignment: int,
97
98
  scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
98
99
  cacheline_alignment: bool = True,
99
100
  ) -> SplitState:
100
- placements = torch.jit.annotate(List[EmbeddingLocation], [])
101
- offsets = torch.jit.annotate(List[int], [])
101
+ placements = torch.jit.annotate(list[EmbeddingLocation], [])
102
+ offsets = torch.jit.annotate(list[int], [])
102
103
  dev_size = 0
103
104
  host_size = 0
104
105
  uvm_size = 0
@@ -164,7 +165,7 @@ def inputs_to_device(
164
165
  offsets: torch.Tensor,
165
166
  per_sample_weights: Optional[torch.Tensor],
166
167
  bounds_check_warning: torch.Tensor,
167
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
168
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
168
169
  if bounds_check_warning.device.type == "meta":
169
170
  return indices, offsets, per_sample_weights
170
171
 
@@ -330,7 +331,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
330
331
  Options are `torch.int32` and `torch.int64`.
331
332
  """
332
333
 
333
- embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
334
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
334
335
  record_cache_metrics: RecordCacheMetrics
335
336
  # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
336
337
  cache_miss_counter: torch.Tensor
@@ -345,15 +346,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
345
346
 
346
347
  def __init__( # noqa C901
347
348
  self,
348
- embedding_specs: List[
349
- Tuple[str, int, int, SparseType, EmbeddingLocation]
349
+ embedding_specs: list[
350
+ tuple[str, int, int, SparseType, EmbeddingLocation]
350
351
  ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
351
- feature_table_map: Optional[List[int]] = None, # [T]
352
- index_remapping: Optional[List[Tensor]] = None,
352
+ feature_table_map: Optional[list[int]] = None, # [T]
353
+ index_remapping: Optional[list[Tensor]] = None,
353
354
  pooling_mode: PoolingMode = PoolingMode.SUM,
354
355
  device: Optional[Union[str, int, torch.device]] = None,
355
356
  bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
356
- weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None,
357
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
357
358
  pruning_hash_load_factor: float = 0.5,
358
359
  use_array_for_index_remapping: bool = True,
359
360
  output_dtype: SparseType = SparseType.FP16,
@@ -372,7 +373,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
372
373
  cacheline_alignment: bool = True,
373
374
  uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
374
375
  reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
375
- feature_names_per_table: Optional[List[List[str]]] = None,
376
+ feature_names_per_table: Optional[list[list[str]]] = None,
376
377
  indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
377
378
  ) -> None: # noqa C901 # tuple of (rows, dims,)
378
379
  super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
@@ -405,14 +406,14 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
405
406
  self.indices_dtype = indices_dtype
406
407
  # (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
407
408
  # Pyre workaround
408
- self.feature_names: List[str] = [e[0] for e in embedding_specs]
409
+ self.feature_names: list[str] = [e[0] for e in embedding_specs]
409
410
  self.cache_load_factor: float = cache_load_factor
410
411
  self.cache_sets: int = cache_sets
411
412
  self.cache_reserved_memory: float = cache_reserved_memory
412
- rows: List[int] = [e[1] for e in embedding_specs]
413
- dims: List[int] = [e[2] for e in embedding_specs]
414
- weights_tys: List[SparseType] = [e[3] for e in embedding_specs]
415
- locations: List[EmbeddingLocation] = [e[4] for e in embedding_specs]
413
+ rows: list[int] = [e[1] for e in embedding_specs]
414
+ dims: list[int] = [e[2] for e in embedding_specs]
415
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
416
+ locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
416
417
  # if target device is meta then we set use_cpu based on the embedding location
417
418
  # information in embedding_specs.
418
419
  if self.current_device.type == "meta":
@@ -452,7 +453,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
452
453
  T_ = len(self.embedding_specs)
453
454
  assert T_ > 0
454
455
 
455
- self.feature_table_map: List[int] = (
456
+ self.feature_table_map: list[int] = (
456
457
  feature_table_map if feature_table_map is not None else list(range(T_))
457
458
  )
458
459
  T = len(self.feature_table_map)
@@ -635,6 +636,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
635
636
  self.fp8_exponent_bits = -1
636
637
  self.fp8_exponent_bias = -1
637
638
 
639
+ self.bounds_check_version: int = get_bounds_check_version_for_platform()
640
+
638
641
  @torch.jit.ignore
639
642
  def log(self, msg: str) -> None:
640
643
  """
@@ -673,7 +676,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
673
676
  return self.table_wise_cache_miss
674
677
 
675
678
  @torch.jit.export
676
- def get_feature_num_per_table(self) -> List[int]:
679
+ def get_feature_num_per_table(self) -> list[int]:
677
680
  if self.feature_names_per_table is None:
678
681
  return []
679
682
  return [len(feature_names) for feature_names in self.feature_names_per_table]
@@ -975,6 +978,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
975
978
  self.bounds_check_mode_int,
976
979
  self.bounds_check_warning,
977
980
  per_sample_weights,
981
+ bounds_check_version=self.bounds_check_version,
978
982
  )
979
983
 
980
984
  # Index remapping changes input indices, and some of them becomes -1 (prunned rows).
@@ -1017,6 +1021,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1017
1021
  self.bounds_check_mode_int,
1018
1022
  self.bounds_check_warning,
1019
1023
  per_sample_weights,
1024
+ bounds_check_version=self.bounds_check_version,
1020
1025
  )
1021
1026
  # Note: CPU and CUDA ops use the same interface to facilitate JIT IR
1022
1027
  # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
@@ -1206,8 +1211,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1206
1211
  dev_size: int,
1207
1212
  host_size: int,
1208
1213
  uvm_size: int,
1209
- placements: List[int],
1210
- offsets: List[int],
1214
+ placements: list[int],
1215
+ offsets: list[int],
1211
1216
  enforce_hbm: bool,
1212
1217
  ) -> None:
1213
1218
  assert not self.weight_initialized, "Weights have already been initialized."
@@ -1516,6 +1521,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1516
1521
  for i, weight in enumerate(weights):
1517
1522
  weights[i] = (
1518
1523
  weight[0].to(device),
1524
+ # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `to`.
1519
1525
  weight[1].to(device) if weight[1] is not None else None,
1520
1526
  )
1521
1527
  (
@@ -1596,7 +1602,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1596
1602
  @torch.jit.export
1597
1603
  def split_embedding_weights_with_scale_bias(
1598
1604
  self, split_scale_bias_mode: int = 1
1599
- ) -> List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
1605
+ ) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
1600
1606
  """
1601
1607
  Returns a list of weights, split by table
1602
1608
  split_scale_bias_mode:
@@ -1605,7 +1611,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1605
1611
  2: return weights, scale, bias.
1606
1612
  """
1607
1613
  assert self.weight_initialized
1608
- splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
1614
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
1609
1615
  for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
1610
1616
  placement = self.weights_physical_placements[t]
1611
1617
  if (
@@ -1730,12 +1736,12 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1730
1736
  # the second with scale_bias.
1731
1737
  # This should've been named as split_scale_bias.
1732
1738
  # Keep as is for backward compatibility.
1733
- ) -> List[Tuple[Tensor, Optional[Tensor]]]:
1739
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
1734
1740
  """
1735
1741
  Returns a list of weights, split by table
1736
1742
  """
1737
1743
  # fmt: off
1738
- splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
1744
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
1739
1745
  self.split_embedding_weights_with_scale_bias(
1740
1746
  split_scale_bias_mode=(1 if split_scale_shifts else 0)
1741
1747
  )
@@ -1773,7 +1779,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1773
1779
  )
1774
1780
 
1775
1781
  def assign_embedding_weights(
1776
- self, q_weight_list: List[Tuple[Tensor, Optional[Tensor]]]
1782
+ self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
1777
1783
  ) -> None:
1778
1784
  """
1779
1785
  Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
@@ -1785,6 +1791,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1785
1791
  dest_weight[0].copy_(input_weight[0])
1786
1792
  if input_weight[1] is not None:
1787
1793
  assert dest_weight[1] is not None
1794
+ # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `copy_`.
1788
1795
  dest_weight[1].copy_(input_weight[1])
1789
1796
  else:
1790
1797
  assert dest_weight[1] is None
@@ -1792,11 +1799,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1792
1799
  @torch.jit.export
1793
1800
  def set_index_remappings_array(
1794
1801
  self,
1795
- index_remapping: List[Tensor],
1802
+ index_remapping: list[Tensor],
1796
1803
  ) -> None:
1797
- rows: List[int] = [e[1] for e in self.embedding_specs]
1804
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1798
1805
  index_remappings_array_offsets = [0]
1799
- original_feature_rows = torch.jit.annotate(List[int], [])
1806
+ original_feature_rows = torch.jit.annotate(list[int], [])
1800
1807
  last_offset = 0
1801
1808
  for t, mapping in enumerate(index_remapping):
1802
1809
  if mapping is not None:
@@ -1835,11 +1842,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1835
1842
 
1836
1843
  def set_index_remappings(
1837
1844
  self,
1838
- index_remapping: List[Tensor],
1845
+ index_remapping: list[Tensor],
1839
1846
  pruning_hash_load_factor: float = 0.5,
1840
1847
  use_array_for_index_remapping: bool = True,
1841
1848
  ) -> None:
1842
- rows: List[int] = [e[1] for e in self.embedding_specs]
1849
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1843
1850
  T = len(self.embedding_specs)
1844
1851
  # Hash mapping pruning
1845
1852
  if not use_array_for_index_remapping:
@@ -1909,7 +1916,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1909
1916
  def _embedding_inplace_update_per_table(
1910
1917
  self,
1911
1918
  update_table_idx: int,
1912
- update_row_indices: List[int],
1919
+ update_row_indices: list[int],
1913
1920
  update_weights: Tensor,
1914
1921
  ) -> None:
1915
1922
  row_size = len(update_row_indices)
@@ -1934,9 +1941,9 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1934
1941
  @torch.jit.export
1935
1942
  def embedding_inplace_update(
1936
1943
  self,
1937
- update_table_indices: List[int],
1938
- update_row_indices: List[List[int]],
1939
- update_weights: List[Tensor],
1944
+ update_table_indices: list[int],
1945
+ update_row_indices: list[list[int]],
1946
+ update_weights: list[Tensor],
1940
1947
  ) -> None:
1941
1948
  for i in range(len(update_table_indices)):
1942
1949
  self._embedding_inplace_update_per_table(
@@ -1947,8 +1954,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1947
1954
 
1948
1955
  def embedding_inplace_update_internal(
1949
1956
  self,
1950
- update_table_indices: List[int],
1951
- update_row_indices: List[int],
1957
+ update_table_indices: list[int],
1958
+ update_row_indices: list[int],
1952
1959
  update_weights: Tensor,
1953
1960
  ) -> None:
1954
1961
  assert len(update_table_indices) == len(update_row_indices)