fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/sparse_ops.py
CHANGED
|
@@ -7,10 +7,12 @@
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
9
|
import math
|
|
10
|
-
from
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from typing import Callable, Optional
|
|
11
12
|
|
|
12
13
|
import torch
|
|
13
14
|
|
|
15
|
+
# fmt:skip
|
|
14
16
|
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
15
17
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
|
|
16
18
|
from fbgemm_gpu.utils.loader import load_torch_module
|
|
@@ -48,8 +50,7 @@ except Exception:
|
|
|
48
50
|
|
|
49
51
|
import torch.utils._pytree as pytree
|
|
50
52
|
from torch import SymInt, Tensor
|
|
51
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
52
|
-
|
|
53
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
53
54
|
|
|
54
55
|
if hasattr(torch.library, "register_fake"):
|
|
55
56
|
# pyre-ignore[9]
|
|
@@ -74,7 +75,7 @@ def permute_2D_sparse_data_input1D_meta(
|
|
|
74
75
|
stride: int,
|
|
75
76
|
weights: Optional[Tensor] = None,
|
|
76
77
|
permuted_lengths_sum: Optional[int] = None,
|
|
77
|
-
) ->
|
|
78
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
78
79
|
torch._check(
|
|
79
80
|
lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
|
|
80
81
|
)
|
|
@@ -111,7 +112,7 @@ def permute_2D_sparse_data_input1D_backward(
|
|
|
111
112
|
grad_lengths: torch.Tensor,
|
|
112
113
|
grad_values: torch.Tensor,
|
|
113
114
|
grad_weights: torch.Tensor,
|
|
114
|
-
) ->
|
|
115
|
+
) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
|
|
115
116
|
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
|
|
116
117
|
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
|
|
117
118
|
torch.ops.fbgemm.permute_2D_sparse_data_input1D(
|
|
@@ -139,7 +140,7 @@ def permute_2D_sparse_data_meta(
|
|
|
139
140
|
values: Tensor,
|
|
140
141
|
weights: Optional[Tensor] = None,
|
|
141
142
|
permuted_lengths_sum: Optional[int] = None,
|
|
142
|
-
) ->
|
|
143
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
143
144
|
torch._check(
|
|
144
145
|
lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
|
|
145
146
|
)
|
|
@@ -166,6 +167,89 @@ def invert_permute_abstract(permute: Tensor) -> Tensor:
|
|
|
166
167
|
return torch.empty_like(permute)
|
|
167
168
|
|
|
168
169
|
|
|
170
|
+
def get_source_mask_meta(
|
|
171
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
172
|
+
) -> Tensor:
|
|
173
|
+
if output_size is None:
|
|
174
|
+
ctx = torch.library.get_ctx()
|
|
175
|
+
output_size = ctx.new_dynamic_size()
|
|
176
|
+
return torch.empty([output_size], dtype=torch.bool)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_source_mask(
|
|
180
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
"""
|
|
183
|
+
Generate a boolean mask indicating which elements are from sources vs targets.
|
|
184
|
+
|
|
185
|
+
This is a Python wrapper that computes output_size when not provided,
|
|
186
|
+
enabling the operation to work with meta tensors for compilation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
num_sources: 1D tensor of source counts per batch element
|
|
190
|
+
num_targets: 1D tensor of target counts per batch element
|
|
191
|
+
output_size: Optional pre-computed output size.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
A 1D boolean tensor where True indicates source elements and False
|
|
195
|
+
indicates target elements
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> num_sources = torch.tensor([2, 3])
|
|
199
|
+
>>> num_targets = torch.tensor([1, 2])
|
|
200
|
+
>>> get_source_mask(num_sources, num_targets)
|
|
201
|
+
tensor([True, True, False, True, True, True, False, False])
|
|
202
|
+
"""
|
|
203
|
+
# Compute output_size if not provided and tensors are regular (not meta/fake)
|
|
204
|
+
if output_size is None:
|
|
205
|
+
combined = num_sources + num_targets
|
|
206
|
+
output_size = int(combined.sum().item())
|
|
207
|
+
|
|
208
|
+
return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def repeat_arange_meta(lengths: Tensor) -> Tensor:
|
|
212
|
+
"""Meta implementation for repeat_arange."""
|
|
213
|
+
# Output size is data-dependent (sum of lengths).
|
|
214
|
+
# For FakeTensors (used in torch.compile), we use dynamic sizing.
|
|
215
|
+
# For actual meta tensors, we cannot determine the size so return empty.
|
|
216
|
+
if lengths.device.type == "meta":
|
|
217
|
+
# Actual meta tensors: return a zero-sized tensor as placeholder
|
|
218
|
+
# since we cannot compute the data-dependent output size
|
|
219
|
+
return torch.empty([0], dtype=lengths.dtype, device=lengths.device)
|
|
220
|
+
else:
|
|
221
|
+
# FakeTensor context: use dynamic sizing for proper shape tracking
|
|
222
|
+
ctx = torch.library.get_ctx()
|
|
223
|
+
output_size = ctx.new_dynamic_size()
|
|
224
|
+
return torch.empty([output_size], dtype=lengths.dtype, device=lengths.device)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def repeat_arange(lengths: Tensor) -> Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Creates a concatenated tensor of aranges based on a lengths tensor.
|
|
230
|
+
|
|
231
|
+
This is a high-performance CUDA kernel that replaces the inefficient PyTorch
|
|
232
|
+
implementation which uses 4+ separate kernels (cumsum, arange, repeat_interleave, sub).
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
lengths: 1D tensor of lengths for each arange sequence
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A 1D tensor containing concatenated arange sequences
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
>>> lengths = torch.tensor([3, 5, 2])
|
|
242
|
+
>>> repeat_arange(lengths)
|
|
243
|
+
tensor([0, 1, 2, 0, 1, 2, 3, 4, 0, 1])
|
|
244
|
+
|
|
245
|
+
Performance:
|
|
246
|
+
- PyTorch implementation: 4+ kernel launches + intermediate allocations
|
|
247
|
+
- CUDA implementation: 1 fused kernel, no intermediate allocations
|
|
248
|
+
- Typical speedup: 3-5x on realistic workloads
|
|
249
|
+
"""
|
|
250
|
+
return torch.ops.fbgemm.repeat_arange(lengths)
|
|
251
|
+
|
|
252
|
+
|
|
169
253
|
# pyre-ignore
|
|
170
254
|
def permute_2D_sparse_data_setup_context(ctx, inputs, output):
|
|
171
255
|
permute, lengths, values, weights, permuted_lengths_sum = inputs
|
|
@@ -197,7 +281,7 @@ def permute_1D_sparse_data_meta(
|
|
|
197
281
|
values: Tensor,
|
|
198
282
|
weights: Optional[Tensor] = None,
|
|
199
283
|
permuted_lengths_sum: Optional[int] = None,
|
|
200
|
-
) ->
|
|
284
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
201
285
|
indices = values
|
|
202
286
|
permuted_lengths_size = permute.numel()
|
|
203
287
|
permuted_lengths = lengths.new_empty([permuted_lengths_size])
|
|
@@ -218,7 +302,7 @@ def permute_1D_sparse_data_meta(
|
|
|
218
302
|
|
|
219
303
|
def masked_select_jagged_1d(
|
|
220
304
|
values: Tensor, lengths: Tensor, mask: Tensor
|
|
221
|
-
) ->
|
|
305
|
+
) -> tuple[Tensor, Tensor]:
|
|
222
306
|
torch._check(values.dim() == 1)
|
|
223
307
|
torch._check(lengths.dim() == 1)
|
|
224
308
|
torch._check(values.device == lengths.device)
|
|
@@ -231,11 +315,11 @@ def masked_select_jagged_1d(
|
|
|
231
315
|
|
|
232
316
|
|
|
233
317
|
def tbe_input_combine_abstract(
|
|
234
|
-
indices_list:
|
|
235
|
-
offsets_list:
|
|
236
|
-
per_sample_weights:
|
|
318
|
+
indices_list: list[Tensor],
|
|
319
|
+
offsets_list: list[Tensor],
|
|
320
|
+
per_sample_weights: list[Tensor],
|
|
237
321
|
include_last_offsets: Tensor,
|
|
238
|
-
) ->
|
|
322
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
239
323
|
torch._check(len(indices_list) > 0)
|
|
240
324
|
torch._check(len(indices_list) == len(offsets_list))
|
|
241
325
|
torch._check(len(indices_list) == len(per_sample_weights))
|
|
@@ -250,7 +334,7 @@ def tbe_input_combine_abstract(
|
|
|
250
334
|
torch._check(index.is_contiguous())
|
|
251
335
|
torch._check(offset.is_contiguous())
|
|
252
336
|
total_indices = total_indices + index.numel()
|
|
253
|
-
if
|
|
337
|
+
if guard_or_true(weight.numel() > 0):
|
|
254
338
|
torch._check(weight.dim() == 1)
|
|
255
339
|
torch._check(weight.numel() == index.numel())
|
|
256
340
|
torch._check(weight.is_contiguous())
|
|
@@ -268,10 +352,10 @@ def tbe_input_combine_abstract(
|
|
|
268
352
|
|
|
269
353
|
|
|
270
354
|
def tbe_input_combine_with_length_abstract(
|
|
271
|
-
indices_list:
|
|
272
|
-
offsets_list:
|
|
273
|
-
per_sample_weights:
|
|
274
|
-
) ->
|
|
355
|
+
indices_list: list[Tensor],
|
|
356
|
+
offsets_list: list[Tensor],
|
|
357
|
+
per_sample_weights: list[Tensor],
|
|
358
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
275
359
|
torch._check(len(indices_list) > 0)
|
|
276
360
|
torch._check(len(indices_list) == len(offsets_list))
|
|
277
361
|
torch._check(len(indices_list) == len(per_sample_weights))
|
|
@@ -287,7 +371,7 @@ def tbe_input_combine_with_length_abstract(
|
|
|
287
371
|
torch._check(offset.is_contiguous())
|
|
288
372
|
total_indices = total_indices + index.numel()
|
|
289
373
|
total_offsets = total_offsets + offset.numel()
|
|
290
|
-
if
|
|
374
|
+
if guard_or_true(weight.numel() > 0):
|
|
291
375
|
torch._check(weight.dim() == 1)
|
|
292
376
|
torch._check(weight.numel() == index.numel())
|
|
293
377
|
torch._check(weight.is_contiguous())
|
|
@@ -339,7 +423,7 @@ def expand_into_jagged_permute_meta(
|
|
|
339
423
|
permute: Tensor,
|
|
340
424
|
input_offsets: Tensor,
|
|
341
425
|
output_offsets: Tensor,
|
|
342
|
-
output_size:
|
|
426
|
+
output_size: tuple[int, ...],
|
|
343
427
|
) -> Tensor:
|
|
344
428
|
torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
|
|
345
429
|
torch._check(
|
|
@@ -465,7 +549,7 @@ def block_bucketize_sparse_features_meta(
|
|
|
465
549
|
keep_orig_idx: bool = False,
|
|
466
550
|
total_num_blocks: Optional[torch.Tensor] = None,
|
|
467
551
|
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
468
|
-
) ->
|
|
552
|
+
) -> tuple[
|
|
469
553
|
torch.Tensor,
|
|
470
554
|
torch.Tensor,
|
|
471
555
|
Optional[torch.Tensor],
|
|
@@ -485,8 +569,43 @@ def block_bucketize_sparse_features_meta(
|
|
|
485
569
|
)
|
|
486
570
|
|
|
487
571
|
|
|
572
|
+
def block_bucketize_sparse_features_2d_weights_meta(
|
|
573
|
+
lengths: torch.Tensor,
|
|
574
|
+
indices: torch.Tensor,
|
|
575
|
+
bucketize_pos: bool,
|
|
576
|
+
sequence: bool,
|
|
577
|
+
block_sizes: torch.Tensor,
|
|
578
|
+
my_size: int,
|
|
579
|
+
weights: torch.Tensor,
|
|
580
|
+
weights_dim: int = 1,
|
|
581
|
+
batch_size_per_feature: Optional[torch.Tensor] = None,
|
|
582
|
+
max_B: int = -1,
|
|
583
|
+
block_bucketize_pos: Optional[torch.Tensor] = None,
|
|
584
|
+
keep_orig_idx: bool = False,
|
|
585
|
+
total_num_blocks: Optional[torch.Tensor] = None,
|
|
586
|
+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
587
|
+
) -> tuple[
|
|
588
|
+
torch.Tensor,
|
|
589
|
+
torch.Tensor,
|
|
590
|
+
torch.Tensor,
|
|
591
|
+
Optional[torch.Tensor],
|
|
592
|
+
Optional[torch.Tensor],
|
|
593
|
+
]:
|
|
594
|
+
# Output: lengths, indices, weights", pos?, unbucketize_permute?
|
|
595
|
+
num_buckets = my_size
|
|
596
|
+
num_features = lengths.size(0)
|
|
597
|
+
num_values = indices.size(0)
|
|
598
|
+
return (
|
|
599
|
+
lengths.new_empty([num_buckets * num_features]),
|
|
600
|
+
indices.new_empty([num_values]),
|
|
601
|
+
weights.new_empty([num_values, weights_dim]),
|
|
602
|
+
indices.new_empty([num_values]) if bucketize_pos else None,
|
|
603
|
+
indices.new_empty([num_values]),
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
|
|
488
607
|
def merge_pooled_embeddings(
|
|
489
|
-
pooled_embeddings:
|
|
608
|
+
pooled_embeddings: list[torch.Tensor],
|
|
490
609
|
uncat_dim_size: int,
|
|
491
610
|
target_device: torch.device,
|
|
492
611
|
cat_dim: int = 1,
|
|
@@ -517,7 +636,7 @@ def merge_pooled_embeddings(
|
|
|
517
636
|
|
|
518
637
|
def permute_sparse_features_abstract(
|
|
519
638
|
permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
|
|
520
|
-
) ->
|
|
639
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
521
640
|
torch._check(lengths.dtype == indices.dtype)
|
|
522
641
|
torch._check(permute.device == lengths.device)
|
|
523
642
|
torch._check(permute.device == indices.device)
|
|
@@ -548,7 +667,7 @@ def segment_sum_csr_abstract(
|
|
|
548
667
|
|
|
549
668
|
def dense_to_jagged_forward(
|
|
550
669
|
dense: torch.Tensor,
|
|
551
|
-
offsets:
|
|
670
|
+
offsets: list[torch.Tensor],
|
|
552
671
|
total_L: Optional[torch.SymInt] = None,
|
|
553
672
|
) -> torch.Tensor:
|
|
554
673
|
if total_L is None:
|
|
@@ -563,9 +682,9 @@ def dense_to_jagged_forward(
|
|
|
563
682
|
|
|
564
683
|
def dense_to_jagged(
|
|
565
684
|
dense: torch.Tensor,
|
|
566
|
-
offsets:
|
|
685
|
+
offsets: list[torch.Tensor],
|
|
567
686
|
total_L: Optional[torch.SymInt] = None,
|
|
568
|
-
) ->
|
|
687
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
569
688
|
if total_L is None:
|
|
570
689
|
total_L = torch.library.get_ctx().new_dynamic_size()
|
|
571
690
|
return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
|
|
@@ -574,9 +693,9 @@ def dense_to_jagged(
|
|
|
574
693
|
def batch_index_select_dim0_abstract(
|
|
575
694
|
inputs: torch.Tensor,
|
|
576
695
|
indices: torch.Tensor,
|
|
577
|
-
input_num_indices:
|
|
578
|
-
input_rows:
|
|
579
|
-
input_columns:
|
|
696
|
+
input_num_indices: list[int],
|
|
697
|
+
input_rows: list[int],
|
|
698
|
+
input_columns: list[int],
|
|
580
699
|
permute_output_dim_0_1: bool,
|
|
581
700
|
) -> torch.Tensor:
|
|
582
701
|
"""
|
|
@@ -618,11 +737,11 @@ def batch_index_select_dim0_tensor_abstract(
|
|
|
618
737
|
def batch_index_select_dim0_forward_cuda_impl_abstract(
|
|
619
738
|
inputs: torch.Tensor,
|
|
620
739
|
indices: torch.Tensor,
|
|
621
|
-
input_num_indices:
|
|
622
|
-
input_rows:
|
|
623
|
-
input_columns:
|
|
740
|
+
input_num_indices: list[int],
|
|
741
|
+
input_rows: list[int],
|
|
742
|
+
input_columns: list[int],
|
|
624
743
|
permute_output_dim_0_1: bool,
|
|
625
|
-
) ->
|
|
744
|
+
) -> list[torch.Tensor]:
|
|
626
745
|
num_inputs = len(input_rows)
|
|
627
746
|
torch._check(len(input_num_indices) == len(input_rows))
|
|
628
747
|
torch._check(len(input_num_indices) == len(input_columns))
|
|
@@ -659,7 +778,7 @@ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
|
|
|
659
778
|
input_rows: torch.Tensor,
|
|
660
779
|
input_columns: torch.Tensor,
|
|
661
780
|
permute_output_dim_0_1: bool,
|
|
662
|
-
) ->
|
|
781
|
+
) -> list[torch.Tensor]:
|
|
663
782
|
num_inputs: int = input_rows.size(0)
|
|
664
783
|
torch._check(input_num_indices.size(0) == input_rows.size(0))
|
|
665
784
|
torch._check(input_num_indices.size(0) == input_columns.size(0))
|
|
@@ -704,7 +823,7 @@ def keyed_jagged_index_select_dim1_abstract(
|
|
|
704
823
|
batch_size: torch.SymInt,
|
|
705
824
|
weights: Optional[torch.Tensor] = None,
|
|
706
825
|
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
707
|
-
) ->
|
|
826
|
+
) -> list[torch.Tensor]:
|
|
708
827
|
"""
|
|
709
828
|
This meta function is used to calculate the shape of output tensors
|
|
710
829
|
from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
|
|
@@ -729,7 +848,7 @@ def keyed_jagged_index_select_dim1_abstract(
|
|
|
729
848
|
torch.index_select(lengths, 0, length_indices).sum().item()
|
|
730
849
|
)
|
|
731
850
|
|
|
732
|
-
ret:
|
|
851
|
+
ret: list[torch.Tensor] = [
|
|
733
852
|
# pyre-ignore
|
|
734
853
|
values.new_empty([selected_lengths_sum]),
|
|
735
854
|
lengths.new_empty([indices.shape[0] * num_batches]),
|
|
@@ -761,17 +880,17 @@ def batch_index_select_dim0_backward_cuda_impl_abstract(
|
|
|
761
880
|
def batch_index_select_dim0_forward_cpu_impl_abstract(
|
|
762
881
|
inputs: torch.Tensor,
|
|
763
882
|
indices: torch.Tensor,
|
|
764
|
-
input_num_indices:
|
|
765
|
-
input_rows:
|
|
766
|
-
input_columns:
|
|
883
|
+
input_num_indices: list[int],
|
|
884
|
+
input_rows: list[int],
|
|
885
|
+
input_columns: list[int],
|
|
767
886
|
permute_output_dim_0_1: bool,
|
|
768
|
-
) ->
|
|
887
|
+
) -> list[torch.Tensor]:
|
|
769
888
|
# input lists must have the same length
|
|
770
889
|
num_inputs = len(input_num_indices)
|
|
771
890
|
torch._check(num_inputs == len(input_rows))
|
|
772
891
|
torch._check(num_inputs == len(input_columns))
|
|
773
892
|
|
|
774
|
-
if permute_output_dim_0_1 and
|
|
893
|
+
if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
|
|
775
894
|
# All num_indices must be the same if permute_output_dim_0_1 is True
|
|
776
895
|
for x in input_num_indices:
|
|
777
896
|
torch._check(x == input_num_indices[0])
|
|
@@ -795,7 +914,7 @@ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
|
|
|
795
914
|
input_rows: torch.Tensor,
|
|
796
915
|
input_columns: torch.Tensor,
|
|
797
916
|
permute_output_dim_0_1: bool,
|
|
798
|
-
) ->
|
|
917
|
+
) -> list[torch.Tensor]:
|
|
799
918
|
# input lists must have the same length
|
|
800
919
|
num_inputs = len(input_num_indices)
|
|
801
920
|
torch._check(num_inputs == len(input_rows))
|
|
@@ -845,8 +964,8 @@ def bounds_check_indices_abstract(
|
|
|
845
964
|
|
|
846
965
|
|
|
847
966
|
def group_index_select_dim0_gpu_impl_abstract(
|
|
848
|
-
inputs:
|
|
849
|
-
) ->
|
|
967
|
+
inputs: list[torch.Tensor], group_size: int
|
|
968
|
+
) -> list[torch.Tensor]:
|
|
850
969
|
"""
|
|
851
970
|
Calculate output shapes for group_index_select_dim0_gpu_impl
|
|
852
971
|
without the actual data.
|
|
@@ -876,8 +995,8 @@ def group_index_select_dim0_gpu_impl_abstract(
|
|
|
876
995
|
|
|
877
996
|
|
|
878
997
|
def group_index_select_dim0_gpu_backward_abstract(
|
|
879
|
-
all_inputs:
|
|
880
|
-
) ->
|
|
998
|
+
all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
|
|
999
|
+
) -> list[torch.Tensor]:
|
|
881
1000
|
"""
|
|
882
1001
|
Calculate output shapes for group_index_select_dim0_gpu_backward
|
|
883
1002
|
without the actual data.
|
|
@@ -910,7 +1029,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
|
|
|
910
1029
|
batch_size: torch.SymInt,
|
|
911
1030
|
weights: Optional[torch.Tensor] = None,
|
|
912
1031
|
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
913
|
-
) ->
|
|
1032
|
+
) -> list[torch.Tensor]:
|
|
914
1033
|
num_batches = lengths.size(0) // batch_size
|
|
915
1034
|
torch._check(lengths.size(0) + 1 == offsets.size(0))
|
|
916
1035
|
# pyre-ignore
|
|
@@ -924,7 +1043,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
|
|
|
924
1043
|
selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
|
|
925
1044
|
|
|
926
1045
|
torch._check_is_size(selected_lengths_sum)
|
|
927
|
-
vlw:
|
|
1046
|
+
vlw: list[torch.Tensor] = [
|
|
928
1047
|
values.new_empty([selected_lengths_sum]), # output
|
|
929
1048
|
lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
|
|
930
1049
|
]
|
|
@@ -967,7 +1086,7 @@ def histogram_binning_calibration_abstract(
|
|
|
967
1086
|
upper_bound: float,
|
|
968
1087
|
bin_ctr_in_use_after: int,
|
|
969
1088
|
bin_ctr_weight_value: float,
|
|
970
|
-
) ->
|
|
1089
|
+
) -> tuple[Tensor, Tensor]:
|
|
971
1090
|
return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
|
|
972
1091
|
|
|
973
1092
|
|
|
@@ -1118,7 +1237,7 @@ def generic_histogram_binning_calibration_by_feature(
|
|
|
1118
1237
|
positive_weight: float,
|
|
1119
1238
|
bin_ctr_in_use_after: int,
|
|
1120
1239
|
bin_ctr_weight_value: float,
|
|
1121
|
-
) ->
|
|
1240
|
+
) -> tuple[Tensor, Tensor]:
|
|
1122
1241
|
torch._check(bin_num_examples.numel() == bin_num_positives.numel())
|
|
1123
1242
|
torch._check(
|
|
1124
1243
|
bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
|
|
@@ -1129,13 +1248,13 @@ def generic_histogram_binning_calibration_by_feature(
|
|
|
1129
1248
|
|
|
1130
1249
|
|
|
1131
1250
|
def permute_multi_embedding_function_impl_abstract(
|
|
1132
|
-
pooled_embs:
|
|
1251
|
+
pooled_embs: list[Tensor],
|
|
1133
1252
|
permutes: Tensor,
|
|
1134
1253
|
in_shapes: Tensor,
|
|
1135
1254
|
out_shapes: Tensor,
|
|
1136
|
-
out_lengths:
|
|
1255
|
+
out_lengths: list[int],
|
|
1137
1256
|
reverse: bool = False,
|
|
1138
|
-
) ->
|
|
1257
|
+
) -> list[Tensor]:
|
|
1139
1258
|
out_dtype = pooled_embs[0].dtype
|
|
1140
1259
|
bs = pooled_embs[0].shape[0]
|
|
1141
1260
|
torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
|
|
@@ -1161,15 +1280,25 @@ def lengths_range_abstract(
|
|
|
1161
1280
|
|
|
1162
1281
|
|
|
1163
1282
|
def all_to_one_device(
|
|
1164
|
-
input_tensors:
|
|
1283
|
+
input_tensors: list[Tensor],
|
|
1165
1284
|
target_device: torch.device,
|
|
1166
|
-
) ->
|
|
1285
|
+
) -> list[Tensor]:
|
|
1167
1286
|
return [
|
|
1168
1287
|
torch.empty_like(input_tensor, device=torch.device("meta"))
|
|
1169
1288
|
for input_tensor in input_tensors
|
|
1170
1289
|
]
|
|
1171
1290
|
|
|
1172
1291
|
|
|
1292
|
+
def sum_reduce_to_one(
|
|
1293
|
+
input_tensors: list[Tensor],
|
|
1294
|
+
target_device: torch.device,
|
|
1295
|
+
) -> Tensor:
|
|
1296
|
+
torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
|
|
1297
|
+
# All tensors should have the same shape
|
|
1298
|
+
first_tensor = input_tensors[0]
|
|
1299
|
+
return torch.empty_like(first_tensor, device=torch.device("meta"))
|
|
1300
|
+
|
|
1301
|
+
|
|
1173
1302
|
def _setup() -> None:
|
|
1174
1303
|
# pyre-ignore[16]
|
|
1175
1304
|
_setup.done = getattr(_setup, "done", False)
|
|
@@ -1202,6 +1331,8 @@ def _setup() -> None:
|
|
|
1202
1331
|
)
|
|
1203
1332
|
|
|
1204
1333
|
impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
|
|
1334
|
+
impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
|
|
1335
|
+
impl_abstract("fbgemm::repeat_arange", repeat_arange_meta)
|
|
1205
1336
|
impl_abstract(
|
|
1206
1337
|
"fbgemm::permute_2D_sparse_data_input1D",
|
|
1207
1338
|
permute_2D_sparse_data_input1D_meta,
|
|
@@ -1234,6 +1365,10 @@ def _setup() -> None:
|
|
|
1234
1365
|
"fbgemm::block_bucketize_sparse_features",
|
|
1235
1366
|
block_bucketize_sparse_features_meta,
|
|
1236
1367
|
)
|
|
1368
|
+
impl_abstract(
|
|
1369
|
+
"fbgemm::block_bucketize_sparse_features_2d_weights",
|
|
1370
|
+
block_bucketize_sparse_features_2d_weights_meta,
|
|
1371
|
+
)
|
|
1237
1372
|
impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
|
|
1238
1373
|
impl_abstract(
|
|
1239
1374
|
"fbgemm::permute_sparse_features", permute_sparse_features_abstract
|
|
@@ -1241,6 +1376,7 @@ def _setup() -> None:
|
|
|
1241
1376
|
impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
|
|
1242
1377
|
impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
|
|
1243
1378
|
impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
|
|
1379
|
+
impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
|
|
1244
1380
|
impl_abstract(
|
|
1245
1381
|
"fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
|
|
1246
1382
|
)
|
|
@@ -152,6 +152,18 @@ except:
|
|
|
152
152
|
DeprecationWarning,
|
|
153
153
|
)
|
|
154
154
|
|
|
155
|
+
try:
|
|
156
|
+
# Import is placed under a try-except bc the op is experimental and can be
|
|
157
|
+
# removed/updated in the future
|
|
158
|
+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd as lookup_adam_ssd # noqa: F401
|
|
159
|
+
except:
|
|
160
|
+
warnings.warn(
|
|
161
|
+
f"""\033[93m
|
|
162
|
+
Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd
|
|
163
|
+
\033[0m""",
|
|
164
|
+
DeprecationWarning,
|
|
165
|
+
)
|
|
166
|
+
|
|
155
167
|
try:
|
|
156
168
|
# Import is placed under a try-except bc the op is experimental and can be
|
|
157
169
|
# removed/updated in the future
|
|
@@ -56,14 +56,15 @@ def invoke(
|
|
|
56
56
|
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
57
57
|
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
58
58
|
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
59
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
59
60
|
}
|
|
60
61
|
|
|
61
62
|
dict_aux_int: Dict[str, int] = {
|
|
62
|
-
"iter": iter,
|
|
63
|
-
"info_B_num_bits": common_args.info_B_num_bits,
|
|
63
|
+
"iter": iter,
|
|
64
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
64
65
|
"info_B_mask": common_args.info_B_mask,
|
|
65
66
|
}
|
|
66
|
-
|
|
67
|
+
|
|
67
68
|
dict_aux_float: Dict[str, float] = {
|
|
68
69
|
"gwd_lower_bound": gwd_lower_bound,
|
|
69
70
|
}
|
|
@@ -81,7 +82,7 @@ def invoke(
|
|
|
81
82
|
|
|
82
83
|
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
83
84
|
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
84
|
-
|
|
85
|
+
|
|
85
86
|
|
|
86
87
|
# optimizer_args # if optimizer == none
|
|
87
88
|
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
@@ -132,6 +133,11 @@ def invoke(
|
|
|
132
133
|
"Please check the frontend and backend version. "
|
|
133
134
|
)
|
|
134
135
|
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
136
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
137
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
138
|
+
"Please check the frontend and backend version. "
|
|
139
|
+
)
|
|
140
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
135
141
|
|
|
136
142
|
aux_int: List[int] = []
|
|
137
143
|
assert "iter" in dict_aux_int, (
|
|
@@ -204,7 +210,7 @@ def invoke(
|
|
|
204
210
|
# ['momentum1', 'learning_rate_tensor', 'optim_float']
|
|
205
211
|
optim_float: List[float] = []
|
|
206
212
|
optim_float.append(dict_optim_float["eps"])
|
|
207
|
-
# optim_bool
|
|
213
|
+
# optim_bool
|
|
208
214
|
|
|
209
215
|
return torch.ops.fbgemm.split_embedding_codegen_lookup_adagrad_function_pt2(
|
|
210
216
|
# common_args
|
|
@@ -226,6 +232,7 @@ def invoke(
|
|
|
226
232
|
max_B=vbe_metadata.max_B,
|
|
227
233
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
228
234
|
vbe_output_size=vbe_metadata.output_size,
|
|
235
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
229
236
|
# aux_tensor
|
|
230
237
|
aux_tensor=aux_tensor,
|
|
231
238
|
# aux_int
|
|
@@ -67,14 +67,15 @@ def invoke(
|
|
|
67
67
|
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
68
68
|
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
69
69
|
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
70
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
70
71
|
}
|
|
71
72
|
|
|
72
73
|
dict_aux_int: Dict[str, int] = {
|
|
73
|
-
"iter": iter,
|
|
74
|
-
"info_B_num_bits": common_args.info_B_num_bits,
|
|
74
|
+
"iter": iter,
|
|
75
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
75
76
|
"info_B_mask": common_args.info_B_mask,
|
|
76
77
|
}
|
|
77
|
-
|
|
78
|
+
|
|
78
79
|
dict_aux_float: Dict[str, float] = {
|
|
79
80
|
"gwd_lower_bound": gwd_lower_bound,
|
|
80
81
|
}
|
|
@@ -92,7 +93,7 @@ def invoke(
|
|
|
92
93
|
|
|
93
94
|
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
94
95
|
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
95
|
-
|
|
96
|
+
|
|
96
97
|
|
|
97
98
|
# optimizer_args # if optimizer == none
|
|
98
99
|
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
@@ -125,13 +126,13 @@ def invoke(
|
|
|
125
126
|
momentum2.placements,
|
|
126
127
|
momentum2.offsets,
|
|
127
128
|
] if momentum2 is not None else None
|
|
128
|
-
|
|
129
|
+
|
|
129
130
|
if optimizer_args.use_rowwise_bias_correction and row_counter is not None:
|
|
130
131
|
row_counter_host = None # not supported on CPU
|
|
131
132
|
row_counter_dev = row_counter.dev
|
|
132
133
|
row_counter_uvm = row_counter.uvm
|
|
133
134
|
row_counter_offsets = row_counter.offsets
|
|
134
|
-
row_counter_placements = row_counter.placements
|
|
135
|
+
row_counter_placements = row_counter.placements
|
|
135
136
|
elif optimizer_args.use_rowwise_bias_correction:
|
|
136
137
|
assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None"
|
|
137
138
|
else:
|
|
@@ -173,6 +174,11 @@ def invoke(
|
|
|
173
174
|
"Please check the frontend and backend version. "
|
|
174
175
|
)
|
|
175
176
|
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
177
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
178
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
179
|
+
"Please check the frontend and backend version. "
|
|
180
|
+
)
|
|
181
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
176
182
|
|
|
177
183
|
aux_int: List[int] = []
|
|
178
184
|
assert "iter" in dict_aux_int, (
|
|
@@ -271,7 +277,7 @@ def invoke(
|
|
|
271
277
|
optim_float.append(dict_optim_float["weight_decay"])
|
|
272
278
|
# optim_bool
|
|
273
279
|
optim_bool: List[bool] = []
|
|
274
|
-
optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
|
|
280
|
+
optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
|
|
275
281
|
|
|
276
282
|
return torch.ops.fbgemm.split_embedding_codegen_lookup_adam_function_pt2(
|
|
277
283
|
# common_args
|
|
@@ -293,6 +299,7 @@ def invoke(
|
|
|
293
299
|
max_B=vbe_metadata.max_B,
|
|
294
300
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
295
301
|
vbe_output_size=vbe_metadata.output_size,
|
|
302
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
296
303
|
# aux_tensor
|
|
297
304
|
aux_tensor=aux_tensor,
|
|
298
305
|
# aux_int
|