fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
fbgemm_gpu/sparse_ops.py
ADDED
|
@@ -0,0 +1,1455 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from typing import Callable, Optional
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
16
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
|
|
17
|
+
from fbgemm_gpu.utils.loader import load_torch_module
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
# pyre-ignore
|
|
21
|
+
from fbgemm_gpu import open_source # noqa: F401
|
|
22
|
+
except Exception:
|
|
23
|
+
load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")
|
|
24
|
+
load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
|
25
|
+
|
|
26
|
+
if torch.version.hip:
|
|
27
|
+
torch.ops.load_library(
|
|
28
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
else:
|
|
32
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
|
|
33
|
+
|
|
34
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine")
|
|
35
|
+
|
|
36
|
+
torch.ops.load_library(
|
|
37
|
+
"//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
|
|
38
|
+
)
|
|
39
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
|
|
40
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")
|
|
41
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
|
|
42
|
+
torch.ops.load_library(
|
|
43
|
+
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
|
|
44
|
+
)
|
|
45
|
+
torch.ops.load_library(
|
|
46
|
+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
import torch.utils._pytree as pytree
|
|
51
|
+
from torch import SymInt, Tensor
|
|
52
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if hasattr(torch.library, "register_fake"):
|
|
56
|
+
# pyre-ignore[9]
|
|
57
|
+
impl_abstract = torch.library.register_fake
|
|
58
|
+
elif hasattr(torch.library, "impl_abstract"):
|
|
59
|
+
impl_abstract = torch.library.impl_abstract
|
|
60
|
+
else:
|
|
61
|
+
# pyre-ignore
|
|
62
|
+
def impl_abstract(schema: str) -> Callable[[Callable], Callable]:
|
|
63
|
+
# no-op
|
|
64
|
+
# pyre-ignore
|
|
65
|
+
def wrapper(f: Callable) -> Callable:
|
|
66
|
+
return f
|
|
67
|
+
|
|
68
|
+
return wrapper
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def permute_2D_sparse_data_input1D_meta(
|
|
72
|
+
permute: Tensor,
|
|
73
|
+
lengths: Tensor,
|
|
74
|
+
values: Tensor,
|
|
75
|
+
stride: int,
|
|
76
|
+
weights: Optional[Tensor] = None,
|
|
77
|
+
permuted_lengths_sum: Optional[int] = None,
|
|
78
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
79
|
+
torch._check(
|
|
80
|
+
lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
|
|
81
|
+
)
|
|
82
|
+
T = permute.numel()
|
|
83
|
+
B = stride
|
|
84
|
+
indices = values
|
|
85
|
+
permuted_lengths = lengths.new_empty([T * B])
|
|
86
|
+
permuted_indices_size = 0
|
|
87
|
+
if permuted_lengths_sum is not None:
|
|
88
|
+
permuted_indices_size = permuted_lengths_sum
|
|
89
|
+
else:
|
|
90
|
+
ctx = torch.library.get_ctx()
|
|
91
|
+
permuted_indices_size = ctx.new_dynamic_size()
|
|
92
|
+
# pyre-fixme
|
|
93
|
+
permuted_indices = indices.new_empty(permuted_indices_size)
|
|
94
|
+
permuted_weights = None
|
|
95
|
+
if weights is not None:
|
|
96
|
+
# pyre-fixme
|
|
97
|
+
permuted_weights = weights.new_empty(permuted_indices_size)
|
|
98
|
+
return permuted_lengths, permuted_indices, permuted_weights
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# pyre-ignore
|
|
102
|
+
def permute_2D_sparse_data_input1D_setup_context(ctx, inputs, output):
|
|
103
|
+
permute, lengths, values, stride, weights, permuted_lengths_sum = inputs
|
|
104
|
+
permuted_lengths, permuted_values, permuted_weights = output
|
|
105
|
+
ctx.permute = permute
|
|
106
|
+
ctx.permuted_lengths = permuted_lengths
|
|
107
|
+
ctx.stride = stride
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def permute_2D_sparse_data_input1D_backward(
|
|
111
|
+
ctx, # pyre-ignore
|
|
112
|
+
grad_lengths: torch.Tensor,
|
|
113
|
+
grad_values: torch.Tensor,
|
|
114
|
+
grad_weights: torch.Tensor,
|
|
115
|
+
) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
|
|
116
|
+
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
|
|
117
|
+
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
|
|
118
|
+
torch.ops.fbgemm.permute_2D_sparse_data_input1D(
|
|
119
|
+
inv_permute,
|
|
120
|
+
ctx.permuted_lengths,
|
|
121
|
+
grad_values,
|
|
122
|
+
ctx.stride,
|
|
123
|
+
grad_weights,
|
|
124
|
+
None,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
return (
|
|
128
|
+
None,
|
|
129
|
+
permuted_grad_lengths,
|
|
130
|
+
permuted_grad_values,
|
|
131
|
+
None,
|
|
132
|
+
permuted_grad_weights,
|
|
133
|
+
None,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def permute_2D_sparse_data_meta(
|
|
138
|
+
permute: Tensor,
|
|
139
|
+
lengths: Tensor,
|
|
140
|
+
values: Tensor,
|
|
141
|
+
weights: Optional[Tensor] = None,
|
|
142
|
+
permuted_lengths_sum: Optional[int] = None,
|
|
143
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
144
|
+
torch._check(
|
|
145
|
+
lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
|
|
146
|
+
)
|
|
147
|
+
T = permute.numel()
|
|
148
|
+
B = lengths.size(1)
|
|
149
|
+
indices = values
|
|
150
|
+
permuted_lengths = lengths.new_empty([T, B])
|
|
151
|
+
permuted_indices_size = 0
|
|
152
|
+
if permuted_lengths_sum is not None:
|
|
153
|
+
permuted_indices_size = permuted_lengths_sum
|
|
154
|
+
else:
|
|
155
|
+
ctx = torch.library.get_ctx()
|
|
156
|
+
permuted_indices_size = ctx.new_dynamic_size()
|
|
157
|
+
# pyre-fixme
|
|
158
|
+
permuted_indices = indices.new_empty(permuted_indices_size)
|
|
159
|
+
permuted_weights = None
|
|
160
|
+
if weights is not None:
|
|
161
|
+
# pyre-fixme
|
|
162
|
+
permuted_weights = weights.new_empty(permuted_indices_size)
|
|
163
|
+
return permuted_lengths, permuted_indices, permuted_weights
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def invert_permute_abstract(permute: Tensor) -> Tensor:
|
|
167
|
+
return torch.empty_like(permute)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def get_source_mask_meta(
|
|
171
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
172
|
+
) -> Tensor:
|
|
173
|
+
if output_size is None:
|
|
174
|
+
ctx = torch.library.get_ctx()
|
|
175
|
+
output_size = ctx.new_dynamic_size()
|
|
176
|
+
return torch.empty([output_size], dtype=torch.bool)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_source_mask(
|
|
180
|
+
num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
"""
|
|
183
|
+
Generate a boolean mask indicating which elements are from sources vs targets.
|
|
184
|
+
|
|
185
|
+
This is a Python wrapper that computes output_size when not provided,
|
|
186
|
+
enabling the operation to work with meta tensors for compilation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
num_sources: 1D tensor of source counts per batch element
|
|
190
|
+
num_targets: 1D tensor of target counts per batch element
|
|
191
|
+
output_size: Optional pre-computed output size.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
A 1D boolean tensor where True indicates source elements and False
|
|
195
|
+
indicates target elements
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> num_sources = torch.tensor([2, 3])
|
|
199
|
+
>>> num_targets = torch.tensor([1, 2])
|
|
200
|
+
>>> get_source_mask(num_sources, num_targets)
|
|
201
|
+
tensor([True, True, False, True, True, True, False, False])
|
|
202
|
+
"""
|
|
203
|
+
# Compute output_size if not provided and tensors are regular (not meta/fake)
|
|
204
|
+
if output_size is None:
|
|
205
|
+
combined = num_sources + num_targets
|
|
206
|
+
output_size = int(combined.sum().item())
|
|
207
|
+
|
|
208
|
+
return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# pyre-ignore
|
|
212
|
+
def permute_2D_sparse_data_setup_context(ctx, inputs, output):
|
|
213
|
+
permute, lengths, values, weights, permuted_lengths_sum = inputs
|
|
214
|
+
permuted_lengths, permuted_values, permuted_weights = output
|
|
215
|
+
ctx.permute = permute
|
|
216
|
+
ctx.permuted_lengths = permuted_lengths
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# pyre-ignore
|
|
220
|
+
def permute_2D_sparse_data_backward(ctx, grad_lengths, grad_values, grad_weights):
|
|
221
|
+
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
|
|
222
|
+
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
|
|
223
|
+
torch.ops.fbgemm.permute_2D_sparse_data(
|
|
224
|
+
inv_permute, ctx.permuted_lengths, grad_values, grad_weights
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
return (
|
|
228
|
+
None,
|
|
229
|
+
permuted_grad_lengths,
|
|
230
|
+
permuted_grad_values,
|
|
231
|
+
permuted_grad_weights,
|
|
232
|
+
None,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def permute_1D_sparse_data_meta(
|
|
237
|
+
permute: Tensor,
|
|
238
|
+
lengths: Tensor,
|
|
239
|
+
values: Tensor,
|
|
240
|
+
weights: Optional[Tensor] = None,
|
|
241
|
+
permuted_lengths_sum: Optional[int] = None,
|
|
242
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
243
|
+
indices = values
|
|
244
|
+
permuted_lengths_size = permute.numel()
|
|
245
|
+
permuted_lengths = lengths.new_empty([permuted_lengths_size])
|
|
246
|
+
permuted_indices_size = 0
|
|
247
|
+
if permuted_lengths_sum is not None:
|
|
248
|
+
permuted_indices_size = permuted_lengths_sum
|
|
249
|
+
else:
|
|
250
|
+
ctx = torch.library.get_ctx()
|
|
251
|
+
permuted_indices_size = ctx.new_dynamic_size()
|
|
252
|
+
# pyre-fixme
|
|
253
|
+
permuted_indices = indices.new_empty(permuted_indices_size)
|
|
254
|
+
permuted_weights = None
|
|
255
|
+
if weights is not None:
|
|
256
|
+
# pyre-fixme
|
|
257
|
+
permuted_weights = weights.new_empty(permuted_indices_size)
|
|
258
|
+
return permuted_lengths, permuted_indices, permuted_weights
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def masked_select_jagged_1d(
|
|
262
|
+
values: Tensor, lengths: Tensor, mask: Tensor
|
|
263
|
+
) -> tuple[Tensor, Tensor]:
|
|
264
|
+
torch._check(values.dim() == 1)
|
|
265
|
+
torch._check(lengths.dim() == 1)
|
|
266
|
+
torch._check(values.device == lengths.device)
|
|
267
|
+
torch._check(values.device == mask.device)
|
|
268
|
+
|
|
269
|
+
s0 = torch.library.get_ctx().new_dynamic_size()
|
|
270
|
+
masked_values = values.new_empty([s0])
|
|
271
|
+
masked_lengths = torch.empty_like(lengths)
|
|
272
|
+
return masked_values, masked_lengths
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def tbe_input_combine_abstract(
|
|
276
|
+
indices_list: list[Tensor],
|
|
277
|
+
offsets_list: list[Tensor],
|
|
278
|
+
per_sample_weights: list[Tensor],
|
|
279
|
+
include_last_offsets: Tensor,
|
|
280
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
281
|
+
torch._check(len(indices_list) > 0)
|
|
282
|
+
torch._check(len(indices_list) == len(offsets_list))
|
|
283
|
+
torch._check(len(indices_list) == len(per_sample_weights))
|
|
284
|
+
torch._check(len(indices_list) == include_last_offsets.numel())
|
|
285
|
+
total_indices = 0
|
|
286
|
+
need_weight = False
|
|
287
|
+
for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights):
|
|
288
|
+
torch._check(index.dtype == torch.int or index.dtype == torch.long)
|
|
289
|
+
torch._check(offset.dtype == torch.int or offset.dtype == torch.long)
|
|
290
|
+
torch._check(index.dim() == 1)
|
|
291
|
+
torch._check(offset.dim() == 1)
|
|
292
|
+
torch._check(index.is_contiguous())
|
|
293
|
+
torch._check(offset.is_contiguous())
|
|
294
|
+
total_indices = total_indices + index.numel()
|
|
295
|
+
if guard_or_true(weight.numel() > 0):
|
|
296
|
+
torch._check(weight.dim() == 1)
|
|
297
|
+
torch._check(weight.numel() == index.numel())
|
|
298
|
+
torch._check(weight.is_contiguous())
|
|
299
|
+
need_weight = True
|
|
300
|
+
total_offsets = torch.library.get_ctx().new_dynamic_size()
|
|
301
|
+
combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int)
|
|
302
|
+
combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int)
|
|
303
|
+
if need_weight:
|
|
304
|
+
combined_weights = per_sample_weights[0].new_empty(
|
|
305
|
+
[total_indices], dtype=torch.float
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
combined_weights = torch.empty(0)
|
|
309
|
+
return combined_indices, combined_offsets, combined_weights
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def tbe_input_combine_with_length_abstract(
|
|
313
|
+
indices_list: list[Tensor],
|
|
314
|
+
offsets_list: list[Tensor],
|
|
315
|
+
per_sample_weights: list[Tensor],
|
|
316
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
317
|
+
torch._check(len(indices_list) > 0)
|
|
318
|
+
torch._check(len(indices_list) == len(offsets_list))
|
|
319
|
+
torch._check(len(indices_list) == len(per_sample_weights))
|
|
320
|
+
total_indices = 0
|
|
321
|
+
total_offsets = 0
|
|
322
|
+
need_weight = False
|
|
323
|
+
for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights):
|
|
324
|
+
torch._check(index.dtype == torch.int or index.dtype == torch.long)
|
|
325
|
+
torch._check(offset.dtype == torch.int or offset.dtype == torch.long)
|
|
326
|
+
torch._check(index.dim() == 1)
|
|
327
|
+
torch._check(offset.dim() == 1)
|
|
328
|
+
torch._check(index.is_contiguous())
|
|
329
|
+
torch._check(offset.is_contiguous())
|
|
330
|
+
total_indices = total_indices + index.numel()
|
|
331
|
+
total_offsets = total_offsets + offset.numel()
|
|
332
|
+
if guard_or_true(weight.numel() > 0):
|
|
333
|
+
torch._check(weight.dim() == 1)
|
|
334
|
+
torch._check(weight.numel() == index.numel())
|
|
335
|
+
torch._check(weight.is_contiguous())
|
|
336
|
+
need_weight = True
|
|
337
|
+
combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int)
|
|
338
|
+
combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int)
|
|
339
|
+
if need_weight:
|
|
340
|
+
combined_weights = per_sample_weights[0].new_empty(
|
|
341
|
+
[total_indices], dtype=torch.float
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
combined_weights = torch.empty(0, device=indices_list[0].device)
|
|
345
|
+
return combined_indices, combined_offsets, combined_weights
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def jagged_index_select_2d_forward_v2_abstract(
|
|
349
|
+
values: Tensor,
|
|
350
|
+
indices: Tensor,
|
|
351
|
+
input_offsets: Tensor,
|
|
352
|
+
output_offsets: Tensor,
|
|
353
|
+
num_dense_output_rows: Optional[int] = None,
|
|
354
|
+
) -> Tensor:
|
|
355
|
+
torch._check(values.device == indices.device)
|
|
356
|
+
torch._check(values.device == input_offsets.device)
|
|
357
|
+
torch._check(values.device == output_offsets.device)
|
|
358
|
+
torch._check(values.dim() == 2)
|
|
359
|
+
dynamic_num_dense_output_rows = torch.library.get_ctx().new_dynamic_size()
|
|
360
|
+
num_cols = values.size(1)
|
|
361
|
+
return values.new_empty([dynamic_num_dense_output_rows, num_cols])
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def jagged_index_add_2d_forward_v2_abstract(
|
|
365
|
+
values: Tensor,
|
|
366
|
+
indices: Tensor,
|
|
367
|
+
input_offsets: Tensor,
|
|
368
|
+
output_offsets: Tensor,
|
|
369
|
+
num_output_rows: int,
|
|
370
|
+
num_dense_input_rows: Optional[int] = None,
|
|
371
|
+
) -> Tensor:
|
|
372
|
+
torch._check(values.device == indices.device)
|
|
373
|
+
torch._check(values.device == input_offsets.device)
|
|
374
|
+
torch._check(values.device == output_offsets.device)
|
|
375
|
+
torch._check(values.dim() == 2)
|
|
376
|
+
num_cols = values.size(1)
|
|
377
|
+
return values.new_empty([num_output_rows, num_cols])
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def expand_into_jagged_permute_meta(
|
|
381
|
+
permute: Tensor,
|
|
382
|
+
input_offsets: Tensor,
|
|
383
|
+
output_offsets: Tensor,
|
|
384
|
+
output_size: tuple[int, ...],
|
|
385
|
+
) -> Tensor:
|
|
386
|
+
torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
|
|
387
|
+
torch._check(
|
|
388
|
+
permute.numel() == input_offsets.numel() - 1,
|
|
389
|
+
lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1",
|
|
390
|
+
)
|
|
391
|
+
torch._check(
|
|
392
|
+
permute.numel() == output_offsets.numel() - 1,
|
|
393
|
+
lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1",
|
|
394
|
+
)
|
|
395
|
+
output_permute = input_offsets.new_empty(output_size)
|
|
396
|
+
return output_permute
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def check_all_same_device(*tensors: Optional[Tensor]) -> None:
|
|
400
|
+
# pyre-ignore[9]
|
|
401
|
+
tensors, _ = pytree.tree_flatten(tensors)
|
|
402
|
+
if len(tensors) == 0:
|
|
403
|
+
return
|
|
404
|
+
if all(t.device.type in ["cpu", "meta"] for t in tensors if t is not None):
|
|
405
|
+
return
|
|
406
|
+
first_tensor: Optional[Tensor] = None
|
|
407
|
+
for tensor in tensors:
|
|
408
|
+
if tensor is None:
|
|
409
|
+
continue
|
|
410
|
+
if first_tensor is None:
|
|
411
|
+
first_tensor = tensor
|
|
412
|
+
torch._check(tensor.device == first_tensor.device)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def pruned_array_lookup_meta(
|
|
416
|
+
indices: Tensor,
|
|
417
|
+
offsets: Tensor,
|
|
418
|
+
index_remappings: Tensor,
|
|
419
|
+
index_remappings_offsets: Tensor,
|
|
420
|
+
) -> Tensor:
|
|
421
|
+
check_all_same_device(indices, offsets, index_remappings, index_remappings_offsets)
|
|
422
|
+
return indices.new_empty(indices.shape)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def int_nbit_split_embedding_codegen_lookup_function_meta(
|
|
426
|
+
dev_weights: torch.Tensor,
|
|
427
|
+
uvm_weights: torch.Tensor,
|
|
428
|
+
weights_placements: torch.Tensor,
|
|
429
|
+
weights_offsets: torch.Tensor,
|
|
430
|
+
weights_tys: torch.Tensor,
|
|
431
|
+
D_offsets: torch.Tensor,
|
|
432
|
+
total_D: int,
|
|
433
|
+
max_int2_D: int,
|
|
434
|
+
max_int4_D: int,
|
|
435
|
+
max_int8_D: int,
|
|
436
|
+
max_float16_D: int,
|
|
437
|
+
max_float32_D: int,
|
|
438
|
+
indices: torch.Tensor,
|
|
439
|
+
offsets: torch.Tensor,
|
|
440
|
+
pooling_mode: int,
|
|
441
|
+
indice_weights: Optional[torch.Tensor] = None,
|
|
442
|
+
output_dtype_int: int = 1,
|
|
443
|
+
lxu_cache_weights: Optional[torch.Tensor] = None,
|
|
444
|
+
lxu_cache_locations: Optional[torch.Tensor] = None,
|
|
445
|
+
row_alignment: Optional[int] = None,
|
|
446
|
+
max_float8_D: Optional[int] = None,
|
|
447
|
+
fp8_exponent_bits: Optional[int] = None,
|
|
448
|
+
fp8_exponent_bias: Optional[int] = None,
|
|
449
|
+
) -> Tensor:
|
|
450
|
+
check_all_same_device(
|
|
451
|
+
dev_weights,
|
|
452
|
+
uvm_weights,
|
|
453
|
+
weights_placements,
|
|
454
|
+
weights_offsets,
|
|
455
|
+
weights_tys,
|
|
456
|
+
D_offsets,
|
|
457
|
+
indices,
|
|
458
|
+
offsets,
|
|
459
|
+
indice_weights,
|
|
460
|
+
)
|
|
461
|
+
output_dtype = SparseType.from_int(output_dtype_int).as_dtype()
|
|
462
|
+
kINT8QparamsBytes = 8
|
|
463
|
+
|
|
464
|
+
if pooling_mode == PoolingMode.NONE:
|
|
465
|
+
kINT8QparamsBytes = 4
|
|
466
|
+
D = max(
|
|
467
|
+
[
|
|
468
|
+
max_int2_D,
|
|
469
|
+
max_int4_D,
|
|
470
|
+
max_int8_D,
|
|
471
|
+
max_float16_D,
|
|
472
|
+
max_float32_D,
|
|
473
|
+
max_float8_D if max_float8_D is not None else 0,
|
|
474
|
+
]
|
|
475
|
+
)
|
|
476
|
+
total_L = indices.numel()
|
|
477
|
+
T = weights_offsets.numel()
|
|
478
|
+
torch._check(D > 0)
|
|
479
|
+
adjusted_D = D
|
|
480
|
+
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
|
|
481
|
+
adjusted_D += kINT8QparamsBytes
|
|
482
|
+
output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
|
|
483
|
+
return output
|
|
484
|
+
|
|
485
|
+
T = D_offsets.numel() - 1
|
|
486
|
+
torch._check(T > 0)
|
|
487
|
+
torch._check(total_D > 0)
|
|
488
|
+
B = (offsets.size(0) - 1) // T
|
|
489
|
+
total_adjusted_D = total_D
|
|
490
|
+
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
|
|
491
|
+
total_adjusted_D += T * kINT8QparamsBytes
|
|
492
|
+
output = dev_weights.new_empty([B, total_adjusted_D], dtype=output_dtype)
|
|
493
|
+
return output
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def block_bucketize_sparse_features_meta(
|
|
497
|
+
lengths: torch.Tensor,
|
|
498
|
+
indices: torch.Tensor,
|
|
499
|
+
bucketize_pos: bool,
|
|
500
|
+
sequence: bool,
|
|
501
|
+
block_sizes: torch.Tensor,
|
|
502
|
+
my_size: int,
|
|
503
|
+
weights: Optional[torch.Tensor] = None,
|
|
504
|
+
batch_size_per_feature: Optional[torch.Tensor] = None,
|
|
505
|
+
max_B: int = -1,
|
|
506
|
+
block_bucketize_pos: Optional[torch.Tensor] = None,
|
|
507
|
+
keep_orig_idx: bool = False,
|
|
508
|
+
total_num_blocks: Optional[torch.Tensor] = None,
|
|
509
|
+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
510
|
+
) -> tuple[
|
|
511
|
+
torch.Tensor,
|
|
512
|
+
torch.Tensor,
|
|
513
|
+
Optional[torch.Tensor],
|
|
514
|
+
Optional[torch.Tensor],
|
|
515
|
+
Optional[torch.Tensor],
|
|
516
|
+
]:
|
|
517
|
+
# Output: lengths, indices, weights", pos?, unbucketize_permute?
|
|
518
|
+
num_buckets = my_size
|
|
519
|
+
num_features = lengths.size(0)
|
|
520
|
+
num_values = indices.size(0)
|
|
521
|
+
return (
|
|
522
|
+
lengths.new_empty([num_buckets * num_features]),
|
|
523
|
+
indices.new_empty([num_values]),
|
|
524
|
+
weights.new_empty(weights.shape) if weights is not None else None,
|
|
525
|
+
indices.new_empty([num_values]) if bucketize_pos else None,
|
|
526
|
+
indices.new_empty([num_values]),
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def block_bucketize_sparse_features_2d_weights_meta(
|
|
531
|
+
lengths: torch.Tensor,
|
|
532
|
+
indices: torch.Tensor,
|
|
533
|
+
bucketize_pos: bool,
|
|
534
|
+
sequence: bool,
|
|
535
|
+
block_sizes: torch.Tensor,
|
|
536
|
+
my_size: int,
|
|
537
|
+
weights: torch.Tensor,
|
|
538
|
+
weights_dim: int = 1,
|
|
539
|
+
batch_size_per_feature: Optional[torch.Tensor] = None,
|
|
540
|
+
max_B: int = -1,
|
|
541
|
+
block_bucketize_pos: Optional[torch.Tensor] = None,
|
|
542
|
+
keep_orig_idx: bool = False,
|
|
543
|
+
total_num_blocks: Optional[torch.Tensor] = None,
|
|
544
|
+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
|
|
545
|
+
) -> tuple[
|
|
546
|
+
torch.Tensor,
|
|
547
|
+
torch.Tensor,
|
|
548
|
+
torch.Tensor,
|
|
549
|
+
Optional[torch.Tensor],
|
|
550
|
+
Optional[torch.Tensor],
|
|
551
|
+
]:
|
|
552
|
+
# Output: lengths, indices, weights", pos?, unbucketize_permute?
|
|
553
|
+
num_buckets = my_size
|
|
554
|
+
num_features = lengths.size(0)
|
|
555
|
+
num_values = indices.size(0)
|
|
556
|
+
return (
|
|
557
|
+
lengths.new_empty([num_buckets * num_features]),
|
|
558
|
+
indices.new_empty([num_values]),
|
|
559
|
+
weights.new_empty([num_values, weights_dim]),
|
|
560
|
+
indices.new_empty([num_values]) if bucketize_pos else None,
|
|
561
|
+
indices.new_empty([num_values]),
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def merge_pooled_embeddings(
|
|
566
|
+
pooled_embeddings: list[torch.Tensor],
|
|
567
|
+
uncat_dim_size: int,
|
|
568
|
+
target_device: torch.device,
|
|
569
|
+
cat_dim: int = 1,
|
|
570
|
+
) -> torch.Tensor:
|
|
571
|
+
if len(pooled_embeddings) == 0:
|
|
572
|
+
return torch.empty([], device=target_device)
|
|
573
|
+
torch._check_is_size(cat_dim)
|
|
574
|
+
torch._check(cat_dim >= 0)
|
|
575
|
+
torch._check(cat_dim <= 1)
|
|
576
|
+
total_cat_dim_size = 0
|
|
577
|
+
for e in pooled_embeddings:
|
|
578
|
+
torch._check(e.dim() == 2)
|
|
579
|
+
torch._check(e.size(1 - cat_dim) == uncat_dim_size)
|
|
580
|
+
total_cat_dim_size += e.size(cat_dim)
|
|
581
|
+
torch._check_is_size(total_cat_dim_size)
|
|
582
|
+
e = pooled_embeddings[0]
|
|
583
|
+
if cat_dim == 0:
|
|
584
|
+
return e.new_empty(
|
|
585
|
+
[total_cat_dim_size, e.size(1)],
|
|
586
|
+
device=target_device,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
return e.new_empty(
|
|
590
|
+
[e.size(0), total_cat_dim_size],
|
|
591
|
+
device=target_device,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def permute_sparse_features_abstract(
|
|
596
|
+
permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
|
|
597
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
598
|
+
torch._check(lengths.dtype == indices.dtype)
|
|
599
|
+
torch._check(permute.device == lengths.device)
|
|
600
|
+
torch._check(permute.device == indices.device)
|
|
601
|
+
if weights is not None:
|
|
602
|
+
torch._check(permute.device == weights.device)
|
|
603
|
+
num_output_features = permute.numel()
|
|
604
|
+
B = lengths.size(1)
|
|
605
|
+
permuted_lengths = lengths.new_empty(num_output_features, B)
|
|
606
|
+
output_size = torch.library.get_ctx().new_dynamic_size()
|
|
607
|
+
# pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
|
|
608
|
+
# expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
|
|
609
|
+
permuted_indices = indices.new_empty(output_size)
|
|
610
|
+
permuted_weights = None
|
|
611
|
+
if weights is not None:
|
|
612
|
+
# pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
|
|
613
|
+
# expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
|
|
614
|
+
permuted_weights = weights.new_empty(output_size)
|
|
615
|
+
return (permuted_lengths, permuted_indices, permuted_weights)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def segment_sum_csr_abstract(
|
|
619
|
+
batch_size: int, csr_seg: Tensor, values: Tensor
|
|
620
|
+
) -> Tensor:
|
|
621
|
+
output_size = csr_seg.numel() - 1
|
|
622
|
+
output = values.new_empty(output_size)
|
|
623
|
+
return output
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def dense_to_jagged_forward(
|
|
627
|
+
dense: torch.Tensor,
|
|
628
|
+
offsets: list[torch.Tensor],
|
|
629
|
+
total_L: Optional[torch.SymInt] = None,
|
|
630
|
+
) -> torch.Tensor:
|
|
631
|
+
if total_L is None:
|
|
632
|
+
total_L = torch.library.get_ctx().new_dynamic_size()
|
|
633
|
+
return dense.new_zeros(
|
|
634
|
+
[total_L, dense.size()[-1]],
|
|
635
|
+
dtype=dense.dtype,
|
|
636
|
+
device=dense.device,
|
|
637
|
+
layout=dense.layout,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def dense_to_jagged(
|
|
642
|
+
dense: torch.Tensor,
|
|
643
|
+
offsets: list[torch.Tensor],
|
|
644
|
+
total_L: Optional[torch.SymInt] = None,
|
|
645
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
646
|
+
if total_L is None:
|
|
647
|
+
total_L = torch.library.get_ctx().new_dynamic_size()
|
|
648
|
+
return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def batch_index_select_dim0_abstract(
|
|
652
|
+
inputs: torch.Tensor,
|
|
653
|
+
indices: torch.Tensor,
|
|
654
|
+
input_num_indices: list[int],
|
|
655
|
+
input_rows: list[int],
|
|
656
|
+
input_columns: list[int],
|
|
657
|
+
permute_output_dim_0_1: bool,
|
|
658
|
+
) -> torch.Tensor:
|
|
659
|
+
"""
|
|
660
|
+
This meta function is used to calculate the shape of output tensor
|
|
661
|
+
from the original function `fbgemm::batch_index_select_dim0` without the actual data.
|
|
662
|
+
"""
|
|
663
|
+
# input lists must have the same length
|
|
664
|
+
torch._check(len(input_num_indices) == len(input_rows))
|
|
665
|
+
torch._check(len(input_num_indices) == len(input_columns))
|
|
666
|
+
|
|
667
|
+
if permute_output_dim_0_1 and len(input_num_indices) > 0:
|
|
668
|
+
# All num_indices must be the same if permute_output_dim_0_1 is True
|
|
669
|
+
for x in input_num_indices:
|
|
670
|
+
torch._check(x == input_num_indices[0])
|
|
671
|
+
|
|
672
|
+
size = sum([row * col for row, col in zip(input_rows, input_columns)])
|
|
673
|
+
torch._check(inputs.size(0) == size)
|
|
674
|
+
|
|
675
|
+
output_numel = 0
|
|
676
|
+
for i, cols in enumerate(input_columns):
|
|
677
|
+
output_numel += input_num_indices[i] * cols
|
|
678
|
+
return inputs.new_empty([output_numel])
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def batch_index_select_dim0_tensor_abstract(
|
|
682
|
+
inputs: torch.Tensor,
|
|
683
|
+
indices: torch.Tensor,
|
|
684
|
+
input_num_indices: torch.Tensor,
|
|
685
|
+
input_rows: torch.Tensor,
|
|
686
|
+
input_columns: torch.Tensor,
|
|
687
|
+
permute_output_dim_0_1: bool,
|
|
688
|
+
) -> torch.Tensor:
|
|
689
|
+
torch._check(input_num_indices.size(0) == input_rows.size(0))
|
|
690
|
+
torch._check(input_num_indices.size(0) == input_columns.size(0))
|
|
691
|
+
output_numel = torch.library.get_ctx().new_dynamic_size()
|
|
692
|
+
return inputs.new_empty([output_numel])
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def batch_index_select_dim0_forward_cuda_impl_abstract(
|
|
696
|
+
inputs: torch.Tensor,
|
|
697
|
+
indices: torch.Tensor,
|
|
698
|
+
input_num_indices: list[int],
|
|
699
|
+
input_rows: list[int],
|
|
700
|
+
input_columns: list[int],
|
|
701
|
+
permute_output_dim_0_1: bool,
|
|
702
|
+
) -> list[torch.Tensor]:
|
|
703
|
+
num_inputs = len(input_rows)
|
|
704
|
+
torch._check(len(input_num_indices) == len(input_rows))
|
|
705
|
+
torch._check(len(input_num_indices) == len(input_columns))
|
|
706
|
+
|
|
707
|
+
output_numel = 0
|
|
708
|
+
for i, cols in enumerate(input_columns):
|
|
709
|
+
output_numel += input_num_indices[i] * cols
|
|
710
|
+
|
|
711
|
+
output_offsets = (
|
|
712
|
+
inputs.new_empty([0], dtype=torch.int64)
|
|
713
|
+
if permute_output_dim_0_1
|
|
714
|
+
else inputs.new_empty([num_inputs + 1], dtype=torch.int64)
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
if permute_output_dim_0_1:
|
|
718
|
+
for i in range(num_inputs):
|
|
719
|
+
torch._check(input_num_indices[0] == input_num_indices[i])
|
|
720
|
+
|
|
721
|
+
return [
|
|
722
|
+
inputs.new_empty([output_numel]),
|
|
723
|
+
inputs.new_empty([num_inputs], dtype=torch.int64),
|
|
724
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int64),
|
|
725
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets
|
|
726
|
+
output_offsets,
|
|
727
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int64),
|
|
728
|
+
inputs.new_empty([4], dtype=torch.int64, device="cpu"),
|
|
729
|
+
]
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
|
|
733
|
+
inputs: torch.Tensor,
|
|
734
|
+
indices: torch.Tensor,
|
|
735
|
+
input_num_indices: torch.Tensor,
|
|
736
|
+
input_rows: torch.Tensor,
|
|
737
|
+
input_columns: torch.Tensor,
|
|
738
|
+
permute_output_dim_0_1: bool,
|
|
739
|
+
) -> list[torch.Tensor]:
|
|
740
|
+
num_inputs: int = input_rows.size(0)
|
|
741
|
+
torch._check(input_num_indices.size(0) == input_rows.size(0))
|
|
742
|
+
torch._check(input_num_indices.size(0) == input_columns.size(0))
|
|
743
|
+
output_numel = torch.library.get_ctx().new_dynamic_size()
|
|
744
|
+
if permute_output_dim_0_1:
|
|
745
|
+
output_offsets = inputs.new_empty([0], dtype=torch.int64)
|
|
746
|
+
else:
|
|
747
|
+
output_offsets = inputs.new_empty([num_inputs + 1], dtype=torch.int64)
|
|
748
|
+
|
|
749
|
+
return [
|
|
750
|
+
inputs.new_empty([output_numel]),
|
|
751
|
+
inputs.new_empty([num_inputs], dtype=torch.int64),
|
|
752
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int64),
|
|
753
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets
|
|
754
|
+
output_offsets,
|
|
755
|
+
inputs.new_empty([num_inputs + 1], dtype=torch.int64), # total_L_offsets
|
|
756
|
+
inputs.new_empty([4], dtype=torch.int64, device="cpu"),
|
|
757
|
+
]
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def batch_index_select_dim0_tensor_backward_cuda_impl_abstract(
|
|
761
|
+
grad_output: torch.Tensor,
|
|
762
|
+
dev_weights: torch.Tensor,
|
|
763
|
+
weights_offsets: torch.Tensor,
|
|
764
|
+
D_offsets: torch.Tensor,
|
|
765
|
+
hash_size_cumsum: torch.Tensor,
|
|
766
|
+
indices: torch.Tensor,
|
|
767
|
+
max_segment_length_per_warp: int,
|
|
768
|
+
grad_offsets: torch.Tensor,
|
|
769
|
+
total_L_offsets: torch.Tensor,
|
|
770
|
+
permute_output_dim_0_1: bool,
|
|
771
|
+
saved_tensor: torch.Tensor,
|
|
772
|
+
) -> torch.Tensor:
|
|
773
|
+
return grad_output.new_empty(dev_weights.shape)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def keyed_jagged_index_select_dim1_abstract(
|
|
777
|
+
values: torch.Tensor,
|
|
778
|
+
lengths: torch.Tensor,
|
|
779
|
+
offsets: torch.Tensor,
|
|
780
|
+
indices: torch.Tensor,
|
|
781
|
+
batch_size: torch.SymInt,
|
|
782
|
+
weights: Optional[torch.Tensor] = None,
|
|
783
|
+
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
784
|
+
) -> list[torch.Tensor]:
|
|
785
|
+
"""
|
|
786
|
+
This meta function is used to calculate the shape of output tensors
|
|
787
|
+
from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
|
|
788
|
+
"""
|
|
789
|
+
# pyre-ignore
|
|
790
|
+
num_batches = len(lengths) // batch_size
|
|
791
|
+
# offsets = [0] + lengths.cumsum(0)
|
|
792
|
+
torch._check(len(lengths) + 1 == len(offsets))
|
|
793
|
+
# len(lengths) == batch_size * num_batches
|
|
794
|
+
# pyre-ignore
|
|
795
|
+
torch._check(len(lengths) % batch_size == 0)
|
|
796
|
+
if weights is not None:
|
|
797
|
+
# weights must have the same shape as values
|
|
798
|
+
torch._check(values.shape == weights.shape)
|
|
799
|
+
|
|
800
|
+
if selected_lengths_sum is None:
|
|
801
|
+
length_indices = torch.cat(
|
|
802
|
+
# pyre-ignore
|
|
803
|
+
[indices + i * batch_size for i in range(num_batches)]
|
|
804
|
+
)
|
|
805
|
+
selected_lengths_sum = (
|
|
806
|
+
torch.index_select(lengths, 0, length_indices).sum().item()
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
ret: list[torch.Tensor] = [
|
|
810
|
+
# pyre-ignore
|
|
811
|
+
values.new_empty([selected_lengths_sum]),
|
|
812
|
+
lengths.new_empty([indices.shape[0] * num_batches]),
|
|
813
|
+
]
|
|
814
|
+
|
|
815
|
+
if weights is not None:
|
|
816
|
+
# pyre-ignore
|
|
817
|
+
ret.append(weights.new_empty([selected_lengths_sum]))
|
|
818
|
+
|
|
819
|
+
return ret
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def batch_index_select_dim0_backward_cuda_impl_abstract(
|
|
823
|
+
grad_output: torch.Tensor,
|
|
824
|
+
dev_weights: torch.Tensor,
|
|
825
|
+
weights_offsets: torch.Tensor,
|
|
826
|
+
D_offsets: torch.Tensor,
|
|
827
|
+
hash_size_cumsum: torch.Tensor,
|
|
828
|
+
indices: torch.Tensor,
|
|
829
|
+
max_segment_length_per_warp: int,
|
|
830
|
+
grad_offsets: torch.Tensor,
|
|
831
|
+
total_L_offsets: torch.Tensor,
|
|
832
|
+
permute_output_dim_0_1: bool,
|
|
833
|
+
saved_tensor: torch.Tensor,
|
|
834
|
+
) -> torch.Tensor:
|
|
835
|
+
return grad_output.new_empty(dev_weights.shape)
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def batch_index_select_dim0_forward_cpu_impl_abstract(
|
|
839
|
+
inputs: torch.Tensor,
|
|
840
|
+
indices: torch.Tensor,
|
|
841
|
+
input_num_indices: list[int],
|
|
842
|
+
input_rows: list[int],
|
|
843
|
+
input_columns: list[int],
|
|
844
|
+
permute_output_dim_0_1: bool,
|
|
845
|
+
) -> list[torch.Tensor]:
|
|
846
|
+
# input lists must have the same length
|
|
847
|
+
num_inputs = len(input_num_indices)
|
|
848
|
+
torch._check(num_inputs == len(input_rows))
|
|
849
|
+
torch._check(num_inputs == len(input_columns))
|
|
850
|
+
|
|
851
|
+
if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
|
|
852
|
+
# All num_indices must be the same if permute_output_dim_0_1 is True
|
|
853
|
+
for x in input_num_indices:
|
|
854
|
+
torch._check(x == input_num_indices[0])
|
|
855
|
+
|
|
856
|
+
output_numel: int = sum([i * c for i, c in zip(input_num_indices, input_columns)])
|
|
857
|
+
|
|
858
|
+
return [
|
|
859
|
+
inputs.new_empty([output_numel]),
|
|
860
|
+
inputs.new_empty([len(input_num_indices)], dtype=torch.int64),
|
|
861
|
+
inputs.new_empty([len(input_rows)], dtype=torch.int64),
|
|
862
|
+
inputs.new_empty([len(input_columns)], dtype=torch.int64),
|
|
863
|
+
inputs.new_empty([num_inputs], dtype=torch.int64), # indices_numels
|
|
864
|
+
inputs.new_empty([1], dtype=torch.int64), # saved_tensor
|
|
865
|
+
]
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
|
|
869
|
+
inputs: torch.Tensor,
|
|
870
|
+
indices: torch.Tensor,
|
|
871
|
+
input_num_indices: torch.Tensor,
|
|
872
|
+
input_rows: torch.Tensor,
|
|
873
|
+
input_columns: torch.Tensor,
|
|
874
|
+
permute_output_dim_0_1: bool,
|
|
875
|
+
) -> list[torch.Tensor]:
|
|
876
|
+
# input lists must have the same length
|
|
877
|
+
num_inputs = len(input_num_indices)
|
|
878
|
+
torch._check(num_inputs == len(input_rows))
|
|
879
|
+
torch._check(num_inputs == len(input_columns))
|
|
880
|
+
|
|
881
|
+
output_numel = torch.library.get_ctx().new_dynamic_size()
|
|
882
|
+
|
|
883
|
+
return [
|
|
884
|
+
inputs.new_empty([output_numel]),
|
|
885
|
+
inputs.new_empty([1], dtype=torch.int64),
|
|
886
|
+
]
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def batch_index_select_dim0_backward_cpu_impl_abstract(
|
|
890
|
+
grad_output: torch.Tensor,
|
|
891
|
+
indices: torch.Tensor,
|
|
892
|
+
indices_numels: torch.Tensor,
|
|
893
|
+
input_num_indices: torch.Tensor,
|
|
894
|
+
input_rows: torch.Tensor,
|
|
895
|
+
input_columns: torch.Tensor,
|
|
896
|
+
permute_output_dim_0_1: bool,
|
|
897
|
+
saved_tensor: torch.Tensor,
|
|
898
|
+
) -> torch.Tensor:
|
|
899
|
+
return grad_output.new_empty([torch.library.get_ctx().new_dynamic_size()])
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
def bounds_check_indices_abstract(
|
|
903
|
+
rows_per_table: torch.Tensor,
|
|
904
|
+
indices: torch.Tensor,
|
|
905
|
+
offsets: torch.Tensor,
|
|
906
|
+
bounds_check_mode_int: int,
|
|
907
|
+
bounds_check_warning: torch.Tensor,
|
|
908
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
909
|
+
B_offsets: Optional[torch.Tensor] = None,
|
|
910
|
+
max_B: Optional[SymInt] = None,
|
|
911
|
+
b_t_map: Optional[torch.Tensor] = None,
|
|
912
|
+
info_B_num_bits: int = -1,
|
|
913
|
+
info_B_mask: int = -1,
|
|
914
|
+
bounds_check_version: int = 1,
|
|
915
|
+
prefetch_pipeline: bool = False,
|
|
916
|
+
) -> None:
|
|
917
|
+
"""
|
|
918
|
+
This meta function is used to fake the bounds checking
|
|
919
|
+
from the original function `fbgemm::bounds_check_indices`
|
|
920
|
+
"""
|
|
921
|
+
return
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def group_index_select_dim0_gpu_impl_abstract(
|
|
925
|
+
inputs: list[torch.Tensor], group_size: int
|
|
926
|
+
) -> list[torch.Tensor]:
|
|
927
|
+
"""
|
|
928
|
+
Calculate output shapes for group_index_select_dim0_gpu_impl
|
|
929
|
+
without the actual data.
|
|
930
|
+
"""
|
|
931
|
+
indices_group = inputs[:group_size]
|
|
932
|
+
input_group = inputs[group_size:]
|
|
933
|
+
torch._check(len(input_group) == group_size)
|
|
934
|
+
|
|
935
|
+
ret = []
|
|
936
|
+
for i in range(group_size):
|
|
937
|
+
size = list(input_group[i].size())
|
|
938
|
+
ret.append(input_group[i].new_empty([indices_group[i].size(0)] + size[1:]))
|
|
939
|
+
|
|
940
|
+
# divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2
|
|
941
|
+
args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2))
|
|
942
|
+
|
|
943
|
+
ret.append(
|
|
944
|
+
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
|
|
945
|
+
input_group[0].new_empty(
|
|
946
|
+
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
|
|
947
|
+
)
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))
|
|
951
|
+
|
|
952
|
+
return ret
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def group_index_select_dim0_gpu_backward_abstract(
|
|
956
|
+
all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
|
|
957
|
+
) -> list[torch.Tensor]:
|
|
958
|
+
"""
|
|
959
|
+
Calculate output shapes for group_index_select_dim0_gpu_backward
|
|
960
|
+
without the actual data.
|
|
961
|
+
"""
|
|
962
|
+
torch._check(len(all_inputs) > 3)
|
|
963
|
+
group_size = (len(all_inputs) - 3) // 2
|
|
964
|
+
ret = []
|
|
965
|
+
|
|
966
|
+
# indices
|
|
967
|
+
for _ in range(group_size):
|
|
968
|
+
ret.append(all_inputs[0].new_empty(0))
|
|
969
|
+
|
|
970
|
+
# inputs
|
|
971
|
+
output_dim = len(output_shape_group_ref) // group_size
|
|
972
|
+
for i in range(group_size):
|
|
973
|
+
ret.append(
|
|
974
|
+
all_inputs[0].new_empty(
|
|
975
|
+
output_shape_group_ref[i * output_dim : (i + 1) * output_dim]
|
|
976
|
+
)
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
return ret
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
|
|
983
|
+
values: torch.Tensor,
|
|
984
|
+
lengths: torch.Tensor,
|
|
985
|
+
offsets: torch.Tensor,
|
|
986
|
+
indices: torch.Tensor,
|
|
987
|
+
batch_size: torch.SymInt,
|
|
988
|
+
weights: Optional[torch.Tensor] = None,
|
|
989
|
+
selected_lengths_sum: Optional[torch.SymInt] = None,
|
|
990
|
+
) -> list[torch.Tensor]:
|
|
991
|
+
num_batches = lengths.size(0) // batch_size
|
|
992
|
+
torch._check(lengths.size(0) + 1 == offsets.size(0))
|
|
993
|
+
# pyre-ignore
|
|
994
|
+
torch._check(lengths.size(0) % batch_size == 0)
|
|
995
|
+
|
|
996
|
+
if weights is not None:
|
|
997
|
+
# weights must have the same shape as values
|
|
998
|
+
torch._check(values.shape == weights.shape)
|
|
999
|
+
|
|
1000
|
+
if selected_lengths_sum is None:
|
|
1001
|
+
selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
|
|
1002
|
+
|
|
1003
|
+
torch._check_is_size(selected_lengths_sum)
|
|
1004
|
+
vlw: list[torch.Tensor] = [
|
|
1005
|
+
values.new_empty([selected_lengths_sum]), # output
|
|
1006
|
+
lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
|
|
1007
|
+
]
|
|
1008
|
+
if weights is not None:
|
|
1009
|
+
vlw.append(weights.new_empty([selected_lengths_sum])) # output_weights
|
|
1010
|
+
|
|
1011
|
+
return [
|
|
1012
|
+
*vlw,
|
|
1013
|
+
offsets.new_empty([indices.shape[0] * num_batches]), # output_offsets
|
|
1014
|
+
torch.empty([4], dtype=torch.int64, device="cpu"), # saved_data_tensor
|
|
1015
|
+
]
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract(
|
|
1019
|
+
grad: torch.Tensor,
|
|
1020
|
+
indices: torch.Tensor,
|
|
1021
|
+
grad_offsets: torch.Tensor,
|
|
1022
|
+
output_offsets: torch.Tensor,
|
|
1023
|
+
saved_tensor: torch.Tensor,
|
|
1024
|
+
) -> torch.Tensor:
|
|
1025
|
+
return grad.new_empty([torch.library.get_ctx().new_dynamic_size()])
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def permute_pooled_embs_split_abstract(
|
|
1029
|
+
pooled_embs: Tensor,
|
|
1030
|
+
offset_dim_list: Tensor,
|
|
1031
|
+
permute_list: Tensor,
|
|
1032
|
+
inv_offset_dim_list: Tensor,
|
|
1033
|
+
inv_permute_list: Tensor,
|
|
1034
|
+
) -> Tensor:
|
|
1035
|
+
return torch.empty_like(pooled_embs)
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def histogram_binning_calibration_abstract(
|
|
1039
|
+
logit: Tensor,
|
|
1040
|
+
bin_num_examples: Tensor,
|
|
1041
|
+
bin_num_positives: Tensor,
|
|
1042
|
+
positive_weight: float,
|
|
1043
|
+
lower_bound: float,
|
|
1044
|
+
upper_bound: float,
|
|
1045
|
+
bin_ctr_in_use_after: int,
|
|
1046
|
+
bin_ctr_weight_value: float,
|
|
1047
|
+
) -> tuple[Tensor, Tensor]:
|
|
1048
|
+
return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
def float_to_hfp8_quantized(
|
|
1052
|
+
input: Tensor, ebits: int, exponent_bias: int, max_pos: float
|
|
1053
|
+
) -> Tensor:
|
|
1054
|
+
return torch.empty_like(input, dtype=torch.uint8)
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
def hfp8_quantized_to_float(input: Tensor, ebits: int, exponent_bias: int) -> Tensor:
|
|
1058
|
+
return torch.empty_like(input, dtype=torch.float32)
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
def float_or_half_to_fused_nbit_rowwise_quantized_sbhalf(
|
|
1062
|
+
input_t: Tensor,
|
|
1063
|
+
bit_rate: int,
|
|
1064
|
+
) -> Tensor:
|
|
1065
|
+
input_sizes = input_t.size()
|
|
1066
|
+
torch._check(len(input_sizes) == 2)
|
|
1067
|
+
nrows = input_sizes[0]
|
|
1068
|
+
ncols = input_sizes[1]
|
|
1069
|
+
num_elem_per_byte = 8 // bit_rate
|
|
1070
|
+
|
|
1071
|
+
torch._check(ncols % (2 * num_elem_per_byte) == 0)
|
|
1072
|
+
output_columns = (ncols + num_elem_per_byte - 1) // num_elem_per_byte + 2 * 2
|
|
1073
|
+
output = torch.empty(
|
|
1074
|
+
(nrows, output_columns), device=input_t.device, dtype=torch.uint8
|
|
1075
|
+
)
|
|
1076
|
+
return output
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
def fused_nbit_rowwise_quantized_sb_half_to_float_or_half(
|
|
1080
|
+
input_t: Tensor,
|
|
1081
|
+
bit_rate: int,
|
|
1082
|
+
output_dtype: int = 0,
|
|
1083
|
+
) -> Tensor:
|
|
1084
|
+
torch._check(output_dtype in [SparseType.FP32.as_int(), SparseType.FP16.as_int()])
|
|
1085
|
+
nrows = input_t.size(0)
|
|
1086
|
+
ncols = input_t.size(1)
|
|
1087
|
+
if input_t.dtype == torch.quint2x4:
|
|
1088
|
+
ncols = (ncols + 3) // 4
|
|
1089
|
+
elif input_t.dtype == torch.quint4x2:
|
|
1090
|
+
ncols = (ncols + 1) // 2
|
|
1091
|
+
num_elem_per_byte = 8 // bit_rate
|
|
1092
|
+
output_columns = (ncols - 2 * 2) * num_elem_per_byte
|
|
1093
|
+
if output_dtype == SparseType.FP32.as_int():
|
|
1094
|
+
return torch.empty(
|
|
1095
|
+
(nrows, output_columns), dtype=torch.float32, device=input_t.device
|
|
1096
|
+
)
|
|
1097
|
+
else: # output_dtype is SparseType.FP16
|
|
1098
|
+
return torch.empty(
|
|
1099
|
+
(nrows, output_columns), dtype=torch.float16, device=input_t.device
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
def fused_8_bit_rowwise_quantized_to_float_or_half(
|
|
1104
|
+
input_t: Tensor,
|
|
1105
|
+
output_dtype: int = 0,
|
|
1106
|
+
scale_bias_last: bool = True,
|
|
1107
|
+
quant_padding_float_type: bool = True,
|
|
1108
|
+
) -> Tensor:
|
|
1109
|
+
torch._check(
|
|
1110
|
+
output_dtype
|
|
1111
|
+
in [
|
|
1112
|
+
SparseType.FP32.as_int(),
|
|
1113
|
+
SparseType.FP16.as_int(),
|
|
1114
|
+
SparseType.BF16.as_int(),
|
|
1115
|
+
]
|
|
1116
|
+
)
|
|
1117
|
+
torch._check(quant_padding_float_type or not scale_bias_last)
|
|
1118
|
+
torch._check(input_t.dim() >= 2)
|
|
1119
|
+
last_dim = input_t.dim() - 1
|
|
1120
|
+
output_shape = list(input_t.shape)
|
|
1121
|
+
ncols = input_t.size(last_dim)
|
|
1122
|
+
quant_padding_size = 4 if quant_padding_float_type else 2
|
|
1123
|
+
ncols_aligned = (
|
|
1124
|
+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
|
|
1125
|
+
)
|
|
1126
|
+
output_columns = ncols_aligned - 2 * quant_padding_size
|
|
1127
|
+
output_shape[last_dim] = output_columns
|
|
1128
|
+
if output_dtype == SparseType.FP32.as_int():
|
|
1129
|
+
return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
|
|
1130
|
+
elif output_dtype == SparseType.FP16.as_int():
|
|
1131
|
+
return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
|
|
1132
|
+
else: # output_dtype is SparseType.BF16
|
|
1133
|
+
return torch.empty(output_shape, dtype=torch.bfloat16, device=input_t.device)
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
def float_or_half_to_fused_8_bit_rowwise(
|
|
1137
|
+
input_t: Tensor,
|
|
1138
|
+
) -> Tensor:
|
|
1139
|
+
torch._check(input_t.dim() >= 2)
|
|
1140
|
+
last_dim = input_t.dim() - 1
|
|
1141
|
+
output_shape = list(input_t.shape)
|
|
1142
|
+
ncols = input_t.size(last_dim)
|
|
1143
|
+
ncols_aligned = (ncols + 4 - 1) // 4 * 4
|
|
1144
|
+
output_columns = ncols_aligned + 2 * 4
|
|
1145
|
+
output_shape[last_dim] = output_columns
|
|
1146
|
+
return torch.empty(output_shape, dtype=torch.uint8, device=input_t.device)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
def fused_8_bit_rowwise_quantized_to_float(
|
|
1150
|
+
input_t: Tensor,
|
|
1151
|
+
scale_bias_last: bool = True,
|
|
1152
|
+
quant_padding_float_type: bool = True,
|
|
1153
|
+
) -> Tensor:
|
|
1154
|
+
torch._check(quant_padding_float_type or not scale_bias_last)
|
|
1155
|
+
torch._check(input_t.dim() >= 2)
|
|
1156
|
+
last_dim = input_t.dim() - 1
|
|
1157
|
+
output_shape = list(input_t.shape)
|
|
1158
|
+
ncols = input_t.size(last_dim)
|
|
1159
|
+
quant_padding_size = 4 if quant_padding_float_type else 2
|
|
1160
|
+
ncols_aligned = (
|
|
1161
|
+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
|
|
1162
|
+
)
|
|
1163
|
+
output_columns = ncols_aligned - 2 * quant_padding_size
|
|
1164
|
+
output_shape[last_dim] = output_columns
|
|
1165
|
+
return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
def fused_8_bit_rowwise_quantized_to_half(
|
|
1169
|
+
input_t: Tensor,
|
|
1170
|
+
scale_bias_last: bool = True,
|
|
1171
|
+
quant_padding_float_type: bool = True,
|
|
1172
|
+
) -> Tensor:
|
|
1173
|
+
torch._check(quant_padding_float_type or not scale_bias_last)
|
|
1174
|
+
torch._check(input_t.dim() >= 2)
|
|
1175
|
+
last_dim = input_t.dim() - 1
|
|
1176
|
+
output_shape = list(input_t.shape)
|
|
1177
|
+
ncols = input_t.size(last_dim)
|
|
1178
|
+
quant_padding_size = 4 if quant_padding_float_type else 2
|
|
1179
|
+
ncols_aligned = (
|
|
1180
|
+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
|
|
1181
|
+
)
|
|
1182
|
+
output_columns = ncols_aligned - 2 * quant_padding_size
|
|
1183
|
+
output_shape[last_dim] = output_columns
|
|
1184
|
+
return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
def generic_histogram_binning_calibration_by_feature(
|
|
1188
|
+
logit: Tensor,
|
|
1189
|
+
segment_value: Tensor,
|
|
1190
|
+
segment_lengths: Tensor,
|
|
1191
|
+
num_segments: int,
|
|
1192
|
+
bin_num_examples: Tensor,
|
|
1193
|
+
bin_num_positives: Tensor,
|
|
1194
|
+
bin_boundaries: Tensor,
|
|
1195
|
+
positive_weight: float,
|
|
1196
|
+
bin_ctr_in_use_after: int,
|
|
1197
|
+
bin_ctr_weight_value: float,
|
|
1198
|
+
) -> tuple[Tensor, Tensor]:
|
|
1199
|
+
torch._check(bin_num_examples.numel() == bin_num_positives.numel())
|
|
1200
|
+
torch._check(
|
|
1201
|
+
bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
|
|
1202
|
+
)
|
|
1203
|
+
return torch.empty_like(logit), torch.empty(
|
|
1204
|
+
[logit.numel()], dtype=torch.int64, device=logit.device
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
def permute_multi_embedding_function_impl_abstract(
|
|
1209
|
+
pooled_embs: list[Tensor],
|
|
1210
|
+
permutes: Tensor,
|
|
1211
|
+
in_shapes: Tensor,
|
|
1212
|
+
out_shapes: Tensor,
|
|
1213
|
+
out_lengths: list[int],
|
|
1214
|
+
reverse: bool = False,
|
|
1215
|
+
) -> list[Tensor]:
|
|
1216
|
+
out_dtype = pooled_embs[0].dtype
|
|
1217
|
+
bs = pooled_embs[0].shape[0]
|
|
1218
|
+
torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
|
|
1219
|
+
|
|
1220
|
+
output = []
|
|
1221
|
+
for i in range(len(out_lengths)):
|
|
1222
|
+
output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
|
|
1223
|
+
return output
|
|
1224
|
+
|
|
1225
|
+
|
|
1226
|
+
def lengths_range_abstract(
|
|
1227
|
+
lengths: Tensor,
|
|
1228
|
+
output_shape: Optional[Sequence[int]] = None,
|
|
1229
|
+
) -> Tensor:
|
|
1230
|
+
torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
|
|
1231
|
+
output_size = 0
|
|
1232
|
+
if output_shape is not None:
|
|
1233
|
+
output_size = math.prod(output_shape)
|
|
1234
|
+
else:
|
|
1235
|
+
ctx = torch.library.get_ctx()
|
|
1236
|
+
output_size = ctx.new_dynamic_size()
|
|
1237
|
+
return lengths.new_empty([output_size], dtype=lengths.dtype)
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def all_to_one_device(
|
|
1241
|
+
input_tensors: list[Tensor],
|
|
1242
|
+
target_device: torch.device,
|
|
1243
|
+
) -> list[Tensor]:
|
|
1244
|
+
return [
|
|
1245
|
+
torch.empty_like(input_tensor, device=torch.device("meta"))
|
|
1246
|
+
for input_tensor in input_tensors
|
|
1247
|
+
]
|
|
1248
|
+
|
|
1249
|
+
|
|
1250
|
+
def sum_reduce_to_one(
|
|
1251
|
+
input_tensors: list[Tensor],
|
|
1252
|
+
target_device: torch.device,
|
|
1253
|
+
) -> Tensor:
|
|
1254
|
+
torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
|
|
1255
|
+
# All tensors should have the same shape
|
|
1256
|
+
first_tensor = input_tensors[0]
|
|
1257
|
+
return torch.empty_like(first_tensor, device=torch.device("meta"))
|
|
1258
|
+
|
|
1259
|
+
|
|
1260
|
+
def _setup() -> None:
|
|
1261
|
+
# pyre-ignore[16]
|
|
1262
|
+
_setup.done = getattr(_setup, "done", False)
|
|
1263
|
+
|
|
1264
|
+
# pyre-ignore[2]
|
|
1265
|
+
def impl_abstract(op_name, fn) -> None:
|
|
1266
|
+
# NOTE: Failures have occasionally been observed with register_fake,
|
|
1267
|
+
# where the error signatures can be found in:
|
|
1268
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/_library/fake_impl.py
|
|
1269
|
+
#
|
|
1270
|
+
# To work around this, we first check if the kernel is already registered
|
|
1271
|
+
# for the following dispatch keys, and if so, we skip the registration.
|
|
1272
|
+
for dkey in ["CompositeImplicitAutograd", "Meta"]:
|
|
1273
|
+
if torch._C._dispatch_has_kernel_for_dispatch_key(op_name, dkey):
|
|
1274
|
+
return
|
|
1275
|
+
torch.library.register_fake(op_name, fn)
|
|
1276
|
+
|
|
1277
|
+
# pyre-ignore[2,24]
|
|
1278
|
+
def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None:
|
|
1279
|
+
name_split = op_name.split("::")
|
|
1280
|
+
key = f"{name_split[0]}/{name_split[-1]}/Autograd"
|
|
1281
|
+
if key not in torch.library._impls:
|
|
1282
|
+
torch.library.register_autograd(op_name, fn, setup_context=setup_context)
|
|
1283
|
+
|
|
1284
|
+
if not _setup.done:
|
|
1285
|
+
impl_autograd(
|
|
1286
|
+
"fbgemm::permute_2D_sparse_data",
|
|
1287
|
+
permute_2D_sparse_data_backward,
|
|
1288
|
+
setup_context=permute_2D_sparse_data_setup_context,
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
|
|
1292
|
+
impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
|
|
1293
|
+
impl_abstract(
|
|
1294
|
+
"fbgemm::permute_2D_sparse_data_input1D",
|
|
1295
|
+
permute_2D_sparse_data_input1D_meta,
|
|
1296
|
+
)
|
|
1297
|
+
impl_abstract("fbgemm::invert_permute", invert_permute_abstract)
|
|
1298
|
+
impl_abstract("fbgemm::permute_1D_sparse_data", permute_1D_sparse_data_meta)
|
|
1299
|
+
impl_abstract("fbgemm::masked_select_jagged_1d", masked_select_jagged_1d)
|
|
1300
|
+
impl_abstract("fbgemm::tbe_input_combine", tbe_input_combine_abstract)
|
|
1301
|
+
impl_abstract(
|
|
1302
|
+
"fbgemm::tbe_input_combine_with_length",
|
|
1303
|
+
tbe_input_combine_with_length_abstract,
|
|
1304
|
+
)
|
|
1305
|
+
impl_abstract(
|
|
1306
|
+
"fbgemm::jagged_index_select_2d_forward_v2",
|
|
1307
|
+
jagged_index_select_2d_forward_v2_abstract,
|
|
1308
|
+
)
|
|
1309
|
+
impl_abstract(
|
|
1310
|
+
"fbgemm::jagged_index_add_2d_forward_v2",
|
|
1311
|
+
jagged_index_add_2d_forward_v2_abstract,
|
|
1312
|
+
)
|
|
1313
|
+
impl_abstract(
|
|
1314
|
+
"fbgemm::expand_into_jagged_permute", expand_into_jagged_permute_meta
|
|
1315
|
+
)
|
|
1316
|
+
impl_abstract("fbgemm::pruned_array_lookup", pruned_array_lookup_meta)
|
|
1317
|
+
impl_abstract(
|
|
1318
|
+
"fbgemm::int_nbit_split_embedding_codegen_lookup_function",
|
|
1319
|
+
int_nbit_split_embedding_codegen_lookup_function_meta,
|
|
1320
|
+
)
|
|
1321
|
+
impl_abstract(
|
|
1322
|
+
"fbgemm::block_bucketize_sparse_features",
|
|
1323
|
+
block_bucketize_sparse_features_meta,
|
|
1324
|
+
)
|
|
1325
|
+
impl_abstract(
|
|
1326
|
+
"fbgemm::block_bucketize_sparse_features_2d_weights",
|
|
1327
|
+
block_bucketize_sparse_features_2d_weights_meta,
|
|
1328
|
+
)
|
|
1329
|
+
impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
|
|
1330
|
+
impl_abstract(
|
|
1331
|
+
"fbgemm::permute_sparse_features", permute_sparse_features_abstract
|
|
1332
|
+
)
|
|
1333
|
+
impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
|
|
1334
|
+
impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
|
|
1335
|
+
impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
|
|
1336
|
+
impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
|
|
1337
|
+
impl_abstract(
|
|
1338
|
+
"fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
|
|
1339
|
+
)
|
|
1340
|
+
impl_abstract(
|
|
1341
|
+
"fbgemm::batch_index_select_dim0_tensor",
|
|
1342
|
+
batch_index_select_dim0_tensor_abstract,
|
|
1343
|
+
)
|
|
1344
|
+
impl_abstract(
|
|
1345
|
+
"fbgemm::batch_index_select_dim0_forward_cuda_impl",
|
|
1346
|
+
batch_index_select_dim0_forward_cuda_impl_abstract,
|
|
1347
|
+
)
|
|
1348
|
+
impl_abstract(
|
|
1349
|
+
"fbgemm::batch_index_select_dim0_tensor_forward_cuda_impl",
|
|
1350
|
+
batch_index_select_dim0_tensor_forward_cuda_impl_abstract,
|
|
1351
|
+
)
|
|
1352
|
+
impl_abstract(
|
|
1353
|
+
"fbgemm::batch_index_select_dim0_tensor_backward_cuda_impl",
|
|
1354
|
+
batch_index_select_dim0_tensor_backward_cuda_impl_abstract,
|
|
1355
|
+
)
|
|
1356
|
+
impl_abstract(
|
|
1357
|
+
"fbgemm::batch_index_select_dim0_backward_cuda_impl",
|
|
1358
|
+
batch_index_select_dim0_backward_cuda_impl_abstract,
|
|
1359
|
+
)
|
|
1360
|
+
impl_abstract(
|
|
1361
|
+
"fbgemm::keyed_jagged_index_select_dim1",
|
|
1362
|
+
keyed_jagged_index_select_dim1_abstract,
|
|
1363
|
+
)
|
|
1364
|
+
impl_abstract(
|
|
1365
|
+
"fbgemm::batch_index_select_dim0_forward_cpu_impl",
|
|
1366
|
+
batch_index_select_dim0_forward_cpu_impl_abstract,
|
|
1367
|
+
)
|
|
1368
|
+
impl_abstract(
|
|
1369
|
+
"fbgemm::batch_index_select_dim0_tensor_forward_cpu_impl",
|
|
1370
|
+
batch_index_select_dim0_tensor_forward_cpu_impl_abstract,
|
|
1371
|
+
)
|
|
1372
|
+
impl_abstract(
|
|
1373
|
+
"fbgemm::batch_index_select_dim0_backward_cpu_impl",
|
|
1374
|
+
batch_index_select_dim0_backward_cpu_impl_abstract,
|
|
1375
|
+
)
|
|
1376
|
+
impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract)
|
|
1377
|
+
impl_abstract(
|
|
1378
|
+
"fbgemm::group_index_select_dim0_gpu_impl",
|
|
1379
|
+
group_index_select_dim0_gpu_impl_abstract,
|
|
1380
|
+
)
|
|
1381
|
+
impl_abstract(
|
|
1382
|
+
"fbgemm::group_index_select_dim0_gpu_backward",
|
|
1383
|
+
group_index_select_dim0_gpu_backward_abstract,
|
|
1384
|
+
)
|
|
1385
|
+
impl_abstract(
|
|
1386
|
+
"fbgemm::keyed_jagged_index_select_dim1_forward",
|
|
1387
|
+
keyed_jagged_index_select_dim1_forward_cuda_impl_abstract,
|
|
1388
|
+
)
|
|
1389
|
+
impl_abstract(
|
|
1390
|
+
"fbgemm::keyed_jagged_index_select_dim1_backward",
|
|
1391
|
+
keyed_jagged_index_select_dim1_backward_cuda_impl_abstract,
|
|
1392
|
+
)
|
|
1393
|
+
impl_abstract(
|
|
1394
|
+
"fbgemm::permute_pooled_embs_split", permute_pooled_embs_split_abstract
|
|
1395
|
+
)
|
|
1396
|
+
impl_abstract(
|
|
1397
|
+
"fbgemm::histogram_binning_calibration",
|
|
1398
|
+
histogram_binning_calibration_abstract,
|
|
1399
|
+
)
|
|
1400
|
+
impl_abstract(
|
|
1401
|
+
"fbgemm::generic_histogram_binning_calibration_by_feature",
|
|
1402
|
+
generic_histogram_binning_calibration_by_feature,
|
|
1403
|
+
)
|
|
1404
|
+
impl_abstract(
|
|
1405
|
+
"fbgemm::lengths_range",
|
|
1406
|
+
lengths_range_abstract,
|
|
1407
|
+
)
|
|
1408
|
+
impl_abstract(
|
|
1409
|
+
"fbgemm::permute_multi_embedding_function",
|
|
1410
|
+
permute_multi_embedding_function_impl_abstract,
|
|
1411
|
+
)
|
|
1412
|
+
impl_abstract(
|
|
1413
|
+
"fbgemm::FloatToHFP8Quantized",
|
|
1414
|
+
float_to_hfp8_quantized,
|
|
1415
|
+
)
|
|
1416
|
+
impl_abstract(
|
|
1417
|
+
"fbgemm::HFP8QuantizedToFloat",
|
|
1418
|
+
hfp8_quantized_to_float,
|
|
1419
|
+
)
|
|
1420
|
+
impl_abstract(
|
|
1421
|
+
"fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
|
|
1422
|
+
float_or_half_to_fused_nbit_rowwise_quantized_sbhalf,
|
|
1423
|
+
)
|
|
1424
|
+
impl_abstract(
|
|
1425
|
+
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
|
|
1426
|
+
fused_nbit_rowwise_quantized_sb_half_to_float_or_half,
|
|
1427
|
+
)
|
|
1428
|
+
impl_abstract(
|
|
1429
|
+
"fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf",
|
|
1430
|
+
fused_8_bit_rowwise_quantized_to_float_or_half,
|
|
1431
|
+
)
|
|
1432
|
+
impl_abstract(
|
|
1433
|
+
"fbgemm::FloatToFused8BitRowwiseQuantized",
|
|
1434
|
+
float_or_half_to_fused_8_bit_rowwise,
|
|
1435
|
+
)
|
|
1436
|
+
impl_abstract(
|
|
1437
|
+
"fbgemm::FloatOrHalfToFused8BitRowwiseQuantized",
|
|
1438
|
+
float_or_half_to_fused_8_bit_rowwise,
|
|
1439
|
+
)
|
|
1440
|
+
impl_abstract(
|
|
1441
|
+
"fbgemm::HalfToFused8BitRowwiseQuantized",
|
|
1442
|
+
float_or_half_to_fused_8_bit_rowwise,
|
|
1443
|
+
)
|
|
1444
|
+
impl_abstract(
|
|
1445
|
+
"fbgemm::Fused8BitRowwiseQuantizedToFloat",
|
|
1446
|
+
fused_8_bit_rowwise_quantized_to_float,
|
|
1447
|
+
)
|
|
1448
|
+
impl_abstract(
|
|
1449
|
+
"fbgemm::Fused8BitRowwiseQuantizedToHalf",
|
|
1450
|
+
fused_8_bit_rowwise_quantized_to_half,
|
|
1451
|
+
)
|
|
1452
|
+
_setup.done = True
|
|
1453
|
+
|
|
1454
|
+
|
|
1455
|
+
_setup()
|