fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -8,10 +8,51 @@
8
8
  # pyre-strict
9
9
 
10
10
  import enum
11
+ import itertools
11
12
  from typing import Any, Dict # noqa: F401
12
13
 
13
14
  import torch
14
15
 
16
+ # fmt:skip
17
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
18
+ EmbeddingLocation,
19
+ SplitState,
20
+ )
21
+
22
+
23
+ def pad4(value: int) -> int:
24
+ """
25
+ Compute the smallest multiple of 4 that is greater than or equal to the given value.
26
+
27
+ Parameters:
28
+ value (int): The integer to align (must be non-negative).
29
+
30
+ Returns:
31
+ int: The aligned value.
32
+
33
+ Raises:
34
+ ValueError: If the input is negative.
35
+ TypeError: If the input is not an integer.
36
+ """
37
+ return (int(value) + 3) & ~3
38
+
39
+
40
+ def pad16(value: int) -> int:
41
+ """
42
+ Compute the smallest multiple of 16 that is greater than or equal to the given value.
43
+
44
+ Parameters:
45
+ value (int): The integer to align (must be non-negative).
46
+
47
+ Returns:
48
+ int: The aligned value.
49
+
50
+ Raises:
51
+ ValueError: If the input is negative.
52
+ TypeError: If the input is not an integer.
53
+ """
54
+ return (int(value) + 15) & ~15
55
+
15
56
 
16
57
  @enum.unique
17
58
  class EmbOptimType(enum.Enum):
@@ -40,6 +81,196 @@ class EmbOptimType(enum.Enum):
40
81
  def __str__(self) -> str:
41
82
  return self.value
42
83
 
84
+ def _extract_dtype(
85
+ self, optimizer_state_dtypes: dict[str, "SparseType"], name: str
86
+ ) -> torch.dtype:
87
+ if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
88
+ return torch.float32
89
+ return optimizer_state_dtypes[name].as_dtype()
90
+
91
+ def state_names(self) -> list[str]:
92
+ """
93
+ Returns the names of the optimizer states. The order of the states will
94
+ be the order in which they are processed and returned in
95
+ SSDTableBatchedEmbeddingBags.split_optimizer_states(), but this is not
96
+ necessarily the same as the order they are stored in the memory layout.
97
+ """
98
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
99
+ return ["momentum1"]
100
+ elif self in [EmbOptimType.PARTIAL_ROWWISE_ADAM, EmbOptimType.ADAM]:
101
+ return ["momentum1", "momentum2"]
102
+ else:
103
+ return []
104
+
105
+ def state_size_table(self, D: int) -> dict[str, int]:
106
+ """
107
+ Returns the table of state names to state sizes in terms of number of
108
+ elements (per table row)
109
+ """
110
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
111
+ return {"momentum1": 1}
112
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
113
+ return {"momentum1": D, "momentum2": 1}
114
+ elif self == EmbOptimType.ADAM:
115
+ return {"momentum1": D, "momentum2": D}
116
+ else:
117
+ return {}
118
+
119
+ def state_size_nbytes(
120
+ self,
121
+ D: int,
122
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
123
+ ) -> int:
124
+ """
125
+ Returns the size of the data (in bytes) required to hold the optimizer
126
+ state (per table row). This size includes byte-padding.
127
+ """
128
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
129
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
130
+
131
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
132
+ return momentum1_dtype.itemsize
133
+
134
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
135
+ return pad4(1 * momentum2_dtype.itemsize) + D * momentum1_dtype.itemsize
136
+
137
+ elif self == EmbOptimType.ADAM:
138
+ return (D * momentum1_dtype.itemsize) + (D * momentum2_dtype.itemsize)
139
+
140
+ else:
141
+ return 0
142
+
143
+ def byte_offsets_along_row(
144
+ self,
145
+ D: int,
146
+ weights_precision: "SparseType",
147
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
148
+ ) -> dict[str, tuple[int, int]]:
149
+ """
150
+ Returns the start and end byte offsets of each optimizer state along a
151
+ cache row with optimizer state offloading enabled.
152
+ """
153
+ # Extract the optimizer state dtypes
154
+ momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
155
+ momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
156
+
157
+ # This is the pointer to where the optimizer state begins in the memory
158
+ p0 = pad4(D) * weights_precision.as_dtype().itemsize
159
+
160
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
161
+ return {"momentum1": (p0, p0 + momentum1_dtype.itemsize)}
162
+
163
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
164
+ # momentum1 lies after momentum2
165
+ p1 = p0 + pad4(1 * momentum2_dtype.itemsize)
166
+ return {
167
+ "momentum2": (p0, p0 + momentum2_dtype.itemsize),
168
+ "momentum1": (
169
+ p1,
170
+ p1 + D * momentum1_dtype.itemsize,
171
+ ),
172
+ }
173
+
174
+ elif self == EmbOptimType.ADAM:
175
+ # momentum2 lies after momentum1
176
+ p1 = p0 + (D * momentum1_dtype.itemsize)
177
+
178
+ return {
179
+ "momentum1": (p0, p1),
180
+ "momentum2": (p1, p1 + D * momentum2_dtype.itemsize),
181
+ }
182
+
183
+ else:
184
+ return {}
185
+
186
+ def empty_states(
187
+ self,
188
+ rows: list[int],
189
+ dims: list[int],
190
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
191
+ ) -> list[list[torch.Tensor]]:
192
+ """
193
+ Creates sets of empty tensors per table to hold optimizer states based
194
+ on the specified optimizer type, state dtypes, embedding specs, and
195
+ (optionally) local row counts.
196
+ """
197
+ # Else, check that the local row count for each table is set
198
+ assert len(rows) == len(dims)
199
+
200
+ opt_states_set: list[list[torch.Tensor]] = []
201
+
202
+ for r, D in zip(rows, dims):
203
+ # Set up the table of state names to state sizes, ordered by their
204
+ # memory layout
205
+ state_size_table = self.state_size_table(D)
206
+ ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]
207
+
208
+ # Create the optimizer states for this table
209
+ opt_states_set.append(
210
+ [
211
+ torch.empty(
212
+ # If the state size is 1, then fix tensor to 1D to be
213
+ # consistent with training.py code
214
+ # pyre-ignore [6]
215
+ (r, d) if d > 1 else r,
216
+ dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
217
+ device="cpu",
218
+ )
219
+ for state_name, d in ordered_state_sizes
220
+ ]
221
+ )
222
+
223
+ return opt_states_set
224
+
225
+ def ssd_state_splits(
226
+ self,
227
+ embedding_specs: list[tuple[int, int]], # Tuple of (rows, dims)
228
+ optimizer_state_dtypes: dict[str, "SparseType"] = {}, # noqa: B006
229
+ enable_optimizer_offloading: bool = False,
230
+ ) -> list[tuple[SplitState, str, torch.dtype]]:
231
+ """
232
+ Returns the split planning for the optimizer states
233
+ """
234
+ rows, _ = zip(*embedding_specs)
235
+ T_ = len(embedding_specs)
236
+
237
+ # This is the cumulative row counts for rowwise states
238
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
239
+ # This is the cumulative element counts for elementwise states
240
+ table_size_cumsum: list[int] = [0] + list(
241
+ itertools.accumulate([r * d for r, d in embedding_specs])
242
+ )
243
+
244
+ if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
245
+ params = {"momentum1": row_count_cumsum}
246
+ elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
247
+ params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
248
+ elif self == EmbOptimType.ADAM:
249
+ params = {
250
+ "momentum1": table_size_cumsum,
251
+ "momentum2": table_size_cumsum,
252
+ "row_counter": row_count_cumsum,
253
+ }
254
+ else:
255
+ params = {}
256
+
257
+ return [
258
+ (
259
+ SplitState(
260
+ dev_size=(
261
+ cumsum_table[-1] if not enable_optimizer_offloading else 0
262
+ ),
263
+ host_size=0,
264
+ uvm_size=0,
265
+ placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
266
+ offsets=cumsum_table[:-1],
267
+ ),
268
+ name,
269
+ self._extract_dtype(optimizer_state_dtypes, name),
270
+ )
271
+ for (name, cumsum_table) in params.items()
272
+ ]
273
+
43
274
 
44
275
  # Base class for quantization configuration (in case other numeric types have
45
276
  # configs)
@@ -79,14 +310,54 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
79
310
  SparseType.BF16.value: 5,
80
311
  SparseType.FP8.value: 6,
81
312
  SparseType.MX4.value: 7,
313
+ SparseType.NFP8.value: 8,
82
314
  }[sparse_type.value]
83
315
 
84
316
 
317
+ def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
318
+ """
319
+ TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
320
+
321
+ This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
322
+ with TorchScript. TorchScript does not support @staticmethod on Enum classes,
323
+ so this function provides a workaround.
324
+ """
325
+ if ty == 0: # fp32
326
+ return torch.float32
327
+ elif ty == 1: # fp16
328
+ return torch.float16
329
+ elif ty == 2: # int8
330
+ return torch.uint8
331
+ elif ty == 3: # int4
332
+ return torch.quint4x2
333
+ elif ty == 4: # int2
334
+ return torch.quint2x4
335
+ elif ty == 5: # bf16
336
+ return torch.bfloat16
337
+ elif ty == 6: # fp8
338
+ return torch.uint8
339
+ elif ty == 7: # mx4
340
+ return torch.uint8
341
+ elif ty == 9:
342
+ return (
343
+ torch.float8_e4m3fnuz
344
+ if torch.version.hip is not None
345
+ else torch.float8_e4m3fn
346
+ )
347
+ else: # Invalid is 7 or non enumerated.
348
+ raise ValueError(f"Unsupported sparse type: {ty}")
349
+
350
+
85
351
  @enum.unique
86
352
  class SparseType(enum.Enum):
87
353
  FP32 = "fp32"
88
354
  FP16 = "fp16"
89
355
  FP8 = "fp8"
356
+ # NFP8 refers to "native" FP8 in that it uses the GPU implementations
357
+ # of E4M3 whereas the other FP8 sparsetype uses a custom format. Use of
358
+ # NFP8 allows us to use hardware casting intrinsics which can be much faster.
359
+ # Eventually, we should merge these two types.
360
+ NFP8 = "nfp8"
90
361
  INT8 = "int8"
91
362
  INT4 = "int4"
92
363
  INT2 = "int2"
@@ -112,9 +383,11 @@ class SparseType(enum.Enum):
112
383
  return SparseType("bf16")
113
384
  elif ty == 6:
114
385
  return SparseType("fp8")
115
- elif ty == 7:
386
+ elif ty == 8:
116
387
  return SparseType("mx4")
117
- else:
388
+ elif ty == 9:
389
+ return SparseType("nfp8")
390
+ else: # Invalid is 7 or non enumerated.
118
391
  raise ValueError(f"Unsupported sparse type: {ty}")
119
392
 
120
393
  def as_int(self) -> int:
@@ -136,6 +409,8 @@ class SparseType(enum.Enum):
136
409
  return SparseType("bf16")
137
410
  elif dtype == torch.uint8:
138
411
  return SparseType("mx4")
412
+ elif dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
413
+ return SparseType("nfp8")
139
414
  else:
140
415
  raise ValueError(f"Unsupported sparse dtype: {dtype}")
141
416
 
@@ -149,6 +424,11 @@ class SparseType(enum.Enum):
149
424
  SparseType.INT2.value: torch.quint2x4,
150
425
  SparseType.BF16.value: torch.bfloat16,
151
426
  SparseType.MX4.value: torch.uint8,
427
+ SparseType.NFP8.value: (
428
+ torch.float8_e4m3fnuz
429
+ if torch.version.hip is not None
430
+ else torch.float8_e4m3fn
431
+ ),
152
432
  }[self.value]
153
433
 
154
434
  def bit_rate(self) -> int:
@@ -161,6 +441,7 @@ class SparseType(enum.Enum):
161
441
  SparseType.INT2.value: 2,
162
442
  SparseType.BF16.value: 16,
163
443
  SparseType.MX4.value: 4,
444
+ SparseType.NFP8.value: 8,
164
445
  }[self.value]
165
446
 
166
447
  def align_size(self) -> int:
@@ -173,6 +454,7 @@ class SparseType(enum.Enum):
173
454
  SparseType.INT2.value: 16,
174
455
  SparseType.BF16.value: 2,
175
456
  SparseType.MX4.value: 8,
457
+ SparseType.NFP8.value: 4,
176
458
  }[self.value]
177
459
 
178
460
  def is_float(self) -> bool:
@@ -181,6 +463,7 @@ class SparseType(enum.Enum):
181
463
  or self.value == SparseType.FP16.value
182
464
  or self.value == SparseType.FP8.value
183
465
  or self.value == SparseType.BF16.value
466
+ or self.value == SparseType.NFP8.value
184
467
  ):
185
468
  return True
186
469
  else:
@@ -193,11 +476,12 @@ class SparseType(enum.Enum):
193
476
  return QuantizationConfig()
194
477
 
195
478
 
196
- ELEMENT_SIZE: Dict[SparseType, int] = {
479
+ ELEMENT_SIZE: dict[SparseType, int] = {
197
480
  SparseType.FP32: 4,
198
481
  SparseType.FP16: 2,
199
482
  SparseType.FP8: 1,
200
483
  SparseType.INT8: 1,
201
484
  SparseType.BF16: 2,
485
+ SparseType.NFP8: 1,
202
486
  # SparseType.INT4: 0.5,
203
487
  }
@@ -10,10 +10,11 @@
10
10
 
11
11
  import logging
12
12
  import math
13
- from typing import cast, Optional, Tuple
13
+ from typing import cast, Optional
14
14
 
15
15
  import torch
16
16
 
17
+ # fmt:skip
17
18
  from fbgemm_gpu.split_embedding_configs import (
18
19
  FP8QuantizationConfig,
19
20
  QuantizationConfig,
@@ -53,7 +54,7 @@ class SplitEmbInferenceConverter:
53
54
  return model
54
55
 
55
56
  # pyre-fixme[2]: Parameter must be annotated.
56
- def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> Tuple[Tensor, float]:
57
+ def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> tuple[Tensor, float]:
57
58
  assert new_num_rows > 0
58
59
  from numpy.linalg import norm
59
60
 
@@ -75,7 +76,7 @@ class SplitEmbInferenceConverter:
75
76
  idx: int,
76
77
  num_rows: int,
77
78
  module: SplitTableBatchedEmbeddingBagsCodegen,
78
- ) -> Tuple[Tensor, Optional[Tensor]]:
79
+ ) -> tuple[Tensor, Optional[Tensor]]:
79
80
  # TODO(yingz): Avoid DtoH / HtoD overhead.
80
81
  weights = module.split_embedding_weights()[idx].cpu()
81
82
  if self.pruning_ratio is None:
@@ -84,7 +85,7 @@ class SplitEmbInferenceConverter:
84
85
  if new_num_rows == num_rows:
85
86
  return (weights, None)
86
87
 
87
- (indicators, threshold) = self._prune_by_weights_l2_norm(new_num_rows, weights)
88
+ indicators, threshold = self._prune_by_weights_l2_norm(new_num_rows, weights)
88
89
 
89
90
  return torch.ops.fbgemm.embedding_bag_rowwise_prune(
90
91
  weights, indicators, threshold, torch.int32
@@ -100,7 +101,7 @@ class SplitEmbInferenceConverter:
100
101
 
101
102
  def _quantize_embs(
102
103
  self, weight: Tensor, weight_ty: SparseType
103
- ) -> Tuple[Tensor, Optional[Tensor]]:
104
+ ) -> tuple[Tensor, Optional[Tensor]]:
104
105
  fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
105
106
  return quantize_embs(weight, weight_ty, fp8_quant_config)
106
107
 
@@ -129,7 +130,7 @@ class SplitEmbInferenceConverter:
129
130
  index_remapping_list = []
130
131
  for t, (_, E, D, weight_ty) in enumerate(embedding_specs):
131
132
  # Try to prune embeddings.
132
- (pruned_weight, index_remapping) = self._prune_embs(t, E, child)
133
+ pruned_weight, index_remapping = self._prune_embs(t, E, child)
133
134
  new_embedding_specs.append(
134
135
  (
135
136
  "",
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/optimizer_args.py
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/optimizer_args.py"
8
+
7
9
  #!/usr/bin/env python3
8
10
  # Copyright (c) Meta Platforms, Inc. and affiliates.
9
11
  # All rights reserved.
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_optimizer_codegen.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_optimizer_codegen.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.