fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,1455 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import math
10
+ from collections.abc import Sequence
11
+ from typing import Callable, Optional
12
+
13
+ import torch
14
+
15
+ from fbgemm_gpu.split_embedding_configs import SparseType
16
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
17
+ from fbgemm_gpu.utils.loader import load_torch_module
18
+
19
+ try:
20
+ # pyre-ignore
21
+ from fbgemm_gpu import open_source # noqa: F401
22
+ except Exception:
23
+ load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")
24
+ load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25
+
26
+ if torch.version.hip:
27
+ torch.ops.load_library(
28
+ "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip"
29
+ )
30
+
31
+ else:
32
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
33
+
34
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine")
35
+
36
+ torch.ops.load_library(
37
+ "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
38
+ )
39
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
40
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")
41
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
42
+ torch.ops.load_library(
43
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
44
+ )
45
+ torch.ops.load_library(
46
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
47
+ )
48
+
49
+
50
+ import torch.utils._pytree as pytree
51
+ from torch import SymInt, Tensor
52
+ from torch.fx.experimental.symbolic_shapes import guard_or_true
53
+
54
+
55
+ if hasattr(torch.library, "register_fake"):
56
+ # pyre-ignore[9]
57
+ impl_abstract = torch.library.register_fake
58
+ elif hasattr(torch.library, "impl_abstract"):
59
+ impl_abstract = torch.library.impl_abstract
60
+ else:
61
+ # pyre-ignore
62
+ def impl_abstract(schema: str) -> Callable[[Callable], Callable]:
63
+ # no-op
64
+ # pyre-ignore
65
+ def wrapper(f: Callable) -> Callable:
66
+ return f
67
+
68
+ return wrapper
69
+
70
+
71
+ def permute_2D_sparse_data_input1D_meta(
72
+ permute: Tensor,
73
+ lengths: Tensor,
74
+ values: Tensor,
75
+ stride: int,
76
+ weights: Optional[Tensor] = None,
77
+ permuted_lengths_sum: Optional[int] = None,
78
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
79
+ torch._check(
80
+ lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
81
+ )
82
+ T = permute.numel()
83
+ B = stride
84
+ indices = values
85
+ permuted_lengths = lengths.new_empty([T * B])
86
+ permuted_indices_size = 0
87
+ if permuted_lengths_sum is not None:
88
+ permuted_indices_size = permuted_lengths_sum
89
+ else:
90
+ ctx = torch.library.get_ctx()
91
+ permuted_indices_size = ctx.new_dynamic_size()
92
+ # pyre-fixme
93
+ permuted_indices = indices.new_empty(permuted_indices_size)
94
+ permuted_weights = None
95
+ if weights is not None:
96
+ # pyre-fixme
97
+ permuted_weights = weights.new_empty(permuted_indices_size)
98
+ return permuted_lengths, permuted_indices, permuted_weights
99
+
100
+
101
+ # pyre-ignore
102
+ def permute_2D_sparse_data_input1D_setup_context(ctx, inputs, output):
103
+ permute, lengths, values, stride, weights, permuted_lengths_sum = inputs
104
+ permuted_lengths, permuted_values, permuted_weights = output
105
+ ctx.permute = permute
106
+ ctx.permuted_lengths = permuted_lengths
107
+ ctx.stride = stride
108
+
109
+
110
+ def permute_2D_sparse_data_input1D_backward(
111
+ ctx, # pyre-ignore
112
+ grad_lengths: torch.Tensor,
113
+ grad_values: torch.Tensor,
114
+ grad_weights: torch.Tensor,
115
+ ) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
116
+ inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
117
+ permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
118
+ torch.ops.fbgemm.permute_2D_sparse_data_input1D(
119
+ inv_permute,
120
+ ctx.permuted_lengths,
121
+ grad_values,
122
+ ctx.stride,
123
+ grad_weights,
124
+ None,
125
+ )
126
+ )
127
+ return (
128
+ None,
129
+ permuted_grad_lengths,
130
+ permuted_grad_values,
131
+ None,
132
+ permuted_grad_weights,
133
+ None,
134
+ )
135
+
136
+
137
+ def permute_2D_sparse_data_meta(
138
+ permute: Tensor,
139
+ lengths: Tensor,
140
+ values: Tensor,
141
+ weights: Optional[Tensor] = None,
142
+ permuted_lengths_sum: Optional[int] = None,
143
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
144
+ torch._check(
145
+ lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
146
+ )
147
+ T = permute.numel()
148
+ B = lengths.size(1)
149
+ indices = values
150
+ permuted_lengths = lengths.new_empty([T, B])
151
+ permuted_indices_size = 0
152
+ if permuted_lengths_sum is not None:
153
+ permuted_indices_size = permuted_lengths_sum
154
+ else:
155
+ ctx = torch.library.get_ctx()
156
+ permuted_indices_size = ctx.new_dynamic_size()
157
+ # pyre-fixme
158
+ permuted_indices = indices.new_empty(permuted_indices_size)
159
+ permuted_weights = None
160
+ if weights is not None:
161
+ # pyre-fixme
162
+ permuted_weights = weights.new_empty(permuted_indices_size)
163
+ return permuted_lengths, permuted_indices, permuted_weights
164
+
165
+
166
+ def invert_permute_abstract(permute: Tensor) -> Tensor:
167
+ return torch.empty_like(permute)
168
+
169
+
170
+ def get_source_mask_meta(
171
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
172
+ ) -> Tensor:
173
+ if output_size is None:
174
+ ctx = torch.library.get_ctx()
175
+ output_size = ctx.new_dynamic_size()
176
+ return torch.empty([output_size], dtype=torch.bool)
177
+
178
+
179
+ def get_source_mask(
180
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
181
+ ) -> Tensor:
182
+ """
183
+ Generate a boolean mask indicating which elements are from sources vs targets.
184
+
185
+ This is a Python wrapper that computes output_size when not provided,
186
+ enabling the operation to work with meta tensors for compilation.
187
+
188
+ Args:
189
+ num_sources: 1D tensor of source counts per batch element
190
+ num_targets: 1D tensor of target counts per batch element
191
+ output_size: Optional pre-computed output size.
192
+
193
+ Returns:
194
+ A 1D boolean tensor where True indicates source elements and False
195
+ indicates target elements
196
+
197
+ Example:
198
+ >>> num_sources = torch.tensor([2, 3])
199
+ >>> num_targets = torch.tensor([1, 2])
200
+ >>> get_source_mask(num_sources, num_targets)
201
+ tensor([True, True, False, True, True, True, False, False])
202
+ """
203
+ # Compute output_size if not provided and tensors are regular (not meta/fake)
204
+ if output_size is None:
205
+ combined = num_sources + num_targets
206
+ output_size = int(combined.sum().item())
207
+
208
+ return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
209
+
210
+
211
+ # pyre-ignore
212
+ def permute_2D_sparse_data_setup_context(ctx, inputs, output):
213
+ permute, lengths, values, weights, permuted_lengths_sum = inputs
214
+ permuted_lengths, permuted_values, permuted_weights = output
215
+ ctx.permute = permute
216
+ ctx.permuted_lengths = permuted_lengths
217
+
218
+
219
+ # pyre-ignore
220
+ def permute_2D_sparse_data_backward(ctx, grad_lengths, grad_values, grad_weights):
221
+ inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
222
+ permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
223
+ torch.ops.fbgemm.permute_2D_sparse_data(
224
+ inv_permute, ctx.permuted_lengths, grad_values, grad_weights
225
+ )
226
+ )
227
+ return (
228
+ None,
229
+ permuted_grad_lengths,
230
+ permuted_grad_values,
231
+ permuted_grad_weights,
232
+ None,
233
+ )
234
+
235
+
236
+ def permute_1D_sparse_data_meta(
237
+ permute: Tensor,
238
+ lengths: Tensor,
239
+ values: Tensor,
240
+ weights: Optional[Tensor] = None,
241
+ permuted_lengths_sum: Optional[int] = None,
242
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
243
+ indices = values
244
+ permuted_lengths_size = permute.numel()
245
+ permuted_lengths = lengths.new_empty([permuted_lengths_size])
246
+ permuted_indices_size = 0
247
+ if permuted_lengths_sum is not None:
248
+ permuted_indices_size = permuted_lengths_sum
249
+ else:
250
+ ctx = torch.library.get_ctx()
251
+ permuted_indices_size = ctx.new_dynamic_size()
252
+ # pyre-fixme
253
+ permuted_indices = indices.new_empty(permuted_indices_size)
254
+ permuted_weights = None
255
+ if weights is not None:
256
+ # pyre-fixme
257
+ permuted_weights = weights.new_empty(permuted_indices_size)
258
+ return permuted_lengths, permuted_indices, permuted_weights
259
+
260
+
261
+ def masked_select_jagged_1d(
262
+ values: Tensor, lengths: Tensor, mask: Tensor
263
+ ) -> tuple[Tensor, Tensor]:
264
+ torch._check(values.dim() == 1)
265
+ torch._check(lengths.dim() == 1)
266
+ torch._check(values.device == lengths.device)
267
+ torch._check(values.device == mask.device)
268
+
269
+ s0 = torch.library.get_ctx().new_dynamic_size()
270
+ masked_values = values.new_empty([s0])
271
+ masked_lengths = torch.empty_like(lengths)
272
+ return masked_values, masked_lengths
273
+
274
+
275
+ def tbe_input_combine_abstract(
276
+ indices_list: list[Tensor],
277
+ offsets_list: list[Tensor],
278
+ per_sample_weights: list[Tensor],
279
+ include_last_offsets: Tensor,
280
+ ) -> tuple[Tensor, Tensor, Tensor]:
281
+ torch._check(len(indices_list) > 0)
282
+ torch._check(len(indices_list) == len(offsets_list))
283
+ torch._check(len(indices_list) == len(per_sample_weights))
284
+ torch._check(len(indices_list) == include_last_offsets.numel())
285
+ total_indices = 0
286
+ need_weight = False
287
+ for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights):
288
+ torch._check(index.dtype == torch.int or index.dtype == torch.long)
289
+ torch._check(offset.dtype == torch.int or offset.dtype == torch.long)
290
+ torch._check(index.dim() == 1)
291
+ torch._check(offset.dim() == 1)
292
+ torch._check(index.is_contiguous())
293
+ torch._check(offset.is_contiguous())
294
+ total_indices = total_indices + index.numel()
295
+ if guard_or_true(weight.numel() > 0):
296
+ torch._check(weight.dim() == 1)
297
+ torch._check(weight.numel() == index.numel())
298
+ torch._check(weight.is_contiguous())
299
+ need_weight = True
300
+ total_offsets = torch.library.get_ctx().new_dynamic_size()
301
+ combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int)
302
+ combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int)
303
+ if need_weight:
304
+ combined_weights = per_sample_weights[0].new_empty(
305
+ [total_indices], dtype=torch.float
306
+ )
307
+ else:
308
+ combined_weights = torch.empty(0)
309
+ return combined_indices, combined_offsets, combined_weights
310
+
311
+
312
+ def tbe_input_combine_with_length_abstract(
313
+ indices_list: list[Tensor],
314
+ offsets_list: list[Tensor],
315
+ per_sample_weights: list[Tensor],
316
+ ) -> tuple[Tensor, Tensor, Tensor]:
317
+ torch._check(len(indices_list) > 0)
318
+ torch._check(len(indices_list) == len(offsets_list))
319
+ torch._check(len(indices_list) == len(per_sample_weights))
320
+ total_indices = 0
321
+ total_offsets = 0
322
+ need_weight = False
323
+ for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights):
324
+ torch._check(index.dtype == torch.int or index.dtype == torch.long)
325
+ torch._check(offset.dtype == torch.int or offset.dtype == torch.long)
326
+ torch._check(index.dim() == 1)
327
+ torch._check(offset.dim() == 1)
328
+ torch._check(index.is_contiguous())
329
+ torch._check(offset.is_contiguous())
330
+ total_indices = total_indices + index.numel()
331
+ total_offsets = total_offsets + offset.numel()
332
+ if guard_or_true(weight.numel() > 0):
333
+ torch._check(weight.dim() == 1)
334
+ torch._check(weight.numel() == index.numel())
335
+ torch._check(weight.is_contiguous())
336
+ need_weight = True
337
+ combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int)
338
+ combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int)
339
+ if need_weight:
340
+ combined_weights = per_sample_weights[0].new_empty(
341
+ [total_indices], dtype=torch.float
342
+ )
343
+ else:
344
+ combined_weights = torch.empty(0, device=indices_list[0].device)
345
+ return combined_indices, combined_offsets, combined_weights
346
+
347
+
348
+ def jagged_index_select_2d_forward_v2_abstract(
349
+ values: Tensor,
350
+ indices: Tensor,
351
+ input_offsets: Tensor,
352
+ output_offsets: Tensor,
353
+ num_dense_output_rows: Optional[int] = None,
354
+ ) -> Tensor:
355
+ torch._check(values.device == indices.device)
356
+ torch._check(values.device == input_offsets.device)
357
+ torch._check(values.device == output_offsets.device)
358
+ torch._check(values.dim() == 2)
359
+ dynamic_num_dense_output_rows = torch.library.get_ctx().new_dynamic_size()
360
+ num_cols = values.size(1)
361
+ return values.new_empty([dynamic_num_dense_output_rows, num_cols])
362
+
363
+
364
+ def jagged_index_add_2d_forward_v2_abstract(
365
+ values: Tensor,
366
+ indices: Tensor,
367
+ input_offsets: Tensor,
368
+ output_offsets: Tensor,
369
+ num_output_rows: int,
370
+ num_dense_input_rows: Optional[int] = None,
371
+ ) -> Tensor:
372
+ torch._check(values.device == indices.device)
373
+ torch._check(values.device == input_offsets.device)
374
+ torch._check(values.device == output_offsets.device)
375
+ torch._check(values.dim() == 2)
376
+ num_cols = values.size(1)
377
+ return values.new_empty([num_output_rows, num_cols])
378
+
379
+
380
+ def expand_into_jagged_permute_meta(
381
+ permute: Tensor,
382
+ input_offsets: Tensor,
383
+ output_offsets: Tensor,
384
+ output_size: tuple[int, ...],
385
+ ) -> Tensor:
386
+ torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
387
+ torch._check(
388
+ permute.numel() == input_offsets.numel() - 1,
389
+ lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1",
390
+ )
391
+ torch._check(
392
+ permute.numel() == output_offsets.numel() - 1,
393
+ lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1",
394
+ )
395
+ output_permute = input_offsets.new_empty(output_size)
396
+ return output_permute
397
+
398
+
399
+ def check_all_same_device(*tensors: Optional[Tensor]) -> None:
400
+ # pyre-ignore[9]
401
+ tensors, _ = pytree.tree_flatten(tensors)
402
+ if len(tensors) == 0:
403
+ return
404
+ if all(t.device.type in ["cpu", "meta"] for t in tensors if t is not None):
405
+ return
406
+ first_tensor: Optional[Tensor] = None
407
+ for tensor in tensors:
408
+ if tensor is None:
409
+ continue
410
+ if first_tensor is None:
411
+ first_tensor = tensor
412
+ torch._check(tensor.device == first_tensor.device)
413
+
414
+
415
+ def pruned_array_lookup_meta(
416
+ indices: Tensor,
417
+ offsets: Tensor,
418
+ index_remappings: Tensor,
419
+ index_remappings_offsets: Tensor,
420
+ ) -> Tensor:
421
+ check_all_same_device(indices, offsets, index_remappings, index_remappings_offsets)
422
+ return indices.new_empty(indices.shape)
423
+
424
+
425
+ def int_nbit_split_embedding_codegen_lookup_function_meta(
426
+ dev_weights: torch.Tensor,
427
+ uvm_weights: torch.Tensor,
428
+ weights_placements: torch.Tensor,
429
+ weights_offsets: torch.Tensor,
430
+ weights_tys: torch.Tensor,
431
+ D_offsets: torch.Tensor,
432
+ total_D: int,
433
+ max_int2_D: int,
434
+ max_int4_D: int,
435
+ max_int8_D: int,
436
+ max_float16_D: int,
437
+ max_float32_D: int,
438
+ indices: torch.Tensor,
439
+ offsets: torch.Tensor,
440
+ pooling_mode: int,
441
+ indice_weights: Optional[torch.Tensor] = None,
442
+ output_dtype_int: int = 1,
443
+ lxu_cache_weights: Optional[torch.Tensor] = None,
444
+ lxu_cache_locations: Optional[torch.Tensor] = None,
445
+ row_alignment: Optional[int] = None,
446
+ max_float8_D: Optional[int] = None,
447
+ fp8_exponent_bits: Optional[int] = None,
448
+ fp8_exponent_bias: Optional[int] = None,
449
+ ) -> Tensor:
450
+ check_all_same_device(
451
+ dev_weights,
452
+ uvm_weights,
453
+ weights_placements,
454
+ weights_offsets,
455
+ weights_tys,
456
+ D_offsets,
457
+ indices,
458
+ offsets,
459
+ indice_weights,
460
+ )
461
+ output_dtype = SparseType.from_int(output_dtype_int).as_dtype()
462
+ kINT8QparamsBytes = 8
463
+
464
+ if pooling_mode == PoolingMode.NONE:
465
+ kINT8QparamsBytes = 4
466
+ D = max(
467
+ [
468
+ max_int2_D,
469
+ max_int4_D,
470
+ max_int8_D,
471
+ max_float16_D,
472
+ max_float32_D,
473
+ max_float8_D if max_float8_D is not None else 0,
474
+ ]
475
+ )
476
+ total_L = indices.numel()
477
+ T = weights_offsets.numel()
478
+ torch._check(D > 0)
479
+ adjusted_D = D
480
+ if SparseType.from_int(output_dtype_int) == SparseType.INT8:
481
+ adjusted_D += kINT8QparamsBytes
482
+ output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
483
+ return output
484
+
485
+ T = D_offsets.numel() - 1
486
+ torch._check(T > 0)
487
+ torch._check(total_D > 0)
488
+ B = (offsets.size(0) - 1) // T
489
+ total_adjusted_D = total_D
490
+ if SparseType.from_int(output_dtype_int) == SparseType.INT8:
491
+ total_adjusted_D += T * kINT8QparamsBytes
492
+ output = dev_weights.new_empty([B, total_adjusted_D], dtype=output_dtype)
493
+ return output
494
+
495
+
496
+ def block_bucketize_sparse_features_meta(
497
+ lengths: torch.Tensor,
498
+ indices: torch.Tensor,
499
+ bucketize_pos: bool,
500
+ sequence: bool,
501
+ block_sizes: torch.Tensor,
502
+ my_size: int,
503
+ weights: Optional[torch.Tensor] = None,
504
+ batch_size_per_feature: Optional[torch.Tensor] = None,
505
+ max_B: int = -1,
506
+ block_bucketize_pos: Optional[torch.Tensor] = None,
507
+ keep_orig_idx: bool = False,
508
+ total_num_blocks: Optional[torch.Tensor] = None,
509
+ keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
510
+ ) -> tuple[
511
+ torch.Tensor,
512
+ torch.Tensor,
513
+ Optional[torch.Tensor],
514
+ Optional[torch.Tensor],
515
+ Optional[torch.Tensor],
516
+ ]:
517
+ # Output: lengths, indices, weights", pos?, unbucketize_permute?
518
+ num_buckets = my_size
519
+ num_features = lengths.size(0)
520
+ num_values = indices.size(0)
521
+ return (
522
+ lengths.new_empty([num_buckets * num_features]),
523
+ indices.new_empty([num_values]),
524
+ weights.new_empty(weights.shape) if weights is not None else None,
525
+ indices.new_empty([num_values]) if bucketize_pos else None,
526
+ indices.new_empty([num_values]),
527
+ )
528
+
529
+
530
+ def block_bucketize_sparse_features_2d_weights_meta(
531
+ lengths: torch.Tensor,
532
+ indices: torch.Tensor,
533
+ bucketize_pos: bool,
534
+ sequence: bool,
535
+ block_sizes: torch.Tensor,
536
+ my_size: int,
537
+ weights: torch.Tensor,
538
+ weights_dim: int = 1,
539
+ batch_size_per_feature: Optional[torch.Tensor] = None,
540
+ max_B: int = -1,
541
+ block_bucketize_pos: Optional[torch.Tensor] = None,
542
+ keep_orig_idx: bool = False,
543
+ total_num_blocks: Optional[torch.Tensor] = None,
544
+ keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
545
+ ) -> tuple[
546
+ torch.Tensor,
547
+ torch.Tensor,
548
+ torch.Tensor,
549
+ Optional[torch.Tensor],
550
+ Optional[torch.Tensor],
551
+ ]:
552
+ # Output: lengths, indices, weights", pos?, unbucketize_permute?
553
+ num_buckets = my_size
554
+ num_features = lengths.size(0)
555
+ num_values = indices.size(0)
556
+ return (
557
+ lengths.new_empty([num_buckets * num_features]),
558
+ indices.new_empty([num_values]),
559
+ weights.new_empty([num_values, weights_dim]),
560
+ indices.new_empty([num_values]) if bucketize_pos else None,
561
+ indices.new_empty([num_values]),
562
+ )
563
+
564
+
565
+ def merge_pooled_embeddings(
566
+ pooled_embeddings: list[torch.Tensor],
567
+ uncat_dim_size: int,
568
+ target_device: torch.device,
569
+ cat_dim: int = 1,
570
+ ) -> torch.Tensor:
571
+ if len(pooled_embeddings) == 0:
572
+ return torch.empty([], device=target_device)
573
+ torch._check_is_size(cat_dim)
574
+ torch._check(cat_dim >= 0)
575
+ torch._check(cat_dim <= 1)
576
+ total_cat_dim_size = 0
577
+ for e in pooled_embeddings:
578
+ torch._check(e.dim() == 2)
579
+ torch._check(e.size(1 - cat_dim) == uncat_dim_size)
580
+ total_cat_dim_size += e.size(cat_dim)
581
+ torch._check_is_size(total_cat_dim_size)
582
+ e = pooled_embeddings[0]
583
+ if cat_dim == 0:
584
+ return e.new_empty(
585
+ [total_cat_dim_size, e.size(1)],
586
+ device=target_device,
587
+ )
588
+
589
+ return e.new_empty(
590
+ [e.size(0), total_cat_dim_size],
591
+ device=target_device,
592
+ )
593
+
594
+
595
+ def permute_sparse_features_abstract(
596
+ permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
597
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
598
+ torch._check(lengths.dtype == indices.dtype)
599
+ torch._check(permute.device == lengths.device)
600
+ torch._check(permute.device == indices.device)
601
+ if weights is not None:
602
+ torch._check(permute.device == weights.device)
603
+ num_output_features = permute.numel()
604
+ B = lengths.size(1)
605
+ permuted_lengths = lengths.new_empty(num_output_features, B)
606
+ output_size = torch.library.get_ctx().new_dynamic_size()
607
+ # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
608
+ # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
609
+ permuted_indices = indices.new_empty(output_size)
610
+ permuted_weights = None
611
+ if weights is not None:
612
+ # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
613
+ # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
614
+ permuted_weights = weights.new_empty(output_size)
615
+ return (permuted_lengths, permuted_indices, permuted_weights)
616
+
617
+
618
+ def segment_sum_csr_abstract(
619
+ batch_size: int, csr_seg: Tensor, values: Tensor
620
+ ) -> Tensor:
621
+ output_size = csr_seg.numel() - 1
622
+ output = values.new_empty(output_size)
623
+ return output
624
+
625
+
626
+ def dense_to_jagged_forward(
627
+ dense: torch.Tensor,
628
+ offsets: list[torch.Tensor],
629
+ total_L: Optional[torch.SymInt] = None,
630
+ ) -> torch.Tensor:
631
+ if total_L is None:
632
+ total_L = torch.library.get_ctx().new_dynamic_size()
633
+ return dense.new_zeros(
634
+ [total_L, dense.size()[-1]],
635
+ dtype=dense.dtype,
636
+ device=dense.device,
637
+ layout=dense.layout,
638
+ )
639
+
640
+
641
+ def dense_to_jagged(
642
+ dense: torch.Tensor,
643
+ offsets: list[torch.Tensor],
644
+ total_L: Optional[torch.SymInt] = None,
645
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
646
+ if total_L is None:
647
+ total_L = torch.library.get_ctx().new_dynamic_size()
648
+ return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
649
+
650
+
651
+ def batch_index_select_dim0_abstract(
652
+ inputs: torch.Tensor,
653
+ indices: torch.Tensor,
654
+ input_num_indices: list[int],
655
+ input_rows: list[int],
656
+ input_columns: list[int],
657
+ permute_output_dim_0_1: bool,
658
+ ) -> torch.Tensor:
659
+ """
660
+ This meta function is used to calculate the shape of output tensor
661
+ from the original function `fbgemm::batch_index_select_dim0` without the actual data.
662
+ """
663
+ # input lists must have the same length
664
+ torch._check(len(input_num_indices) == len(input_rows))
665
+ torch._check(len(input_num_indices) == len(input_columns))
666
+
667
+ if permute_output_dim_0_1 and len(input_num_indices) > 0:
668
+ # All num_indices must be the same if permute_output_dim_0_1 is True
669
+ for x in input_num_indices:
670
+ torch._check(x == input_num_indices[0])
671
+
672
+ size = sum([row * col for row, col in zip(input_rows, input_columns)])
673
+ torch._check(inputs.size(0) == size)
674
+
675
+ output_numel = 0
676
+ for i, cols in enumerate(input_columns):
677
+ output_numel += input_num_indices[i] * cols
678
+ return inputs.new_empty([output_numel])
679
+
680
+
681
+ def batch_index_select_dim0_tensor_abstract(
682
+ inputs: torch.Tensor,
683
+ indices: torch.Tensor,
684
+ input_num_indices: torch.Tensor,
685
+ input_rows: torch.Tensor,
686
+ input_columns: torch.Tensor,
687
+ permute_output_dim_0_1: bool,
688
+ ) -> torch.Tensor:
689
+ torch._check(input_num_indices.size(0) == input_rows.size(0))
690
+ torch._check(input_num_indices.size(0) == input_columns.size(0))
691
+ output_numel = torch.library.get_ctx().new_dynamic_size()
692
+ return inputs.new_empty([output_numel])
693
+
694
+
695
+ def batch_index_select_dim0_forward_cuda_impl_abstract(
696
+ inputs: torch.Tensor,
697
+ indices: torch.Tensor,
698
+ input_num_indices: list[int],
699
+ input_rows: list[int],
700
+ input_columns: list[int],
701
+ permute_output_dim_0_1: bool,
702
+ ) -> list[torch.Tensor]:
703
+ num_inputs = len(input_rows)
704
+ torch._check(len(input_num_indices) == len(input_rows))
705
+ torch._check(len(input_num_indices) == len(input_columns))
706
+
707
+ output_numel = 0
708
+ for i, cols in enumerate(input_columns):
709
+ output_numel += input_num_indices[i] * cols
710
+
711
+ output_offsets = (
712
+ inputs.new_empty([0], dtype=torch.int64)
713
+ if permute_output_dim_0_1
714
+ else inputs.new_empty([num_inputs + 1], dtype=torch.int64)
715
+ )
716
+
717
+ if permute_output_dim_0_1:
718
+ for i in range(num_inputs):
719
+ torch._check(input_num_indices[0] == input_num_indices[i])
720
+
721
+ return [
722
+ inputs.new_empty([output_numel]),
723
+ inputs.new_empty([num_inputs], dtype=torch.int64),
724
+ inputs.new_empty([num_inputs + 1], dtype=torch.int64),
725
+ inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets
726
+ output_offsets,
727
+ inputs.new_empty([num_inputs + 1], dtype=torch.int64),
728
+ inputs.new_empty([4], dtype=torch.int64, device="cpu"),
729
+ ]
730
+
731
+
732
+ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
733
+ inputs: torch.Tensor,
734
+ indices: torch.Tensor,
735
+ input_num_indices: torch.Tensor,
736
+ input_rows: torch.Tensor,
737
+ input_columns: torch.Tensor,
738
+ permute_output_dim_0_1: bool,
739
+ ) -> list[torch.Tensor]:
740
+ num_inputs: int = input_rows.size(0)
741
+ torch._check(input_num_indices.size(0) == input_rows.size(0))
742
+ torch._check(input_num_indices.size(0) == input_columns.size(0))
743
+ output_numel = torch.library.get_ctx().new_dynamic_size()
744
+ if permute_output_dim_0_1:
745
+ output_offsets = inputs.new_empty([0], dtype=torch.int64)
746
+ else:
747
+ output_offsets = inputs.new_empty([num_inputs + 1], dtype=torch.int64)
748
+
749
+ return [
750
+ inputs.new_empty([output_numel]),
751
+ inputs.new_empty([num_inputs], dtype=torch.int64),
752
+ inputs.new_empty([num_inputs + 1], dtype=torch.int64),
753
+ inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets
754
+ output_offsets,
755
+ inputs.new_empty([num_inputs + 1], dtype=torch.int64), # total_L_offsets
756
+ inputs.new_empty([4], dtype=torch.int64, device="cpu"),
757
+ ]
758
+
759
+
760
+ def batch_index_select_dim0_tensor_backward_cuda_impl_abstract(
761
+ grad_output: torch.Tensor,
762
+ dev_weights: torch.Tensor,
763
+ weights_offsets: torch.Tensor,
764
+ D_offsets: torch.Tensor,
765
+ hash_size_cumsum: torch.Tensor,
766
+ indices: torch.Tensor,
767
+ max_segment_length_per_warp: int,
768
+ grad_offsets: torch.Tensor,
769
+ total_L_offsets: torch.Tensor,
770
+ permute_output_dim_0_1: bool,
771
+ saved_tensor: torch.Tensor,
772
+ ) -> torch.Tensor:
773
+ return grad_output.new_empty(dev_weights.shape)
774
+
775
+
776
+ def keyed_jagged_index_select_dim1_abstract(
777
+ values: torch.Tensor,
778
+ lengths: torch.Tensor,
779
+ offsets: torch.Tensor,
780
+ indices: torch.Tensor,
781
+ batch_size: torch.SymInt,
782
+ weights: Optional[torch.Tensor] = None,
783
+ selected_lengths_sum: Optional[torch.SymInt] = None,
784
+ ) -> list[torch.Tensor]:
785
+ """
786
+ This meta function is used to calculate the shape of output tensors
787
+ from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
788
+ """
789
+ # pyre-ignore
790
+ num_batches = len(lengths) // batch_size
791
+ # offsets = [0] + lengths.cumsum(0)
792
+ torch._check(len(lengths) + 1 == len(offsets))
793
+ # len(lengths) == batch_size * num_batches
794
+ # pyre-ignore
795
+ torch._check(len(lengths) % batch_size == 0)
796
+ if weights is not None:
797
+ # weights must have the same shape as values
798
+ torch._check(values.shape == weights.shape)
799
+
800
+ if selected_lengths_sum is None:
801
+ length_indices = torch.cat(
802
+ # pyre-ignore
803
+ [indices + i * batch_size for i in range(num_batches)]
804
+ )
805
+ selected_lengths_sum = (
806
+ torch.index_select(lengths, 0, length_indices).sum().item()
807
+ )
808
+
809
+ ret: list[torch.Tensor] = [
810
+ # pyre-ignore
811
+ values.new_empty([selected_lengths_sum]),
812
+ lengths.new_empty([indices.shape[0] * num_batches]),
813
+ ]
814
+
815
+ if weights is not None:
816
+ # pyre-ignore
817
+ ret.append(weights.new_empty([selected_lengths_sum]))
818
+
819
+ return ret
820
+
821
+
822
+ def batch_index_select_dim0_backward_cuda_impl_abstract(
823
+ grad_output: torch.Tensor,
824
+ dev_weights: torch.Tensor,
825
+ weights_offsets: torch.Tensor,
826
+ D_offsets: torch.Tensor,
827
+ hash_size_cumsum: torch.Tensor,
828
+ indices: torch.Tensor,
829
+ max_segment_length_per_warp: int,
830
+ grad_offsets: torch.Tensor,
831
+ total_L_offsets: torch.Tensor,
832
+ permute_output_dim_0_1: bool,
833
+ saved_tensor: torch.Tensor,
834
+ ) -> torch.Tensor:
835
+ return grad_output.new_empty(dev_weights.shape)
836
+
837
+
838
+ def batch_index_select_dim0_forward_cpu_impl_abstract(
839
+ inputs: torch.Tensor,
840
+ indices: torch.Tensor,
841
+ input_num_indices: list[int],
842
+ input_rows: list[int],
843
+ input_columns: list[int],
844
+ permute_output_dim_0_1: bool,
845
+ ) -> list[torch.Tensor]:
846
+ # input lists must have the same length
847
+ num_inputs = len(input_num_indices)
848
+ torch._check(num_inputs == len(input_rows))
849
+ torch._check(num_inputs == len(input_columns))
850
+
851
+ if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
852
+ # All num_indices must be the same if permute_output_dim_0_1 is True
853
+ for x in input_num_indices:
854
+ torch._check(x == input_num_indices[0])
855
+
856
+ output_numel: int = sum([i * c for i, c in zip(input_num_indices, input_columns)])
857
+
858
+ return [
859
+ inputs.new_empty([output_numel]),
860
+ inputs.new_empty([len(input_num_indices)], dtype=torch.int64),
861
+ inputs.new_empty([len(input_rows)], dtype=torch.int64),
862
+ inputs.new_empty([len(input_columns)], dtype=torch.int64),
863
+ inputs.new_empty([num_inputs], dtype=torch.int64), # indices_numels
864
+ inputs.new_empty([1], dtype=torch.int64), # saved_tensor
865
+ ]
866
+
867
+
868
+ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
869
+ inputs: torch.Tensor,
870
+ indices: torch.Tensor,
871
+ input_num_indices: torch.Tensor,
872
+ input_rows: torch.Tensor,
873
+ input_columns: torch.Tensor,
874
+ permute_output_dim_0_1: bool,
875
+ ) -> list[torch.Tensor]:
876
+ # input lists must have the same length
877
+ num_inputs = len(input_num_indices)
878
+ torch._check(num_inputs == len(input_rows))
879
+ torch._check(num_inputs == len(input_columns))
880
+
881
+ output_numel = torch.library.get_ctx().new_dynamic_size()
882
+
883
+ return [
884
+ inputs.new_empty([output_numel]),
885
+ inputs.new_empty([1], dtype=torch.int64),
886
+ ]
887
+
888
+
889
+ def batch_index_select_dim0_backward_cpu_impl_abstract(
890
+ grad_output: torch.Tensor,
891
+ indices: torch.Tensor,
892
+ indices_numels: torch.Tensor,
893
+ input_num_indices: torch.Tensor,
894
+ input_rows: torch.Tensor,
895
+ input_columns: torch.Tensor,
896
+ permute_output_dim_0_1: bool,
897
+ saved_tensor: torch.Tensor,
898
+ ) -> torch.Tensor:
899
+ return grad_output.new_empty([torch.library.get_ctx().new_dynamic_size()])
900
+
901
+
902
+ def bounds_check_indices_abstract(
903
+ rows_per_table: torch.Tensor,
904
+ indices: torch.Tensor,
905
+ offsets: torch.Tensor,
906
+ bounds_check_mode_int: int,
907
+ bounds_check_warning: torch.Tensor,
908
+ per_sample_weights: Optional[torch.Tensor] = None,
909
+ B_offsets: Optional[torch.Tensor] = None,
910
+ max_B: Optional[SymInt] = None,
911
+ b_t_map: Optional[torch.Tensor] = None,
912
+ info_B_num_bits: int = -1,
913
+ info_B_mask: int = -1,
914
+ bounds_check_version: int = 1,
915
+ prefetch_pipeline: bool = False,
916
+ ) -> None:
917
+ """
918
+ This meta function is used to fake the bounds checking
919
+ from the original function `fbgemm::bounds_check_indices`
920
+ """
921
+ return
922
+
923
+
924
+ def group_index_select_dim0_gpu_impl_abstract(
925
+ inputs: list[torch.Tensor], group_size: int
926
+ ) -> list[torch.Tensor]:
927
+ """
928
+ Calculate output shapes for group_index_select_dim0_gpu_impl
929
+ without the actual data.
930
+ """
931
+ indices_group = inputs[:group_size]
932
+ input_group = inputs[group_size:]
933
+ torch._check(len(input_group) == group_size)
934
+
935
+ ret = []
936
+ for i in range(group_size):
937
+ size = list(input_group[i].size())
938
+ ret.append(input_group[i].new_empty([indices_group[i].size(0)] + size[1:]))
939
+
940
+ # divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2
941
+ args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2))
942
+
943
+ ret.append(
944
+ # sizeof(int64_t) = 8, torch.uint8 = at::kByte
945
+ input_group[0].new_empty(
946
+ args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
947
+ )
948
+ )
949
+
950
+ ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))
951
+
952
+ return ret
953
+
954
+
955
+ def group_index_select_dim0_gpu_backward_abstract(
956
+ all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
957
+ ) -> list[torch.Tensor]:
958
+ """
959
+ Calculate output shapes for group_index_select_dim0_gpu_backward
960
+ without the actual data.
961
+ """
962
+ torch._check(len(all_inputs) > 3)
963
+ group_size = (len(all_inputs) - 3) // 2
964
+ ret = []
965
+
966
+ # indices
967
+ for _ in range(group_size):
968
+ ret.append(all_inputs[0].new_empty(0))
969
+
970
+ # inputs
971
+ output_dim = len(output_shape_group_ref) // group_size
972
+ for i in range(group_size):
973
+ ret.append(
974
+ all_inputs[0].new_empty(
975
+ output_shape_group_ref[i * output_dim : (i + 1) * output_dim]
976
+ )
977
+ )
978
+
979
+ return ret
980
+
981
+
982
+ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
983
+ values: torch.Tensor,
984
+ lengths: torch.Tensor,
985
+ offsets: torch.Tensor,
986
+ indices: torch.Tensor,
987
+ batch_size: torch.SymInt,
988
+ weights: Optional[torch.Tensor] = None,
989
+ selected_lengths_sum: Optional[torch.SymInt] = None,
990
+ ) -> list[torch.Tensor]:
991
+ num_batches = lengths.size(0) // batch_size
992
+ torch._check(lengths.size(0) + 1 == offsets.size(0))
993
+ # pyre-ignore
994
+ torch._check(lengths.size(0) % batch_size == 0)
995
+
996
+ if weights is not None:
997
+ # weights must have the same shape as values
998
+ torch._check(values.shape == weights.shape)
999
+
1000
+ if selected_lengths_sum is None:
1001
+ selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
1002
+
1003
+ torch._check_is_size(selected_lengths_sum)
1004
+ vlw: list[torch.Tensor] = [
1005
+ values.new_empty([selected_lengths_sum]), # output
1006
+ lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
1007
+ ]
1008
+ if weights is not None:
1009
+ vlw.append(weights.new_empty([selected_lengths_sum])) # output_weights
1010
+
1011
+ return [
1012
+ *vlw,
1013
+ offsets.new_empty([indices.shape[0] * num_batches]), # output_offsets
1014
+ torch.empty([4], dtype=torch.int64, device="cpu"), # saved_data_tensor
1015
+ ]
1016
+
1017
+
1018
+ def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract(
1019
+ grad: torch.Tensor,
1020
+ indices: torch.Tensor,
1021
+ grad_offsets: torch.Tensor,
1022
+ output_offsets: torch.Tensor,
1023
+ saved_tensor: torch.Tensor,
1024
+ ) -> torch.Tensor:
1025
+ return grad.new_empty([torch.library.get_ctx().new_dynamic_size()])
1026
+
1027
+
1028
+ def permute_pooled_embs_split_abstract(
1029
+ pooled_embs: Tensor,
1030
+ offset_dim_list: Tensor,
1031
+ permute_list: Tensor,
1032
+ inv_offset_dim_list: Tensor,
1033
+ inv_permute_list: Tensor,
1034
+ ) -> Tensor:
1035
+ return torch.empty_like(pooled_embs)
1036
+
1037
+
1038
+ def histogram_binning_calibration_abstract(
1039
+ logit: Tensor,
1040
+ bin_num_examples: Tensor,
1041
+ bin_num_positives: Tensor,
1042
+ positive_weight: float,
1043
+ lower_bound: float,
1044
+ upper_bound: float,
1045
+ bin_ctr_in_use_after: int,
1046
+ bin_ctr_weight_value: float,
1047
+ ) -> tuple[Tensor, Tensor]:
1048
+ return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
1049
+
1050
+
1051
+ def float_to_hfp8_quantized(
1052
+ input: Tensor, ebits: int, exponent_bias: int, max_pos: float
1053
+ ) -> Tensor:
1054
+ return torch.empty_like(input, dtype=torch.uint8)
1055
+
1056
+
1057
+ def hfp8_quantized_to_float(input: Tensor, ebits: int, exponent_bias: int) -> Tensor:
1058
+ return torch.empty_like(input, dtype=torch.float32)
1059
+
1060
+
1061
+ def float_or_half_to_fused_nbit_rowwise_quantized_sbhalf(
1062
+ input_t: Tensor,
1063
+ bit_rate: int,
1064
+ ) -> Tensor:
1065
+ input_sizes = input_t.size()
1066
+ torch._check(len(input_sizes) == 2)
1067
+ nrows = input_sizes[0]
1068
+ ncols = input_sizes[1]
1069
+ num_elem_per_byte = 8 // bit_rate
1070
+
1071
+ torch._check(ncols % (2 * num_elem_per_byte) == 0)
1072
+ output_columns = (ncols + num_elem_per_byte - 1) // num_elem_per_byte + 2 * 2
1073
+ output = torch.empty(
1074
+ (nrows, output_columns), device=input_t.device, dtype=torch.uint8
1075
+ )
1076
+ return output
1077
+
1078
+
1079
+ def fused_nbit_rowwise_quantized_sb_half_to_float_or_half(
1080
+ input_t: Tensor,
1081
+ bit_rate: int,
1082
+ output_dtype: int = 0,
1083
+ ) -> Tensor:
1084
+ torch._check(output_dtype in [SparseType.FP32.as_int(), SparseType.FP16.as_int()])
1085
+ nrows = input_t.size(0)
1086
+ ncols = input_t.size(1)
1087
+ if input_t.dtype == torch.quint2x4:
1088
+ ncols = (ncols + 3) // 4
1089
+ elif input_t.dtype == torch.quint4x2:
1090
+ ncols = (ncols + 1) // 2
1091
+ num_elem_per_byte = 8 // bit_rate
1092
+ output_columns = (ncols - 2 * 2) * num_elem_per_byte
1093
+ if output_dtype == SparseType.FP32.as_int():
1094
+ return torch.empty(
1095
+ (nrows, output_columns), dtype=torch.float32, device=input_t.device
1096
+ )
1097
+ else: # output_dtype is SparseType.FP16
1098
+ return torch.empty(
1099
+ (nrows, output_columns), dtype=torch.float16, device=input_t.device
1100
+ )
1101
+
1102
+
1103
+ def fused_8_bit_rowwise_quantized_to_float_or_half(
1104
+ input_t: Tensor,
1105
+ output_dtype: int = 0,
1106
+ scale_bias_last: bool = True,
1107
+ quant_padding_float_type: bool = True,
1108
+ ) -> Tensor:
1109
+ torch._check(
1110
+ output_dtype
1111
+ in [
1112
+ SparseType.FP32.as_int(),
1113
+ SparseType.FP16.as_int(),
1114
+ SparseType.BF16.as_int(),
1115
+ ]
1116
+ )
1117
+ torch._check(quant_padding_float_type or not scale_bias_last)
1118
+ torch._check(input_t.dim() >= 2)
1119
+ last_dim = input_t.dim() - 1
1120
+ output_shape = list(input_t.shape)
1121
+ ncols = input_t.size(last_dim)
1122
+ quant_padding_size = 4 if quant_padding_float_type else 2
1123
+ ncols_aligned = (
1124
+ (ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1125
+ )
1126
+ output_columns = ncols_aligned - 2 * quant_padding_size
1127
+ output_shape[last_dim] = output_columns
1128
+ if output_dtype == SparseType.FP32.as_int():
1129
+ return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
1130
+ elif output_dtype == SparseType.FP16.as_int():
1131
+ return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
1132
+ else: # output_dtype is SparseType.BF16
1133
+ return torch.empty(output_shape, dtype=torch.bfloat16, device=input_t.device)
1134
+
1135
+
1136
+ def float_or_half_to_fused_8_bit_rowwise(
1137
+ input_t: Tensor,
1138
+ ) -> Tensor:
1139
+ torch._check(input_t.dim() >= 2)
1140
+ last_dim = input_t.dim() - 1
1141
+ output_shape = list(input_t.shape)
1142
+ ncols = input_t.size(last_dim)
1143
+ ncols_aligned = (ncols + 4 - 1) // 4 * 4
1144
+ output_columns = ncols_aligned + 2 * 4
1145
+ output_shape[last_dim] = output_columns
1146
+ return torch.empty(output_shape, dtype=torch.uint8, device=input_t.device)
1147
+
1148
+
1149
+ def fused_8_bit_rowwise_quantized_to_float(
1150
+ input_t: Tensor,
1151
+ scale_bias_last: bool = True,
1152
+ quant_padding_float_type: bool = True,
1153
+ ) -> Tensor:
1154
+ torch._check(quant_padding_float_type or not scale_bias_last)
1155
+ torch._check(input_t.dim() >= 2)
1156
+ last_dim = input_t.dim() - 1
1157
+ output_shape = list(input_t.shape)
1158
+ ncols = input_t.size(last_dim)
1159
+ quant_padding_size = 4 if quant_padding_float_type else 2
1160
+ ncols_aligned = (
1161
+ (ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1162
+ )
1163
+ output_columns = ncols_aligned - 2 * quant_padding_size
1164
+ output_shape[last_dim] = output_columns
1165
+ return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
1166
+
1167
+
1168
+ def fused_8_bit_rowwise_quantized_to_half(
1169
+ input_t: Tensor,
1170
+ scale_bias_last: bool = True,
1171
+ quant_padding_float_type: bool = True,
1172
+ ) -> Tensor:
1173
+ torch._check(quant_padding_float_type or not scale_bias_last)
1174
+ torch._check(input_t.dim() >= 2)
1175
+ last_dim = input_t.dim() - 1
1176
+ output_shape = list(input_t.shape)
1177
+ ncols = input_t.size(last_dim)
1178
+ quant_padding_size = 4 if quant_padding_float_type else 2
1179
+ ncols_aligned = (
1180
+ (ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1181
+ )
1182
+ output_columns = ncols_aligned - 2 * quant_padding_size
1183
+ output_shape[last_dim] = output_columns
1184
+ return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
1185
+
1186
+
1187
+ def generic_histogram_binning_calibration_by_feature(
1188
+ logit: Tensor,
1189
+ segment_value: Tensor,
1190
+ segment_lengths: Tensor,
1191
+ num_segments: int,
1192
+ bin_num_examples: Tensor,
1193
+ bin_num_positives: Tensor,
1194
+ bin_boundaries: Tensor,
1195
+ positive_weight: float,
1196
+ bin_ctr_in_use_after: int,
1197
+ bin_ctr_weight_value: float,
1198
+ ) -> tuple[Tensor, Tensor]:
1199
+ torch._check(bin_num_examples.numel() == bin_num_positives.numel())
1200
+ torch._check(
1201
+ bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
1202
+ )
1203
+ return torch.empty_like(logit), torch.empty(
1204
+ [logit.numel()], dtype=torch.int64, device=logit.device
1205
+ )
1206
+
1207
+
1208
+ def permute_multi_embedding_function_impl_abstract(
1209
+ pooled_embs: list[Tensor],
1210
+ permutes: Tensor,
1211
+ in_shapes: Tensor,
1212
+ out_shapes: Tensor,
1213
+ out_lengths: list[int],
1214
+ reverse: bool = False,
1215
+ ) -> list[Tensor]:
1216
+ out_dtype = pooled_embs[0].dtype
1217
+ bs = pooled_embs[0].shape[0]
1218
+ torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
1219
+
1220
+ output = []
1221
+ for i in range(len(out_lengths)):
1222
+ output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
1223
+ return output
1224
+
1225
+
1226
+ def lengths_range_abstract(
1227
+ lengths: Tensor,
1228
+ output_shape: Optional[Sequence[int]] = None,
1229
+ ) -> Tensor:
1230
+ torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
1231
+ output_size = 0
1232
+ if output_shape is not None:
1233
+ output_size = math.prod(output_shape)
1234
+ else:
1235
+ ctx = torch.library.get_ctx()
1236
+ output_size = ctx.new_dynamic_size()
1237
+ return lengths.new_empty([output_size], dtype=lengths.dtype)
1238
+
1239
+
1240
+ def all_to_one_device(
1241
+ input_tensors: list[Tensor],
1242
+ target_device: torch.device,
1243
+ ) -> list[Tensor]:
1244
+ return [
1245
+ torch.empty_like(input_tensor, device=torch.device("meta"))
1246
+ for input_tensor in input_tensors
1247
+ ]
1248
+
1249
+
1250
+ def sum_reduce_to_one(
1251
+ input_tensors: list[Tensor],
1252
+ target_device: torch.device,
1253
+ ) -> Tensor:
1254
+ torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
1255
+ # All tensors should have the same shape
1256
+ first_tensor = input_tensors[0]
1257
+ return torch.empty_like(first_tensor, device=torch.device("meta"))
1258
+
1259
+
1260
+ def _setup() -> None:
1261
+ # pyre-ignore[16]
1262
+ _setup.done = getattr(_setup, "done", False)
1263
+
1264
+ # pyre-ignore[2]
1265
+ def impl_abstract(op_name, fn) -> None:
1266
+ # NOTE: Failures have occasionally been observed with register_fake,
1267
+ # where the error signatures can be found in:
1268
+ # https://github.com/pytorch/pytorch/blob/main/torch/_library/fake_impl.py
1269
+ #
1270
+ # To work around this, we first check if the kernel is already registered
1271
+ # for the following dispatch keys, and if so, we skip the registration.
1272
+ for dkey in ["CompositeImplicitAutograd", "Meta"]:
1273
+ if torch._C._dispatch_has_kernel_for_dispatch_key(op_name, dkey):
1274
+ return
1275
+ torch.library.register_fake(op_name, fn)
1276
+
1277
+ # pyre-ignore[2,24]
1278
+ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None:
1279
+ name_split = op_name.split("::")
1280
+ key = f"{name_split[0]}/{name_split[-1]}/Autograd"
1281
+ if key not in torch.library._impls:
1282
+ torch.library.register_autograd(op_name, fn, setup_context=setup_context)
1283
+
1284
+ if not _setup.done:
1285
+ impl_autograd(
1286
+ "fbgemm::permute_2D_sparse_data",
1287
+ permute_2D_sparse_data_backward,
1288
+ setup_context=permute_2D_sparse_data_setup_context,
1289
+ )
1290
+
1291
+ impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
1292
+ impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
1293
+ impl_abstract(
1294
+ "fbgemm::permute_2D_sparse_data_input1D",
1295
+ permute_2D_sparse_data_input1D_meta,
1296
+ )
1297
+ impl_abstract("fbgemm::invert_permute", invert_permute_abstract)
1298
+ impl_abstract("fbgemm::permute_1D_sparse_data", permute_1D_sparse_data_meta)
1299
+ impl_abstract("fbgemm::masked_select_jagged_1d", masked_select_jagged_1d)
1300
+ impl_abstract("fbgemm::tbe_input_combine", tbe_input_combine_abstract)
1301
+ impl_abstract(
1302
+ "fbgemm::tbe_input_combine_with_length",
1303
+ tbe_input_combine_with_length_abstract,
1304
+ )
1305
+ impl_abstract(
1306
+ "fbgemm::jagged_index_select_2d_forward_v2",
1307
+ jagged_index_select_2d_forward_v2_abstract,
1308
+ )
1309
+ impl_abstract(
1310
+ "fbgemm::jagged_index_add_2d_forward_v2",
1311
+ jagged_index_add_2d_forward_v2_abstract,
1312
+ )
1313
+ impl_abstract(
1314
+ "fbgemm::expand_into_jagged_permute", expand_into_jagged_permute_meta
1315
+ )
1316
+ impl_abstract("fbgemm::pruned_array_lookup", pruned_array_lookup_meta)
1317
+ impl_abstract(
1318
+ "fbgemm::int_nbit_split_embedding_codegen_lookup_function",
1319
+ int_nbit_split_embedding_codegen_lookup_function_meta,
1320
+ )
1321
+ impl_abstract(
1322
+ "fbgemm::block_bucketize_sparse_features",
1323
+ block_bucketize_sparse_features_meta,
1324
+ )
1325
+ impl_abstract(
1326
+ "fbgemm::block_bucketize_sparse_features_2d_weights",
1327
+ block_bucketize_sparse_features_2d_weights_meta,
1328
+ )
1329
+ impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
1330
+ impl_abstract(
1331
+ "fbgemm::permute_sparse_features", permute_sparse_features_abstract
1332
+ )
1333
+ impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
1334
+ impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
1335
+ impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
1336
+ impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
1337
+ impl_abstract(
1338
+ "fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
1339
+ )
1340
+ impl_abstract(
1341
+ "fbgemm::batch_index_select_dim0_tensor",
1342
+ batch_index_select_dim0_tensor_abstract,
1343
+ )
1344
+ impl_abstract(
1345
+ "fbgemm::batch_index_select_dim0_forward_cuda_impl",
1346
+ batch_index_select_dim0_forward_cuda_impl_abstract,
1347
+ )
1348
+ impl_abstract(
1349
+ "fbgemm::batch_index_select_dim0_tensor_forward_cuda_impl",
1350
+ batch_index_select_dim0_tensor_forward_cuda_impl_abstract,
1351
+ )
1352
+ impl_abstract(
1353
+ "fbgemm::batch_index_select_dim0_tensor_backward_cuda_impl",
1354
+ batch_index_select_dim0_tensor_backward_cuda_impl_abstract,
1355
+ )
1356
+ impl_abstract(
1357
+ "fbgemm::batch_index_select_dim0_backward_cuda_impl",
1358
+ batch_index_select_dim0_backward_cuda_impl_abstract,
1359
+ )
1360
+ impl_abstract(
1361
+ "fbgemm::keyed_jagged_index_select_dim1",
1362
+ keyed_jagged_index_select_dim1_abstract,
1363
+ )
1364
+ impl_abstract(
1365
+ "fbgemm::batch_index_select_dim0_forward_cpu_impl",
1366
+ batch_index_select_dim0_forward_cpu_impl_abstract,
1367
+ )
1368
+ impl_abstract(
1369
+ "fbgemm::batch_index_select_dim0_tensor_forward_cpu_impl",
1370
+ batch_index_select_dim0_tensor_forward_cpu_impl_abstract,
1371
+ )
1372
+ impl_abstract(
1373
+ "fbgemm::batch_index_select_dim0_backward_cpu_impl",
1374
+ batch_index_select_dim0_backward_cpu_impl_abstract,
1375
+ )
1376
+ impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract)
1377
+ impl_abstract(
1378
+ "fbgemm::group_index_select_dim0_gpu_impl",
1379
+ group_index_select_dim0_gpu_impl_abstract,
1380
+ )
1381
+ impl_abstract(
1382
+ "fbgemm::group_index_select_dim0_gpu_backward",
1383
+ group_index_select_dim0_gpu_backward_abstract,
1384
+ )
1385
+ impl_abstract(
1386
+ "fbgemm::keyed_jagged_index_select_dim1_forward",
1387
+ keyed_jagged_index_select_dim1_forward_cuda_impl_abstract,
1388
+ )
1389
+ impl_abstract(
1390
+ "fbgemm::keyed_jagged_index_select_dim1_backward",
1391
+ keyed_jagged_index_select_dim1_backward_cuda_impl_abstract,
1392
+ )
1393
+ impl_abstract(
1394
+ "fbgemm::permute_pooled_embs_split", permute_pooled_embs_split_abstract
1395
+ )
1396
+ impl_abstract(
1397
+ "fbgemm::histogram_binning_calibration",
1398
+ histogram_binning_calibration_abstract,
1399
+ )
1400
+ impl_abstract(
1401
+ "fbgemm::generic_histogram_binning_calibration_by_feature",
1402
+ generic_histogram_binning_calibration_by_feature,
1403
+ )
1404
+ impl_abstract(
1405
+ "fbgemm::lengths_range",
1406
+ lengths_range_abstract,
1407
+ )
1408
+ impl_abstract(
1409
+ "fbgemm::permute_multi_embedding_function",
1410
+ permute_multi_embedding_function_impl_abstract,
1411
+ )
1412
+ impl_abstract(
1413
+ "fbgemm::FloatToHFP8Quantized",
1414
+ float_to_hfp8_quantized,
1415
+ )
1416
+ impl_abstract(
1417
+ "fbgemm::HFP8QuantizedToFloat",
1418
+ hfp8_quantized_to_float,
1419
+ )
1420
+ impl_abstract(
1421
+ "fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
1422
+ float_or_half_to_fused_nbit_rowwise_quantized_sbhalf,
1423
+ )
1424
+ impl_abstract(
1425
+ "fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
1426
+ fused_nbit_rowwise_quantized_sb_half_to_float_or_half,
1427
+ )
1428
+ impl_abstract(
1429
+ "fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf",
1430
+ fused_8_bit_rowwise_quantized_to_float_or_half,
1431
+ )
1432
+ impl_abstract(
1433
+ "fbgemm::FloatToFused8BitRowwiseQuantized",
1434
+ float_or_half_to_fused_8_bit_rowwise,
1435
+ )
1436
+ impl_abstract(
1437
+ "fbgemm::FloatOrHalfToFused8BitRowwiseQuantized",
1438
+ float_or_half_to_fused_8_bit_rowwise,
1439
+ )
1440
+ impl_abstract(
1441
+ "fbgemm::HalfToFused8BitRowwiseQuantized",
1442
+ float_or_half_to_fused_8_bit_rowwise,
1443
+ )
1444
+ impl_abstract(
1445
+ "fbgemm::Fused8BitRowwiseQuantizedToFloat",
1446
+ fused_8_bit_rowwise_quantized_to_float,
1447
+ )
1448
+ impl_abstract(
1449
+ "fbgemm::Fused8BitRowwiseQuantizedToHalf",
1450
+ fused_8_bit_rowwise_quantized_to_half,
1451
+ )
1452
+ _setup.done = True
1453
+
1454
+
1455
+ _setup()