fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/sparse_ops.py
CHANGED
|
@@ -7,10 +7,12 @@
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
9
|
import math
|
|
10
|
-
from
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from typing import Callable, Optional
|
|
11
12
|
|
|
12
13
|
import torch
|
|
13
14
|
|
|
15
|
+
# fmt:skip
|
|
14
16
|
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
15
17
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
|
|
16
18
|
from fbgemm_gpu.utils.loader import load_torch_module
|
|
@@ -41,12 +43,14 @@ except Exception:
|
|
|
41
43
|
torch.ops.load_library(
|
|
42
44
|
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
|
|
43
45
|
)
|
|
46
|
+
torch.ops.load_library(
|
|
47
|
+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
import torch.utils._pytree as pytree
|
|
47
52
|
from torch import SymInt, Tensor
|
|
48
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
49
|
-
|
|
53
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
50
54
|
|
|
51
55
|
if hasattr(torch.library, "register_fake"):
|
|
52
56
|
# pyre-ignore[9]
|
|
@@ -71,7 +75,7 @@ def permute_2D_sparse_data_input1D_meta(
|
|
|
71
75
|
stride: int,
|
|
72
76
|
weights: Optional[Tensor] = None,
|
|
73
77
|
permuted_lengths_sum: Optional[int] = None,
|
|
74
|
-
) ->
|
|
78
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
75
79
|
torch._check(
|
|
76
80
|
lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
|
|
77
81
|
)
|
|
@@ -108,7 +112,7 @@ def permute_2D_sparse_data_input1D_backward(
|
|
|
108
112
|
grad_lengths: torch.Tensor,
|
|
109
113
|
grad_values: torch.Tensor,
|
|
110
114
|
grad_weights: torch.Tensor,
|
|
111
|
-
) ->
|
|
115
|
+
) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
|
|
112
116
|
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
|
|
113
117
|
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
|
|
114
118
|
torch.ops.fbgemm.permute_2D_sparse_data_input1D(
|
|
@@ -136,7 +140,7 @@ def permute_2D_sparse_data_meta(
|
|
|
136
140
|
values: Tensor,
|
|
137
141
|
weights: Optional[Tensor] = None,
|
|
138
142
|
permuted_lengths_sum: Optional[int] = None,
|
|
139
|
-
) ->
|
|
143
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
140
144
|
torch._check(
|
|
141
145
|
lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
|
|
142
146
|
)
|
|
@@ -163,6 +167,89 @@ def invert_permute_abstract(permute: Tensor) -> Tensor:
|
|
|
163
167
|
return torch.empty_like(permute)
|
|
164
168
|
|
|
165
169
|
|
|
170
|
+
def get_source_mask_meta(
|
|
171
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
172
|
+
) -> Tensor:
|
|
173
|
+
if output_size is None:
|
|
174
|
+
ctx = torch.library.get_ctx()
|
|
175
|
+
output_size = ctx.new_dynamic_size()
|
|
176
|
+
return torch.empty([output_size], dtype=torch.bool)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_source_mask(
|
|
180
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
"""
|
|
183
|
+
Generate a boolean mask indicating which elements are from sources vs targets.
|
|
184
|
+
|
|
185
|
+
This is a Python wrapper that computes output_size when not provided,
|
|
186
|
+
enabling the operation to work with meta tensors for compilation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
num_sources: 1D tensor of source counts per batch element
|
|
190
|
+
num_targets: 1D tensor of target counts per batch element
|
|
191
|
+
output_size: Optional pre-computed output size.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
A 1D boolean tensor where True indicates source elements and False
|
|
195
|
+
indicates target elements
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> num_sources = torch.tensor([2, 3])
|
|
199
|
+
>>> num_targets = torch.tensor([1, 2])
|
|
200
|
+
>>> get_source_mask(num_sources, num_targets)
|
|
201
|
+
tensor([True, True, False, True, True, True, False, False])
|
|
202
|
+
"""
|
|
203
|
+
# Compute output_size if not provided and tensors are regular (not meta/fake)
|
|
204
|
+
if output_size is None:
|
|
205
|
+
combined = num_sources + num_targets
|
|
206
|
+
output_size = int(combined.sum().item())
|
|
207
|
+
|
|
208
|
+
return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def repeat_arange_meta(lengths: Tensor) -> Tensor:
|
|
212
|
+
"""Meta implementation for repeat_arange."""
|
|
213
|
+
# Output size is data-dependent (sum of lengths).
|
|
214
|
+
# For FakeTensors (used in torch.compile), we use dynamic sizing.
|
|
215
|
+
# For actual meta tensors, we cannot determine the size so return empty.
|
|
216
|
+
if lengths.device.type == "meta":
|
|
217
|
+
# Actual meta tensors: return a zero-sized tensor as placeholder
|
|
218
|
+
# since we cannot compute the data-dependent output size
|
|
219
|
+
return torch.empty([0], dtype=lengths.dtype, device=lengths.device)
|
|
220
|
+
else:
|
|
221
|
+
# FakeTensor context: use dynamic sizing for proper shape tracking
|
|
222
|
+
ctx = torch.library.get_ctx()
|
|
223
|
+
output_size = ctx.new_dynamic_size()
|
|
224
|
+
return torch.empty([output_size], dtype=lengths.dtype, device=lengths.device)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def repeat_arange(lengths: Tensor) -> Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Creates a concatenated tensor of aranges based on a lengths tensor.
|
|
230
|
+
|
|
231
|
+
This is a high-performance CUDA kernel that replaces the inefficient PyTorch
|
|
232
|
+
implementation which uses 4+ separate kernels (cumsum, arange, repeat_interleave, sub).
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
lengths: 1D tensor of lengths for each arange sequence
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A 1D tensor containing concatenated arange sequences
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
>>> lengths = torch.tensor([3, 5, 2])
|
|
242
|
+
>>> repeat_arange(lengths)
|
|
243
|
+
tensor([0, 1, 2, 0, 1, 2, 3, 4, 0, 1])
|
|
244
|
+
|
|
245
|
+
Performance:
|
|
246
|
+
- PyTorch implementation: 4+ kernel launches + intermediate allocations
|
|
247
|
+
- CUDA implementation: 1 fused kernel, no intermediate allocations
|
|
248
|
+
- Typical speedup: 3-5x on realistic workloads
|
|
249
|
+
"""
|
|
250
|
+
return torch.ops.fbgemm.repeat_arange(lengths)
|
|
251
|
+
|
|
252
|
+
|
|
166
253
|
# pyre-ignore
|
|
167
254
|
def permute_2D_sparse_data_setup_context(ctx, inputs, output):
|
|
168
255
|
permute, lengths, values, weights, permuted_lengths_sum = inputs
|
|
@@ -194,7 +281,7 @@ def permute_1D_sparse_data_meta(
|
|
|
194
281
|
values: Tensor,
|
|
195
282
|
weights: Optional[Tensor] = None,
|
|
196
283
|
permuted_lengths_sum: Optional[int] = None,
|
|
197
|
-
) ->
|
|
284
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
198
285
|
indices = values
|
|
199
286
|
permuted_lengths_size = permute.numel()
|
|
200
287
|
permuted_lengths = lengths.new_empty([permuted_lengths_size])
|
|
@@ -215,7 +302,7 @@ def permute_1D_sparse_data_meta(
|
|
|
215
302
|
|
|
216
303
|
def masked_select_jagged_1d(
|
|
217
304
|
values: Tensor, lengths: Tensor, mask: Tensor
|
|
218
|
-
) ->
|
|
305
|
+
) -> tuple[Tensor, Tensor]:
|
|
219
306
|
torch._check(values.dim() == 1)
|
|
220
307
|
torch._check(lengths.dim() == 1)
|
|
221
308
|
torch._check(values.device == lengths.device)
|
|
@@ -228,11 +315,11 @@ def masked_select_jagged_1d(
|
|
|
228
315
|
|
|
229
316
|
|
|
230
317
|
def tbe_input_combine_abstract(
|
|
231
|
-
indices_list:
|
|
232
|
-
offsets_list:
|
|
233
|
-
per_sample_weights:
|
|
318
|
+
indices_list: list[Tensor],
|
|
319
|
+
offsets_list: list[Tensor],
|
|
320
|
+
per_sample_weights: list[Tensor],
|
|
234
321
|
include_last_offsets: Tensor,
|
|
235
|
-
) ->
|
|
322
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
236
323
|
torch._check(len(indices_list) > 0)
|
|
237
324
|
torch._check(len(indices_list) == len(offsets_list))
|
|
238
325
|
torch._check(len(indices_list) == len(per_sample_weights))
|
|
@@ -247,7 +334,7 @@ def tbe_input_combine_abstract(
|
|
|
247
334
|
torch._check(index.is_contiguous())
|
|
248
335
|
torch._check(offset.is_contiguous())
|
|
249
336
|
total_indices = total_indices + index.numel()
|
|
250
|
-
if
|
|
337
|
+
if guard_or_true(weight.numel() > 0):
|
|
251
338
|
torch._check(weight.dim() == 1)
|
|
252
339
|
torch._check(weight.numel() == index.numel())
|
|
253
340
|
torch._check(weight.is_contiguous())
|
|
@@ -265,10 +352,10 @@ def tbe_input_combine_abstract(
|
|
|
265
352
|
|
|
266
353
|
|
|
267
354
|
def tbe_input_combine_with_length_abstract(
|
|
268
|
-
indices_list:
|
|
269
|
-
offsets_list:
|
|
270
|
-
per_sample_weights:
|
|
271
|
-
) ->
|
|
355
|
+
indices_list: list[Tensor],
|
|
356
|
+
offsets_list: list[Tensor],
|
|
357
|
+
per_sample_weights: list[Tensor],
|
|
358
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
272
359
|
torch._check(len(indices_list) > 0)
|
|
273
360
|
torch._check(len(indices_list) == len(offsets_list))
|
|
274
361
|
torch._check(len(indices_list) == len(per_sample_weights))
|
|
@@ -284,7 +371,7 @@ def tbe_input_combine_with_length_abstract(
|
|
|
284
371
|
torch._check(offset.is_contiguous())
|
|
285
372
|
total_indices = total_indices + index.numel()
|
|
286
373
|
total_offsets = total_offsets + offset.numel()
|
|
287
|
-
if
|
|
374
|
+
if guard_or_true(weight.numel() > 0):
|
|
288
375
|
torch._check(weight.dim() == 1)
|
|
289
376
|
torch._check(weight.numel() == index.numel())
|
|
290
377
|
torch._check(weight.is_contiguous())
|
|
@@ -336,7 +423,7 @@ def expand_into_jagged_permute_meta(
|
|
|
336
423
|
permute: Tensor,
|
|
337
424
|
input_offsets: Tensor,
|
|
338
425
|
output_offsets: Tensor,
|
|
339
|
-
output_size:
|
|
426
|
+
output_size: tuple[int, ...],
|
|
340
427
|
) -> Tensor:
|
|
341
428
|
torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
|
|
342
429
|
torch._check(
|
|
@@ -417,6 +504,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
|
|
|
417
504
|
kINT8QparamsBytes = 8
|
|
418
505
|
|
|
419
506
|
if pooling_mode == PoolingMode.NONE:
|
|
507
|
+
kINT8QparamsBytes = 4
|
|
420
508
|
D = max(
|
|
421
509
|
[
|
|
422
510
|
max_int2_D,
|
|
@@ -432,7 +520,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
|
|
|
432
520
|
torch._check(D > 0)
|
|
433
521
|
adjusted_D = D
|
|
434
522
|
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
|
|
435
|
-
adjusted_D +=
|
|
523
|
+
adjusted_D += kINT8QparamsBytes
|
|
436
524
|
output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
|
|
437
525
|
return output
|
|
438
526
|
|
|
@@ -460,7 +548,8 @@ def block_bucketize_sparse_features_meta(
|
|
|
460
548
|
block_bucketize_pos: Optional[torch.Tensor] = None,
|
|
461
549
|
keep_orig_idx: bool = False,
|
|
462
550
|
total_num_blocks: Optional[torch.Tensor] = None,
|
|
463
|
-
|
|
551
|
+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
552
|
+
) -> tuple[
|
|
464
553
|
torch.Tensor,
|
|
465
554
|
torch.Tensor,
|
|
466
555
|
Optional[torch.Tensor],
|
|
@@ -480,8 +569,43 @@ def block_bucketize_sparse_features_meta(
|
|
|
480
569
|
)
|
|
481
570
|
|
|
482
571
|
|
|
572
|
+
def block_bucketize_sparse_features_2d_weights_meta(
|
|
573
|
+
lengths: torch.Tensor,
|
|
574
|
+
indices: torch.Tensor,
|
|
575
|
+
bucketize_pos: bool,
|
|
576
|
+
sequence: bool,
|
|
577
|
+
block_sizes: torch.Tensor,
|
|
578
|
+
my_size: int,
|
|
579
|
+
weights: torch.Tensor,
|
|
580
|
+
weights_dim: int = 1,
|
|
581
|
+
batch_size_per_feature: Optional[torch.Tensor] = None,
|
|
582
|
+
max_B: int = -1,
|
|
583
|
+
block_bucketize_pos: Optional[torch.Tensor] = None,
|
|
584
|
+
keep_orig_idx: bool = False,
|
|
585
|
+
total_num_blocks: Optional[torch.Tensor] = None,
|
|
586
|
+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
587
|
+
) -> tuple[
|
|
588
|
+
torch.Tensor,
|
|
589
|
+
torch.Tensor,
|
|
590
|
+
torch.Tensor,
|
|
591
|
+
Optional[torch.Tensor],
|
|
592
|
+
Optional[torch.Tensor],
|
|
593
|
+
]:
|
|
594
|
+
# Output: lengths, indices, weights", pos?, unbucketize_permute?
|
|
595
|
+
num_buckets = my_size
|
|
596
|
+
num_features = lengths.size(0)
|
|
597
|
+
num_values = indices.size(0)
|
|
598
|
+
return (
|
|
599
|
+
lengths.new_empty([num_buckets * num_features]),
|
|
600
|
+
indices.new_empty([num_values]),
|
|
601
|
+
weights.new_empty([num_values, weights_dim]),
|
|
602
|
+
indices.new_empty([num_values]) if bucketize_pos else None,
|
|
603
|
+
indices.new_empty([num_values]),
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
|
|
483
607
|
def merge_pooled_embeddings(
|
|
484
|
-
pooled_embeddings:
|
|
608
|
+
pooled_embeddings: list[torch.Tensor],
|
|
485
609
|
uncat_dim_size: int,
|
|
486
610
|
target_device: torch.device,
|
|
487
611
|
cat_dim: int = 1,
|
|
@@ -512,7 +636,7 @@ def merge_pooled_embeddings(
|
|
|
512
636
|
|
|
513
637
|
def permute_sparse_features_abstract(
|
|
514
638
|
permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
|
|
515
|
-
) ->
|
|
639
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
516
640
|
torch._check(lengths.dtype == indices.dtype)
|
|
517
641
|
torch._check(permute.device == lengths.device)
|
|
518
642
|
torch._check(permute.device == indices.device)
|
|
@@ -543,7 +667,7 @@ def segment_sum_csr_abstract(
|
|
|
543
667
|
|
|
544
668
|
def dense_to_jagged_forward(
|
|
545
669
|
dense: torch.Tensor,
|
|
546
|
-
offsets:
|
|
670
|
+
offsets: list[torch.Tensor],
|
|
547
671
|
total_L: Optional[torch.SymInt] = None,
|
|
548
672
|
) -> torch.Tensor:
|
|
549
673
|
if total_L is None:
|
|
@@ -558,9 +682,9 @@ def dense_to_jagged_forward(
|
|
|
558
682
|
|
|
559
683
|
def dense_to_jagged(
|
|
560
684
|
dense: torch.Tensor,
|
|
561
|
-
offsets:
|
|
685
|
+
offsets: list[torch.Tensor],
|
|
562
686
|
total_L: Optional[torch.SymInt] = None,
|
|
563
|
-
) ->
|
|
687
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
564
688
|
if total_L is None:
|
|
565
689
|
total_L = torch.library.get_ctx().new_dynamic_size()
|
|
566
690
|
return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
|
|
@@ -569,9 +693,9 @@ def dense_to_jagged(
|
|
|
569
693
|
def batch_index_select_dim0_abstract(
|
|
570
694
|
inputs: torch.Tensor,
|
|
571
695
|
indices: torch.Tensor,
|
|
572
|
-
input_num_indices:
|
|
573
|
-
input_rows:
|
|
574
|
-
input_columns:
|
|
696
|
+
input_num_indices: list[int],
|
|
697
|
+
input_rows: list[int],
|
|
698
|
+
input_columns: list[int],
|
|
575
699
|
permute_output_dim_0_1: bool,
|
|
576
700
|
) -> torch.Tensor:
|
|
577
701
|
"""
|
|
@@ -613,11 +737,11 @@ def batch_index_select_dim0_tensor_abstract(
|
|
|
613
737
|
def batch_index_select_dim0_forward_cuda_impl_abstract(
|
|
614
738
|
inputs: torch.Tensor,
|
|
615
739
|
indices: torch.Tensor,
|
|
616
|
-
input_num_indices:
|
|
617
|
-
input_rows:
|
|
618
|
-
input_columns:
|
|
740
|
+
input_num_indices: list[int],
|
|
741
|
+
input_rows: list[int],
|
|
742
|
+
input_columns: list[int],
|
|
619
743
|
permute_output_dim_0_1: bool,
|
|
620
|
-
) ->
|
|
744
|
+
) -> list[torch.Tensor]:
|
|
621
745
|
num_inputs = len(input_rows)
|
|
622
746
|
torch._check(len(input_num_indices) == len(input_rows))
|
|
623
747
|
torch._check(len(input_num_indices) == len(input_columns))
|
|
@@ -654,7 +778,7 @@ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
|
|
|
654
778
|
input_rows: torch.Tensor,
|
|
655
779
|
input_columns: torch.Tensor,
|
|
656
780
|
permute_output_dim_0_1: bool,
|
|
657
|
-
) ->
|
|
781
|
+
) -> list[torch.Tensor]:
|
|
658
782
|
num_inputs: int = input_rows.size(0)
|
|
659
783
|
torch._check(input_num_indices.size(0) == input_rows.size(0))
|
|
660
784
|
torch._check(input_num_indices.size(0) == input_columns.size(0))
|
|
@@ -699,7 +823,7 @@ def keyed_jagged_index_select_dim1_abstract(
|
|
|
699
823
|
batch_size: torch.SymInt,
|
|
700
824
|
weights: Optional[torch.Tensor] = None,
|
|
701
825
|
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
702
|
-
) ->
|
|
826
|
+
) -> list[torch.Tensor]:
|
|
703
827
|
"""
|
|
704
828
|
This meta function is used to calculate the shape of output tensors
|
|
705
829
|
from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
|
|
@@ -724,7 +848,7 @@ def keyed_jagged_index_select_dim1_abstract(
|
|
|
724
848
|
torch.index_select(lengths, 0, length_indices).sum().item()
|
|
725
849
|
)
|
|
726
850
|
|
|
727
|
-
ret:
|
|
851
|
+
ret: list[torch.Tensor] = [
|
|
728
852
|
# pyre-ignore
|
|
729
853
|
values.new_empty([selected_lengths_sum]),
|
|
730
854
|
lengths.new_empty([indices.shape[0] * num_batches]),
|
|
@@ -756,17 +880,17 @@ def batch_index_select_dim0_backward_cuda_impl_abstract(
|
|
|
756
880
|
def batch_index_select_dim0_forward_cpu_impl_abstract(
|
|
757
881
|
inputs: torch.Tensor,
|
|
758
882
|
indices: torch.Tensor,
|
|
759
|
-
input_num_indices:
|
|
760
|
-
input_rows:
|
|
761
|
-
input_columns:
|
|
883
|
+
input_num_indices: list[int],
|
|
884
|
+
input_rows: list[int],
|
|
885
|
+
input_columns: list[int],
|
|
762
886
|
permute_output_dim_0_1: bool,
|
|
763
|
-
) ->
|
|
887
|
+
) -> list[torch.Tensor]:
|
|
764
888
|
# input lists must have the same length
|
|
765
889
|
num_inputs = len(input_num_indices)
|
|
766
890
|
torch._check(num_inputs == len(input_rows))
|
|
767
891
|
torch._check(num_inputs == len(input_columns))
|
|
768
892
|
|
|
769
|
-
if permute_output_dim_0_1 and
|
|
893
|
+
if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
|
|
770
894
|
# All num_indices must be the same if permute_output_dim_0_1 is True
|
|
771
895
|
for x in input_num_indices:
|
|
772
896
|
torch._check(x == input_num_indices[0])
|
|
@@ -790,7 +914,7 @@ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
|
|
|
790
914
|
input_rows: torch.Tensor,
|
|
791
915
|
input_columns: torch.Tensor,
|
|
792
916
|
permute_output_dim_0_1: bool,
|
|
793
|
-
) ->
|
|
917
|
+
) -> list[torch.Tensor]:
|
|
794
918
|
# input lists must have the same length
|
|
795
919
|
num_inputs = len(input_num_indices)
|
|
796
920
|
torch._check(num_inputs == len(input_rows))
|
|
@@ -829,6 +953,8 @@ def bounds_check_indices_abstract(
|
|
|
829
953
|
b_t_map: Optional[torch.Tensor] = None,
|
|
830
954
|
info_B_num_bits: int = -1,
|
|
831
955
|
info_B_mask: int = -1,
|
|
956
|
+
bounds_check_version: int = 1,
|
|
957
|
+
prefetch_pipeline: bool = False,
|
|
832
958
|
) -> None:
|
|
833
959
|
"""
|
|
834
960
|
This meta function is used to fake the bounds checking
|
|
@@ -838,8 +964,8 @@ def bounds_check_indices_abstract(
|
|
|
838
964
|
|
|
839
965
|
|
|
840
966
|
def group_index_select_dim0_gpu_impl_abstract(
|
|
841
|
-
inputs:
|
|
842
|
-
) ->
|
|
967
|
+
inputs: list[torch.Tensor], group_size: int
|
|
968
|
+
) -> list[torch.Tensor]:
|
|
843
969
|
"""
|
|
844
970
|
Calculate output shapes for group_index_select_dim0_gpu_impl
|
|
845
971
|
without the actual data.
|
|
@@ -869,8 +995,8 @@ def group_index_select_dim0_gpu_impl_abstract(
|
|
|
869
995
|
|
|
870
996
|
|
|
871
997
|
def group_index_select_dim0_gpu_backward_abstract(
|
|
872
|
-
all_inputs:
|
|
873
|
-
) ->
|
|
998
|
+
all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
|
|
999
|
+
) -> list[torch.Tensor]:
|
|
874
1000
|
"""
|
|
875
1001
|
Calculate output shapes for group_index_select_dim0_gpu_backward
|
|
876
1002
|
without the actual data.
|
|
@@ -903,7 +1029,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
|
|
|
903
1029
|
batch_size: torch.SymInt,
|
|
904
1030
|
weights: Optional[torch.Tensor] = None,
|
|
905
1031
|
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
906
|
-
) ->
|
|
1032
|
+
) -> list[torch.Tensor]:
|
|
907
1033
|
num_batches = lengths.size(0) // batch_size
|
|
908
1034
|
torch._check(lengths.size(0) + 1 == offsets.size(0))
|
|
909
1035
|
# pyre-ignore
|
|
@@ -917,7 +1043,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
|
|
|
917
1043
|
selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
|
|
918
1044
|
|
|
919
1045
|
torch._check_is_size(selected_lengths_sum)
|
|
920
|
-
vlw:
|
|
1046
|
+
vlw: list[torch.Tensor] = [
|
|
921
1047
|
values.new_empty([selected_lengths_sum]), # output
|
|
922
1048
|
lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
|
|
923
1049
|
]
|
|
@@ -960,7 +1086,7 @@ def histogram_binning_calibration_abstract(
|
|
|
960
1086
|
upper_bound: float,
|
|
961
1087
|
bin_ctr_in_use_after: int,
|
|
962
1088
|
bin_ctr_weight_value: float,
|
|
963
|
-
) ->
|
|
1089
|
+
) -> tuple[Tensor, Tensor]:
|
|
964
1090
|
return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
|
|
965
1091
|
|
|
966
1092
|
|
|
@@ -1111,7 +1237,7 @@ def generic_histogram_binning_calibration_by_feature(
|
|
|
1111
1237
|
positive_weight: float,
|
|
1112
1238
|
bin_ctr_in_use_after: int,
|
|
1113
1239
|
bin_ctr_weight_value: float,
|
|
1114
|
-
) ->
|
|
1240
|
+
) -> tuple[Tensor, Tensor]:
|
|
1115
1241
|
torch._check(bin_num_examples.numel() == bin_num_positives.numel())
|
|
1116
1242
|
torch._check(
|
|
1117
1243
|
bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
|
|
@@ -1121,6 +1247,58 @@ def generic_histogram_binning_calibration_by_feature(
|
|
|
1121
1247
|
)
|
|
1122
1248
|
|
|
1123
1249
|
|
|
1250
|
+
def permute_multi_embedding_function_impl_abstract(
|
|
1251
|
+
pooled_embs: list[Tensor],
|
|
1252
|
+
permutes: Tensor,
|
|
1253
|
+
in_shapes: Tensor,
|
|
1254
|
+
out_shapes: Tensor,
|
|
1255
|
+
out_lengths: list[int],
|
|
1256
|
+
reverse: bool = False,
|
|
1257
|
+
) -> list[Tensor]:
|
|
1258
|
+
out_dtype = pooled_embs[0].dtype
|
|
1259
|
+
bs = pooled_embs[0].shape[0]
|
|
1260
|
+
torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
|
|
1261
|
+
|
|
1262
|
+
output = []
|
|
1263
|
+
for i in range(len(out_lengths)):
|
|
1264
|
+
output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
|
|
1265
|
+
return output
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def lengths_range_abstract(
|
|
1269
|
+
lengths: Tensor,
|
|
1270
|
+
output_shape: Optional[Sequence[int]] = None,
|
|
1271
|
+
) -> Tensor:
|
|
1272
|
+
torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
|
|
1273
|
+
output_size = 0
|
|
1274
|
+
if output_shape is not None:
|
|
1275
|
+
output_size = math.prod(output_shape)
|
|
1276
|
+
else:
|
|
1277
|
+
ctx = torch.library.get_ctx()
|
|
1278
|
+
output_size = ctx.new_dynamic_size()
|
|
1279
|
+
return lengths.new_empty([output_size], dtype=lengths.dtype)
|
|
1280
|
+
|
|
1281
|
+
|
|
1282
|
+
def all_to_one_device(
|
|
1283
|
+
input_tensors: list[Tensor],
|
|
1284
|
+
target_device: torch.device,
|
|
1285
|
+
) -> list[Tensor]:
|
|
1286
|
+
return [
|
|
1287
|
+
torch.empty_like(input_tensor, device=torch.device("meta"))
|
|
1288
|
+
for input_tensor in input_tensors
|
|
1289
|
+
]
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
def sum_reduce_to_one(
|
|
1293
|
+
input_tensors: list[Tensor],
|
|
1294
|
+
target_device: torch.device,
|
|
1295
|
+
) -> Tensor:
|
|
1296
|
+
torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
|
|
1297
|
+
# All tensors should have the same shape
|
|
1298
|
+
first_tensor = input_tensors[0]
|
|
1299
|
+
return torch.empty_like(first_tensor, device=torch.device("meta"))
|
|
1300
|
+
|
|
1301
|
+
|
|
1124
1302
|
def _setup() -> None:
|
|
1125
1303
|
# pyre-ignore[16]
|
|
1126
1304
|
_setup.done = getattr(_setup, "done", False)
|
|
@@ -1153,6 +1331,8 @@ def _setup() -> None:
|
|
|
1153
1331
|
)
|
|
1154
1332
|
|
|
1155
1333
|
impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
|
|
1334
|
+
impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
|
|
1335
|
+
impl_abstract("fbgemm::repeat_arange", repeat_arange_meta)
|
|
1156
1336
|
impl_abstract(
|
|
1157
1337
|
"fbgemm::permute_2D_sparse_data_input1D",
|
|
1158
1338
|
permute_2D_sparse_data_input1D_meta,
|
|
@@ -1185,12 +1365,18 @@ def _setup() -> None:
|
|
|
1185
1365
|
"fbgemm::block_bucketize_sparse_features",
|
|
1186
1366
|
block_bucketize_sparse_features_meta,
|
|
1187
1367
|
)
|
|
1368
|
+
impl_abstract(
|
|
1369
|
+
"fbgemm::block_bucketize_sparse_features_2d_weights",
|
|
1370
|
+
block_bucketize_sparse_features_2d_weights_meta,
|
|
1371
|
+
)
|
|
1188
1372
|
impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
|
|
1189
1373
|
impl_abstract(
|
|
1190
1374
|
"fbgemm::permute_sparse_features", permute_sparse_features_abstract
|
|
1191
1375
|
)
|
|
1192
1376
|
impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
|
|
1193
1377
|
impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
|
|
1378
|
+
impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
|
|
1379
|
+
impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
|
|
1194
1380
|
impl_abstract(
|
|
1195
1381
|
"fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
|
|
1196
1382
|
)
|
|
@@ -1258,6 +1444,14 @@ def _setup() -> None:
|
|
|
1258
1444
|
"fbgemm::generic_histogram_binning_calibration_by_feature",
|
|
1259
1445
|
generic_histogram_binning_calibration_by_feature,
|
|
1260
1446
|
)
|
|
1447
|
+
impl_abstract(
|
|
1448
|
+
"fbgemm::lengths_range",
|
|
1449
|
+
lengths_range_abstract,
|
|
1450
|
+
)
|
|
1451
|
+
impl_abstract(
|
|
1452
|
+
"fbgemm::permute_multi_embedding_function",
|
|
1453
|
+
permute_multi_embedding_function_impl_abstract,
|
|
1454
|
+
)
|
|
1261
1455
|
impl_abstract(
|
|
1262
1456
|
"fbgemm::FloatToHFP8Quantized",
|
|
1263
1457
|
float_to_hfp8_quantized,
|
|
@@ -1302,29 +1496,3 @@ def _setup() -> None:
|
|
|
1302
1496
|
|
|
1303
1497
|
|
|
1304
1498
|
_setup()
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
@torch.library.register_fake("fbgemm::lengths_range")
|
|
1308
|
-
def lengths_range_abstract(
|
|
1309
|
-
lengths: Tensor,
|
|
1310
|
-
output_shape: Optional[Sequence[int]] = None,
|
|
1311
|
-
) -> Tensor:
|
|
1312
|
-
torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
|
|
1313
|
-
output_size = 0
|
|
1314
|
-
if output_shape is not None:
|
|
1315
|
-
output_size = math.prod(output_shape)
|
|
1316
|
-
else:
|
|
1317
|
-
ctx = torch.library.get_ctx()
|
|
1318
|
-
output_size = ctx.new_dynamic_size()
|
|
1319
|
-
return lengths.new_empty([output_size], dtype=lengths.dtype)
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
@torch.library.register_fake("fbgemm::all_to_one_device")
|
|
1323
|
-
def all_to_one_device(
|
|
1324
|
-
input_tensors: List[Tensor],
|
|
1325
|
-
target_device: torch.device,
|
|
1326
|
-
) -> List[Tensor]:
|
|
1327
|
-
return [
|
|
1328
|
-
torch.empty_like(input_tensor, device=torch.device("meta"))
|
|
1329
|
-
for input_tensor in input_tensors
|
|
1330
|
-
]
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/__init__.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/__init__.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
@@ -150,6 +152,30 @@ except:
|
|
|
150
152
|
DeprecationWarning,
|
|
151
153
|
)
|
|
152
154
|
|
|
155
|
+
try:
|
|
156
|
+
# Import is placed under a try-except bc the op is experimental and can be
|
|
157
|
+
# removed/updated in the future
|
|
158
|
+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd as lookup_adam_ssd # noqa: F401
|
|
159
|
+
except:
|
|
160
|
+
warnings.warn(
|
|
161
|
+
f"""\033[93m
|
|
162
|
+
Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd
|
|
163
|
+
\033[0m""",
|
|
164
|
+
DeprecationWarning,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
# Import is placed under a try-except bc the op is experimental and can be
|
|
169
|
+
# removed/updated in the future
|
|
170
|
+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam_ssd as lookup_partial_rowwise_adam_ssd # noqa: F401
|
|
171
|
+
except:
|
|
172
|
+
warnings.warn(
|
|
173
|
+
f"""\033[93m
|
|
174
|
+
Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam_ssd
|
|
175
|
+
\033[0m""",
|
|
176
|
+
DeprecationWarning,
|
|
177
|
+
)
|
|
178
|
+
|
|
153
179
|
try:
|
|
154
180
|
# Import is placed under a try-except bc the op is experimental and can be
|
|
155
181
|
# removed/updated in the future
|