fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -9,10 +9,11 @@
9
9
 
10
10
  import enum
11
11
  import itertools
12
- from typing import Any, Dict, List, Optional, Tuple # noqa: F401
12
+ from typing import Any, Dict # noqa: F401
13
13
 
14
14
  import torch
15
15
 
16
+ # fmt:skip
16
17
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17
18
  EmbeddingLocation,
18
19
  SplitState,
@@ -36,6 +37,23 @@ def pad4(value: int) -> int:
36
37
  return (int(value) + 3) & ~3
37
38
 
38
39
 
40
+ def pad16(value: int) -> int:
41
+ """
42
+ Compute the smallest multiple of 16 that is greater than or equal to the given value.
43
+
44
+ Parameters:
45
+ value (int): The integer to align (must be non-negative).
46
+
47
+ Returns:
48
+ int: The aligned value.
49
+
50
+ Raises:
51
+ ValueError: If the input is negative.
52
+ TypeError: If the input is not an integer.
53
+ """
54
+ return (int(value) + 15) & ~15
55
+
56
+
39
57
  @enum.unique
40
58
  class EmbOptimType(enum.Enum):
41
59
  SGD = "sgd" # uses non-deterministic updates (atomicAdd(..)) with duplicate ids
@@ -64,13 +82,13 @@ class EmbOptimType(enum.Enum):
64
82
  return self.value
65
83
 
66
84
  def _extract_dtype(
67
- self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str
85
+ self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
68
86
  ) -> torch.dtype:
69
87
  if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
70
88
  return torch.float32
71
89
  return optimizer_state_dtypes[name].as_dtype()
72
90
 
73
- def state_names(self) -> List[str]:
91
+ def state_names(self) -> list[str]:
74
92
  """
75
93
  Returns the names of the optimizer states. The order of the states will
76
94
  be the order in which they are processed and returned in
@@ -79,12 +97,12 @@ class EmbOptimType(enum.Enum):
79
97
  """
80
98
  if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
81
99
  return ["momentum1"]
82
- elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
100
+ elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
83
101
  return ["momentum1", "momentum2"]
84
102
  else:
85
103
  return []
86
104
 
87
- def state_size_table(self, D: int) -> Dict[str, int]:
105
+ def state_size_table(self, D: int) -> dict[str, int]:
88
106
  """
89
107
  Returns the table of state names to state sizes in terms of number of
90
108
  elements (per table row)
@@ -93,64 +111,84 @@ class EmbOptimType(enum.Enum):
93
111
  return {"momentum1": 1}
94
112
  elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
95
113
  return {"momentum1": D, "momentum2": 1}
114
+ elif self == EmbOptimType.ADAM:
115
+ return {"momentum1": D, "momentum2": D}
96
116
  else:
97
117
  return {}
98
118
 
99
119
  def state_size_nbytes(
100
- self, D: int, optimizer_state_dtypes: Dict[str, "SparseType"] = {} # noqa: B006
120
+ self,
121
+ D: int,
122
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
101
123
  ) -> int:
102
124
  """
103
125
  Returns the size of the data (in bytes) required to hold the optimizer
104
- state (per table row)
126
+ state (per table row). This size includes byte-padding.
105
127
  """
106
- return sum(
107
- [
108
- # For each state, multiply the number of elements by the byte
109
- # size of each element
110
- (self._extract_dtype(optimizer_state_dtypes, name).itemsize * elem)
111
- for name, elem in self.state_size_table(D).items()
112
- ]
113
- )
128
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
129
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
130
+
131
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
132
+ return momentum1_dtype.itemsize
133
+
134
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
135
+ return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
136
+
137
+ elif self == EmbOptimType.ADAM:
138
+ return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
139
+
140
+ else:
141
+ return 0
114
142
 
115
143
  def byte_offsets_along_row(
116
144
  self,
117
145
  D: int,
118
146
  weights_precision: "SparseType",
119
- optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
120
- ) -> Dict[str, Tuple[int, int]]:
147
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
148
+ ) -> dict[str, tuple[int, int]]:
121
149
  """
122
150
  Returns the start and end byte offsets of each optimizer state along a
123
151
  cache row with optimizer state offloading enabled.
124
152
  """
153
+ # Extract the optimizer state dtypes
154
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
155
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
125
156
 
126
157
  # This is the pointer to where the optimizer state begins in the memory
127
158
  p0 = pad4(D) * weights_precision.as_dtype().itemsize
128
159
 
129
160
  if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
130
- momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
131
- # Store one value for momentum per row
132
161
  return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
133
162
 
134
163
  elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
135
- momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
136
- momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
164
+ # momentum1 lies after momentum2
165
+ p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
137
166
  return {
138
167
  "momentum2": (p0, p0 + momentum2_dtype.itemsize),
139
168
  "momentum1": (
140
- p0 + momentum2_dtype.itemsize,
141
- p0 + momentum2_dtype.itemsize + D * momentum1_dtype.itemsize,
169
+ p1,
170
+ p1 + D * momentum1_dtype.itemsize,
142
171
  ),
143
172
  }
144
173
 
174
+ elif self == EmbOptimType.ADAM:
175
+ # momentum2 lies after momentum1
176
+ p1 = p0 + (D * momentum1_dtype.itemsize)
177
+
178
+ return {
179
+ "momentum1": (p0, p1),
180
+ "momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
181
+ }
182
+
145
183
  else:
146
184
  return {}
147
185
 
148
186
  def empty_states(
149
187
  self,
150
- rows: List[int],
151
- dims: List[int],
152
- optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
153
- ) -> List[List[torch.Tensor]]:
188
+ rows: list[int],
189
+ dims: list[int],
190
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
191
+ ) -> list[list[torch.Tensor]]:
154
192
  """
155
193
  Creates sets of empty tensors per table to hold optimizer states based
156
194
  on the specified optimizer type, state dtypes, embedding specs, and
@@ -159,7 +197,7 @@ class EmbOptimType(enum.Enum):
159
197
  # Else, check that the local row count for each table is set
160
198
  assert len(rows) == len(dims)
161
199
 
162
- opt_states_set: List[List[torch.Tensor]] = []
200
+ opt_states_set: list[list[torch.Tensor]] = []
163
201
 
164
202
  for r, D in zip(rows, dims):
165
203
  # Set up the table of state names to state sizes, ordered by their
@@ -186,20 +224,20 @@ class EmbOptimType(enum.Enum):
186
224
 
187
225
  def ssd_state_splits(
188
226
  self,
189
- embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
190
- optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
227
+ embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
228
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
191
229
  enable_optimizer_offloading: bool = False,
192
- ) -> List[Tuple[SplitState, str, torch.dtype]]:
230
+ ) -> list[tuple[SplitState, str, torch.dtype]]:
193
231
  """
194
232
  Returns the split planning for the optimizer states
195
233
  """
196
- (rows, _) = zip(*embedding_specs)
234
+ rows, _ = zip(*embedding_specs)
197
235
  T_ = len(embedding_specs)
198
236
 
199
237
  # This is the cumulative row counts for rowwise states
200
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
238
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
201
239
  # This is the cumulative element counts for elementwise states
202
- table_size_cumsum: List[int] = [0] + list(
240
+ table_size_cumsum: list[int] = [0] + list(
203
241
  itertools.accumulate([r * d for r, d in embedding_specs])
204
242
  )
205
243
 
@@ -207,6 +245,12 @@ class EmbOptimType(enum.Enum):
207
245
  params = {"momentum1": row_count_cumsum}
208
246
  elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
209
247
  params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
248
+ elif self == EmbOptimType.ADAM:
249
+ params = {
250
+ "momentum1": table_size_cumsum,
251
+ "momentum2": table_size_cumsum,
252
+ "row_counter": row_count_cumsum,
253
+ }
210
254
  else:
211
255
  params = {}
212
256
 
@@ -266,14 +310,54 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
266
310
  SparseType.BF16.value: 5,
267
311
  SparseType.FP8.value: 6,
268
312
  SparseType.MX4.value: 7,
313
+ SparseType.NFP8.value: 8,
269
314
  }[sparse_type.value]
270
315
 
271
316
 
317
+ def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
318
+ """
319
+ TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
320
+
321
+ This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
322
+ with TorchScript. TorchScript does not support @staticmethod on Enum classes,
323
+ so this function provides a workaround.
324
+ """
325
+ if ty == 0: # fp32
326
+ return torch.float32
327
+ elif ty == 1: # fp16
328
+ return torch.float16
329
+ elif ty == 2: # int8
330
+ return torch.uint8
331
+ elif ty == 3: # int4
332
+ return torch.quint4x2
333
+ elif ty == 4: # int2
334
+ return torch.quint2x4
335
+ elif ty == 5: # bf16
336
+ return torch.bfloat16
337
+ elif ty == 6: # fp8
338
+ return torch.uint8
339
+ elif ty == 7: # mx4
340
+ return torch.uint8
341
+ elif ty == 9:
342
+ return (
343
+ torch.float8_e4m3fnuz
344
+ if torch.version.hip is not None
345
+ else torch.float8_e4m3fn
346
+ )
347
+ else: # Invalid is 7 or non enumerated.
348
+ raise ValueError(f"Unsupported sparse type: {ty}")
349
+
350
+
272
351
  @enum.unique
273
352
  class SparseType(enum.Enum):
274
353
  FP32 = "fp32"
275
354
  FP16 = "fp16"
276
355
  FP8 = "fp8"
356
+ # NFP8 refers to "native" FP8 in that it uses the GPU implementations
357
+ # of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
358
+ # NFP8 allows us to use hardware casting intrinsics which can be much faster.
359
+ # Eventually, we should merge these two types.
360
+ NFP8 = "nfp8"
277
361
  INT8 = "int8"
278
362
  INT4 = "int4"
279
363
  INT2 = "int2"
@@ -299,9 +383,11 @@ class SparseType(enum.Enum):
299
383
  return SparseType("bf16")
300
384
  elif ty == 6:
301
385
  return SparseType("fp8")
302
- elif ty == 7:
386
+ elif ty == 8:
303
387
  return SparseType("mx4")
304
- else:
388
+ elif ty == 9:
389
+ return SparseType("nfp8")
390
+ else: # Invalid is 7 or non enumerated.
305
391
  raise ValueError(f"Unsupported sparse type: {ty}")
306
392
 
307
393
  def as_int(self) -> int:
@@ -323,6 +409,8 @@ class SparseType(enum.Enum):
323
409
  return SparseType("bf16")
324
410
  elif dtype == torch.uint8:
325
411
  return SparseType("mx4")
412
+ elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
413
+ return SparseType("nfp8")
326
414
  else:
327
415
  raise ValueError(f"Unsupported sparse dtype: {dtype}")
328
416
 
@@ -336,6 +424,11 @@ class SparseType(enum.Enum):
336
424
  SparseType.INT2.value: torch.quint2x4,
337
425
  SparseType.BF16.value: torch.bfloat16,
338
426
  SparseType.MX4.value: torch.uint8,
427
+ SparseType.NFP8.value: (
428
+ torch.float8_e4m3fnuz
429
+ if torch.version.hip is not None
430
+ else torch.float8_e4m3fn
431
+ ),
339
432
  }[self.value]
340
433
 
341
434
  def bit_rate(self) -> int:
@@ -348,6 +441,7 @@ class SparseType(enum.Enum):
348
441
  SparseType.INT2.value: 2,
349
442
  SparseType.BF16.value: 16,
350
443
  SparseType.MX4.value: 4,
444
+ SparseType.NFP8.value: 8,
351
445
  }[self.value]
352
446
 
353
447
  def align_size(self) -> int:
@@ -360,6 +454,7 @@ class SparseType(enum.Enum):
360
454
  SparseType.INT2.value: 16,
361
455
  SparseType.BF16.value: 2,
362
456
  SparseType.MX4.value: 8,
457
+ SparseType.NFP8.value: 4,
363
458
  }[self.value]
364
459
 
365
460
  def is_float(self) -> bool:
@@ -368,6 +463,7 @@ class SparseType(enum.Enum):
368
463
  or self.value == SparseType.FP16.value
369
464
  or self.value == SparseType.FP8.value
370
465
  or self.value == SparseType.BF16.value
466
+ or self.value == SparseType.NFP8.value
371
467
  ):
372
468
  return True
373
469
  else:
@@ -380,11 +476,12 @@ class SparseType(enum.Enum):
380
476
  return QuantizationConfig()
381
477
 
382
478
 
383
- ELEMENT_SIZE: Dict[SparseType, int] = {
479
+ ELEMENT_SIZE: dict[SparseType, int] = {
384
480
  SparseType.FP32: 4,
385
481
  SparseType.FP16: 2,
386
482
  SparseType.FP8: 1,
387
483
  SparseType.INT8: 1,
388
484
  SparseType.BF16: 2,
485
+ SparseType.NFP8: 1,
389
486
  # SparseType.INT4: 0.5,
390
487
  }
@@ -10,10 +10,11 @@
10
10
 
11
11
  import logging
12
12
  import math
13
- from typing import cast, Optional, Tuple
13
+ from typing import cast, Optional
14
14
 
15
15
  import torch
16
16
 
17
+ # fmt:skip
17
18
  from fbgemm_gpu.split_embedding_configs import (
18
19
  FP8QuantizationConfig,
19
20
  QuantizationConfig,
@@ -53,7 +54,7 @@ class SplitEmbInferenceConverter:
53
54
  return model
54
55
 
55
56
  # pyre-fixme[2]: Parameter must be annotated.
56
- def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> Tuple[Tensor, float]:
57
+ def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
57
58
  assert new_num_rows > 0
58
59
  from numpy.linalg import norm
59
60
 
@@ -75,7 +76,7 @@ class SplitEmbInferenceConverter:
75
76
  idx: int,
76
77
  num_rows: int,
77
78
  module: SplitTableBatchedEmbeddingBagsCodegen,
78
- ) -> Tuple[Tensor, Optional[Tensor]]:
79
+ ) -> tuple[Tensor, Optional[Tensor]]:
79
80
  # TODO(yingz): Avoid DtoH / HtoD overhead.
80
81
  weights = module.split_embedding_weights()[idx].cpu()
81
82
  if self.pruning_ratio is None:
@@ -84,7 +85,7 @@ class SplitEmbInferenceConverter:
84
85
  if new_num_rows == num_rows:
85
86
  return (weights, None)
86
87
 
87
- (indicators, threshold) = self._prune_by_weights_l2_norm(new_num_rows, weights)
88
+ indicators, threshold = self._prune_by_weights_l2_norm(new_num_rows, weights)
88
89
 
89
90
  return torch.ops.fbgemm.embedding_bag_rowwise_prune(
90
91
  weights, indicators, threshold, torch.int32
@@ -100,7 +101,7 @@ class SplitEmbInferenceConverter:
100
101
 
101
102
  def _quantize_embs(
102
103
  self, weight: Tensor, weight_ty: SparseType
103
- ) -> Tuple[Tensor, Optional[Tensor]]:
104
+ ) -> tuple[Tensor, Optional[Tensor]]:
104
105
  fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
105
106
  return quantize_embs(weight, weight_ty, fp8_quant_config)
106
107
 
@@ -129,7 +130,7 @@ class SplitEmbInferenceConverter:
129
130
  index_remapping_list = []
130
131
  for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
131
132
  # Try to prune embeddings.
132
- (pruned_weight, index_remapping) = self._prune_embs(t, E, child)
133
+ pruned_weight, index_remapping = self._prune_embs(t, E, child)
133
134
  new_embedding_specs.append(
134
135
  (
135
136
  "",
@@ -11,12 +11,11 @@
11
11
 
12
12
  import enum
13
13
  from dataclasses import dataclass
14
- from typing import List, NamedTuple, Optional, Tuple
14
+ from typing import FrozenSet, NamedTuple, Optional, Tuple
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
18
18
 
19
-
20
19
  # Maximum number of times prefetch() can be called without
21
20
  # a corresponding forward() call
22
21
  MAX_PREFETCH_DEPTH = 100
@@ -62,10 +61,10 @@ class EmbeddingLocation(enum.IntEnum):
62
61
 
63
62
  class EvictionPolicy(NamedTuple):
64
63
  eviction_trigger_mode: int = (
65
- 0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
64
+ 0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
66
65
  )
67
66
  eviction_strategy: int = (
68
- 0 # 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
67
+ 0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
69
68
  )
70
69
  eviction_step_intervals: Optional[int] = (
71
70
  None # trigger_step_interval if trigger mode is iteration
@@ -73,18 +72,33 @@ class EvictionPolicy(NamedTuple):
73
72
  eviction_mem_threshold_gb: Optional[int] = (
74
73
  None # eviction trigger condition if trigger mode is mem_util
75
74
  )
76
- counter_thresholds: Optional[List[int]] = (
77
- None # count_thresholds for each table if eviction strategy is feature score
75
+ counter_thresholds: Optional[list[int]] = (
76
+ None # count_thresholds for each table if eviction strategy is counter
78
77
  )
79
- ttls_in_mins: Optional[List[int]] = (
78
+ ttls_in_mins: Optional[list[int]] = (
80
79
  None # ttls_in_mins for each table if eviction strategy is timestamp
81
80
  )
82
- counter_decay_rates: Optional[List[float]] = (
83
- None # count_decay_rates for each table if eviction strategy is feature score
81
+ counter_decay_rates: Optional[list[float]] = (
82
+ None # count_decay_rates for each table if eviction strategy is counter
83
+ )
84
+ feature_score_counter_decay_rates: Optional[list[float]] = (
85
+ None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
86
+ )
87
+ training_id_eviction_trigger_count: Optional[list[int]] = (
88
+ None # Number of training IDs that, when exceeded, will trigger eviction for each table.
84
89
  )
85
- l2_weight_thresholds: Optional[List[float]] = (
90
+ training_id_keep_count: Optional[list[int]] = (
91
+ None # Target number of training IDs to retain in each table after eviction.
92
+ )
93
+ l2_weight_thresholds: Optional[list[float]] = (
86
94
  None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
87
95
  )
96
+ threshold_calculation_bucket_stride: Optional[float] = (
97
+ 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
98
+ )
99
+ threshold_calculation_bucket_num: Optional[int] = (
100
+ 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
101
+ )
88
102
  interval_for_insufficient_eviction_s: int = (
89
103
  # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
90
104
  # insufficient means we didn't evict enough rows, so we want to wait longer time to
@@ -95,18 +109,30 @@ class EvictionPolicy(NamedTuple):
95
109
  # wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
96
110
  60
97
111
  )
98
- meta_header_lens: Optional[List[int]] = None # metaheader length for each table
112
+ interval_for_feature_statistics_decay_s: int = (
113
+ 24 * 3600 # 1 day, interval for feature statistics decay
114
+ )
115
+ meta_header_lens: Optional[list[int]] = None # metaheader length for each table
116
+ eviction_free_mem_threshold_gb: Optional[int] = (
117
+ None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
118
+ )
119
+ eviction_free_mem_check_interval_batch: Optional[int] = (
120
+ None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
121
+ )
122
+ enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
123
+ None # enable eviction if eviction policy is feature score, false means no eviction
124
+ )
99
125
 
100
126
  def validate(self) -> None:
101
- assert self.eviction_trigger_mode in [0, 1, 2, 3], (
102
- "eviction_trigger_mode must be 0, 1, 2, or 3, "
127
+ assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
128
+ "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
103
129
  f"actual {self.eviction_trigger_mode}"
104
130
  )
105
131
  if self.eviction_trigger_mode == 0:
106
132
  return
107
133
 
108
- assert self.eviction_strategy in [0, 1, 2, 3], (
109
- "eviction_strategy must be 0, 1, 2, or 3, "
134
+ assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
135
+ "eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
110
136
  f"actual {self.eviction_strategy}"
111
137
  )
112
138
  if self.eviction_trigger_mode == 1:
@@ -121,6 +147,17 @@ class EvictionPolicy(NamedTuple):
121
147
  assert (
122
148
  self.eviction_mem_threshold_gb is not None
123
149
  ), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
150
+ elif self.eviction_trigger_mode == 4:
151
+ assert (
152
+ self.training_id_eviction_trigger_count is not None
153
+ ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
154
+ elif self.eviction_trigger_mode == 5:
155
+ assert (
156
+ self.eviction_free_mem_threshold_gb is not None
157
+ ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
158
+ assert (
159
+ self.eviction_free_mem_check_interval_batch is not None
160
+ ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
124
161
 
125
162
  if self.eviction_strategy == 0:
126
163
  assert self.ttls_in_mins is not None, (
@@ -161,21 +198,58 @@ class EvictionPolicy(NamedTuple):
161
198
  "counter_thresholds and ttls_in_mins must have the same length, "
162
199
  f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
163
200
  )
201
+ elif self.eviction_strategy == 5:
202
+ assert self.feature_score_counter_decay_rates is not None, (
203
+ "feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
204
+ f"actual {self.feature_score_counter_decay_rates}"
205
+ )
206
+ assert self.training_id_eviction_trigger_count is not None, (
207
+ "training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
208
+ f"actual {self.training_id_eviction_trigger_count}"
209
+ )
210
+ assert self.training_id_keep_count is not None, (
211
+ "training_id_keep_count must be set if eviction_strategy is 5,"
212
+ f"actual {self.training_id_keep_count}"
213
+ )
214
+ assert self.threshold_calculation_bucket_stride is not None, (
215
+ "threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
216
+ f"actual {self.threshold_calculation_bucket_stride}"
217
+ )
218
+ assert self.threshold_calculation_bucket_num is not None, (
219
+ "threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
220
+ f"actual {self.threshold_calculation_bucket_num}"
221
+ )
222
+ assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
223
+ "enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
224
+ f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
225
+ )
226
+ assert (
227
+ len(self.enable_eviction_for_feature_score_eviction_policy)
228
+ == len(self.training_id_keep_count)
229
+ == len(self.feature_score_counter_decay_rates)
230
+ ), (
231
+ "feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
232
+ f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
233
+ )
164
234
 
165
235
 
166
236
  class KVZCHParams(NamedTuple):
167
237
  # global bucket id start and global bucket id end offsets for each logical table,
168
238
  # where start offset is inclusive and end offset is exclusive
169
- bucket_offsets: List[Tuple[int, int]] = []
239
+ bucket_offsets: list[tuple[int, int]] = []
170
240
  # bucket size for each logical table
171
241
  # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
172
- bucket_sizes: List[int] = []
242
+ bucket_sizes: list[int] = []
173
243
  # enable optimizer offloading or not
174
244
  enable_optimizer_offloading: bool = False
175
245
  # when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
176
246
  # can only be enabled when enable_optimizer_offloading is enabled
177
247
  backend_return_whole_row: bool = False
178
248
  eviction_policy: EvictionPolicy = EvictionPolicy()
249
+ embedding_cache_mode: bool = False
250
+ load_ckpt_without_opt: bool = False
251
+ optimizer_type_for_st: Optional[str] = None
252
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
179
253
 
180
254
  def validate(self) -> None:
181
255
  assert len(self.bucket_offsets) == len(self.bucket_sizes), (
@@ -188,6 +262,25 @@ class KVZCHParams(NamedTuple):
188
262
  ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
189
263
 
190
264
 
265
+ class KVZCHTBEConfig(NamedTuple):
266
+ # Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
267
+ kvzch_eviction_trigger_mode: int = 2 # mem_util
268
+ # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
269
+ eviction_free_mem_threshold_gb: int = 200 # 200GB
270
+ # Number of batches between checks for free memory threshold when using free_mem trigger mode.
271
+ eviction_free_mem_check_interval_batch: int = 1000
272
+ # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
273
+ threshold_calculation_bucket_stride: float = 0.2
274
+ # Total number of feature score buckets used for threshold calculation in feature score-based eviction.
275
+ threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
276
+ # When true, we only save weight to kvzch backend and not optimizer state.
277
+ load_ckpt_without_opt: bool = False
278
+ # [DO NOT USE] This is for st publish only, do not set it in your config
279
+ optimizer_type_for_st: Optional[str] = None
280
+ # [DO NOT USE] This is for st publish only, do not set it in your config
281
+ optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
282
+
283
+
191
284
  class BackendType(enum.IntEnum):
192
285
  SSD = 0
193
286
  DRAM = 1
@@ -288,8 +381,8 @@ SplitState: NamedTuple = NamedTuple(
288
381
  ("dev_size", int),
289
382
  ("host_size", int),
290
383
  ("uvm_size", int),
291
- ("placements", List[EmbeddingLocation]),
292
- ("offsets", List[int]),
384
+ ("placements", list[EmbeddingLocation]),
385
+ ("offsets", list[int]),
293
386
  ],
294
387
  )
295
388
 
@@ -297,15 +390,15 @@ SplitState: NamedTuple = NamedTuple(
297
390
  @dataclass
298
391
  class CacheState:
299
392
  # T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
300
- cache_hash_size_cumsum: List[int]
301
- cache_index_table_map: List[int]
393
+ cache_hash_size_cumsum: list[int]
394
+ cache_index_table_map: list[int]
302
395
  total_cache_hash_size: int
303
396
 
304
397
 
305
398
  def construct_cache_state(
306
- row_list: List[int],
307
- location_list: List[EmbeddingLocation],
308
- feature_table_map: List[int],
399
+ row_list: list[int],
400
+ location_list: list[EmbeddingLocation],
401
+ feature_table_map: list[int],
309
402
  ) -> CacheState:
310
403
  _cache_hash_size_cumsum = [0]
311
404
  total_cache_hash_size = 0