fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# [fbgemm-gpu.autogen.docs.examples.docstring.start]
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .common import add_docs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
add_docs(
|
|
14
|
+
torch.ops.fbgemm.jagged_2d_to_dense,
|
|
15
|
+
"""
|
|
16
|
+
jagged_2d_to_dense(values, x_offsets, max_sequence_length) -> Tensor
|
|
17
|
+
|
|
18
|
+
Converts a jagged tensor, with a 2D values array into a dense tensor, padding with zeros.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
values (Tensor): 2D tensor containing the values of the jagged tensor.
|
|
22
|
+
|
|
23
|
+
x_offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor.
|
|
24
|
+
|
|
25
|
+
max_sequence_length (int): Maximum length of any row in the jagged dimension.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: The padded dense tensor
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]])
|
|
32
|
+
>>> x_offsets = torch.tensor([0, 1, 3])
|
|
33
|
+
>>> torch.ops.fbgemm.jagged_2d_to_dense(values, x_offsets, 3)
|
|
34
|
+
tensor([[[1, 1],
|
|
35
|
+
[0, 0],
|
|
36
|
+
[0, 0]],
|
|
37
|
+
[[2, 2],
|
|
38
|
+
[3, 3],
|
|
39
|
+
[0, 0]]])
|
|
40
|
+
|
|
41
|
+
""",
|
|
42
|
+
)
|
|
43
|
+
# [fbgemm-gpu.autogen.docs.examples.docstring.end]
|
|
44
|
+
|
|
45
|
+
add_docs(
|
|
46
|
+
torch.ops.fbgemm.jagged_1d_to_dense,
|
|
47
|
+
"""
|
|
48
|
+
jagged_1d_to_dense(values, offsets, max_sequence_length, padding_value) -> Tensor)
|
|
49
|
+
|
|
50
|
+
Converts a jagged tensor, with a 1D values array, into a dense tensor, padding with a specified padding value.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
values (Tensor): 1D tensor containing the values of the jagged tensor.
|
|
54
|
+
|
|
55
|
+
offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor.
|
|
56
|
+
|
|
57
|
+
max_sequence_length (int): Maximum length of any row in the jagged dimension.
|
|
58
|
+
|
|
59
|
+
padding_value (int): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Tensor: the padded dense tensor
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> values = torch.tensor([1,2,3,4])
|
|
66
|
+
>>> offsets = torch.tensor([0, 1, 3])
|
|
67
|
+
>>> torch.ops.fbgemm.jagged_1d_to_dense(values, x_offsets, 3, 0)
|
|
68
|
+
tensor([[1, 0, 0],
|
|
69
|
+
[2, 3, 0]])
|
|
70
|
+
|
|
71
|
+
""",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
add_docs(
|
|
75
|
+
torch.ops.fbgemm.dense_to_jagged,
|
|
76
|
+
"""
|
|
77
|
+
dense_to_jagged(dense, x_offsets, total_L) -> (Tensor, Tensor[])
|
|
78
|
+
|
|
79
|
+
Converts a dense tensor into a jagged tensor, given the desired offsets of the resulting dense tensor.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dense (Tensor): A dense input tensor to be converted
|
|
83
|
+
|
|
84
|
+
x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
85
|
+
|
|
86
|
+
total_L (int, Optional): Total number of values in the resulting jagged tensor.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
(Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> dense = torch.tensor([[[1, 1], [0, 0], [0, 0]], [[2, 2], [3, 3], [0, 0]]])
|
|
93
|
+
>>> x_offsets = torch.tensor([0, 1, 3])
|
|
94
|
+
>>> torch.ops.fbgemm.dense_to_jagged(dense, [x_offsets])
|
|
95
|
+
(tensor([[1, 1],
|
|
96
|
+
[2, 2],
|
|
97
|
+
[3, 3]]), [tensor([0, 1, 3])])
|
|
98
|
+
|
|
99
|
+
""",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
add_docs(
|
|
104
|
+
torch.ops.fbgemm.jagged_to_padded_dense,
|
|
105
|
+
"""
|
|
106
|
+
jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0) -> Tensor
|
|
107
|
+
|
|
108
|
+
Converts a jagged tensor into a dense tensor, padding with a specified padding value.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
values (Tensor): Jagged tensor values
|
|
112
|
+
|
|
113
|
+
offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
114
|
+
|
|
115
|
+
max_lengths (int[]): A list with max_length for each jagged dimension.
|
|
116
|
+
|
|
117
|
+
padding_value (float): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Tensor: the padded dense tensor
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
>>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]])
|
|
124
|
+
>>> offsets = torch.tensor([0, 1, 3])
|
|
125
|
+
>>> torch.ops.fbgemm.jagged_to_padded_dense(values, [offsets], [3], 7)
|
|
126
|
+
tensor([[[1, 1],
|
|
127
|
+
[7, 7],
|
|
128
|
+
[7, 7]],
|
|
129
|
+
[[2, 2],
|
|
130
|
+
[3, 3],
|
|
131
|
+
[7, 7]]])
|
|
132
|
+
""",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
add_docs(
|
|
137
|
+
torch.ops.fbgemm.jagged_dense_elementwise_add,
|
|
138
|
+
"""
|
|
139
|
+
jagged_dense_elementwise_add(x_values, x_offsets, y) -> Tensor
|
|
140
|
+
|
|
141
|
+
Adds a jagged tensor to a dense tensor, resulting in dense tensor. Jagged
|
|
142
|
+
tensor input will be padded with zeros for the purposes of the addition.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
x_values (Tensor): Jagged tensor values
|
|
146
|
+
|
|
147
|
+
offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
148
|
+
|
|
149
|
+
y (Tensor): A dense tensor
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Tensor: The sum of jagged input tensor + y
|
|
153
|
+
|
|
154
|
+
""",
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
add_docs(
|
|
159
|
+
torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output,
|
|
160
|
+
"""
|
|
161
|
+
jagged_dense_elementwise_add_jagged_output(x_values, x_offsets, y) -> (Tensor, Tensor[])
|
|
162
|
+
|
|
163
|
+
Adds a jagged tensor to a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
x_values (Tensor): Jagged tensor values
|
|
167
|
+
|
|
168
|
+
x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
169
|
+
|
|
170
|
+
y (Tensor): A dense tensor
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
(Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
|
|
174
|
+
|
|
175
|
+
""",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
add_docs(
|
|
180
|
+
torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output,
|
|
181
|
+
"""
|
|
182
|
+
jagged_dense_dense_elementwise_add_jagged_output(x_values, x_offsets, y_0, y_1) -> (Tensor, Tensor[])
|
|
183
|
+
|
|
184
|
+
Adds a jagged tensor to the sum of two dense tensors, resulting in a jagged tensor with the same structure as the input jagged tensor.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
x_values (Tensor): Jagged tensor values
|
|
188
|
+
|
|
189
|
+
x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
190
|
+
|
|
191
|
+
y_0 (Tensor): A dense tensor
|
|
192
|
+
|
|
193
|
+
y_1 (Tensor): A dense tensor
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
(Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
|
|
197
|
+
|
|
198
|
+
""",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
add_docs(
|
|
203
|
+
torch.ops.fbgemm.jagged_dense_elementwise_mul,
|
|
204
|
+
"""
|
|
205
|
+
jagged_dense_elementwise_mul(x_values, x_offsets, y) -> (Tensor, Tensor[])
|
|
206
|
+
|
|
207
|
+
Elementwise-multiplies a jagged tensor a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
x_values (Tensor): Jagged tensor values
|
|
211
|
+
|
|
212
|
+
x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
|
|
213
|
+
|
|
214
|
+
y (Tensor): A dense tensor
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
(Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
|
|
218
|
+
|
|
219
|
+
""",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
add_docs(
|
|
223
|
+
torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul,
|
|
224
|
+
"""
|
|
225
|
+
batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor
|
|
226
|
+
|
|
227
|
+
Batched vector matrix multiplication of a batched dense vector with a jagged tensor, dense vector is in
|
|
228
|
+
size (B * H, max_N) and jagged tensor is in size (B, max_N, H * D) where max_N is the maximum size of
|
|
229
|
+
jagged dimension. B * H is the batch size and each multiplies is max_N with [max_N, D]
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
v (Tensor): dense vector tensor
|
|
233
|
+
|
|
234
|
+
a_values (Tensor): Jagged tensor values
|
|
235
|
+
|
|
236
|
+
a_offsets (Tensor []): A list of jagged offset tensors, one for each jagged dimension.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Tensor: output of batch matmul in size (B * H, D)
|
|
240
|
+
|
|
241
|
+
""",
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# add_docs(
|
|
245
|
+
# torch.ops.fbgemm.stacked_jagged_1d_to_dense,
|
|
246
|
+
# """Args:
|
|
247
|
+
# {input}
|
|
248
|
+
# Keyword args:
|
|
249
|
+
# {out}""",
|
|
250
|
+
# )
|
|
251
|
+
#
|
|
252
|
+
#
|
|
253
|
+
# add_docs(
|
|
254
|
+
# torch.ops.fbgemm.stacked_jagged_2d_to_dense,
|
|
255
|
+
# """Args:
|
|
256
|
+
# {input}
|
|
257
|
+
# Keyword args:
|
|
258
|
+
# {out}""",
|
|
259
|
+
# )
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from .common import add_docs
|
|
10
|
+
|
|
11
|
+
add_docs(
|
|
12
|
+
torch.ops.fbgemm.merge_pooled_embeddings,
|
|
13
|
+
"""
|
|
14
|
+
merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor
|
|
15
|
+
|
|
16
|
+
Concatenate embedding outputs from different devices (on the same host)
|
|
17
|
+
on to the target device.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
pooled_embeddings (List[Tensor]): A list of embedding outputs from
|
|
21
|
+
different devices on the same host. Each output has 2
|
|
22
|
+
dimensions.
|
|
23
|
+
|
|
24
|
+
uncat_dim_size (int): The size of the dimension that is not
|
|
25
|
+
concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size
|
|
26
|
+
of dim 1 and vice versa.
|
|
27
|
+
|
|
28
|
+
target_device (torch.device): The target device that aggregates all
|
|
29
|
+
the embedding outputs.
|
|
30
|
+
|
|
31
|
+
cat_dim (int = 1): The dimension that the tensors are concatenated
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The concatenated embedding output (2D) on the target device
|
|
35
|
+
""",
|
|
36
|
+
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from .common import add_docs
|
|
10
|
+
|
|
11
|
+
add_docs(
|
|
12
|
+
torch.ops.fbgemm.permute_pooled_embs,
|
|
13
|
+
"""
|
|
14
|
+
permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) -> Tensor
|
|
15
|
+
|
|
16
|
+
Permute embedding outputs along the feature dimension.
|
|
17
|
+
|
|
18
|
+
The embedding output tensor `pooled_embs` contains the embedding outputs
|
|
19
|
+
for all features in a batch. It is represented in a 2D format, where the
|
|
20
|
+
rows are the batch size dimension and the columns are the feature *
|
|
21
|
+
embedding dimension. Permuting along the feature dimension is
|
|
22
|
+
essentially permuting along the second dimension (dim 1).
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pooled_embs (Tensor): The embedding outputs to permute. Shape is
|
|
26
|
+
`(B_local, total_global_D)`, where `B_local` = a local batch size
|
|
27
|
+
and `total_global_D` is the total embedding dimension across all
|
|
28
|
+
features (global)
|
|
29
|
+
|
|
30
|
+
offset_dim_list (Tensor): The complete cumulative sum of embedding
|
|
31
|
+
dimensions of all features. Shape is `T + 1` where `T` is the
|
|
32
|
+
total number of features
|
|
33
|
+
|
|
34
|
+
permute_list (Tensor): A tensor that describes how each feature is
|
|
35
|
+
permuted. `permute_list[i]` indicates that the feature
|
|
36
|
+
`permute_list[i]` is permuted to position `i`
|
|
37
|
+
|
|
38
|
+
inv_offset_dim_list (Tensor): The complete cumulative sum of inverse
|
|
39
|
+
embedding dimensions, which are the permuted embedding dimensions.
|
|
40
|
+
`inv_offset_dim_list[i]` represents the starting embedding position of
|
|
41
|
+
feature `permute_list[i]`
|
|
42
|
+
|
|
43
|
+
inv_permute_list (Tensor): The inverse permute list, which contains the
|
|
44
|
+
permuted positions of each feature. `inv_permute_list[i]` represents
|
|
45
|
+
the permuted position of feature `i`
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Permuted embedding outputs (Tensor). Same shape as `pooled_embs`
|
|
49
|
+
|
|
50
|
+
**Example:**
|
|
51
|
+
|
|
52
|
+
>>> import torch
|
|
53
|
+
>>> from itertools import accumulate
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Suppose batch size = 3 and there are 3 features
|
|
56
|
+
>>> batch_size = 3
|
|
57
|
+
>>>
|
|
58
|
+
>>> # Embedding dimensions for each feature
|
|
59
|
+
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
|
|
60
|
+
>>>
|
|
61
|
+
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
|
|
62
|
+
>>> # to position 1, so on
|
|
63
|
+
>>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda")
|
|
64
|
+
>>>
|
|
65
|
+
>>> # Compute embedding dim offsets
|
|
66
|
+
>>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda")
|
|
67
|
+
>>> print(offset_dim_list)
|
|
68
|
+
>>>
|
|
69
|
+
tensor([ 0, 4, 8, 16], device='cuda:0')
|
|
70
|
+
>>>
|
|
71
|
+
>>> # Compute inverse embedding dims
|
|
72
|
+
>>> inv_embs_dims = [embs_dims[p] for p in permute]
|
|
73
|
+
>>> # Compute complete cumulative sum of inverse embedding dims
|
|
74
|
+
>>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda")
|
|
75
|
+
>>> print(inv_offset_dim_list)
|
|
76
|
+
>>>
|
|
77
|
+
tensor([ 0, 8, 12, 16], device='cuda:0')
|
|
78
|
+
>>>
|
|
79
|
+
>>> # Compute inverse permutes
|
|
80
|
+
>>> inv_permute = [0] * len(permute)
|
|
81
|
+
>>> for i, p in enumerate(permute):
|
|
82
|
+
>>> inv_permute[p] = i
|
|
83
|
+
>>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda")
|
|
84
|
+
>>> print(inv_permute_list)
|
|
85
|
+
>>>
|
|
86
|
+
tensor([[1, 2, 0]], device='cuda:0')
|
|
87
|
+
>>>
|
|
88
|
+
>>> # Generate an example input
|
|
89
|
+
>>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1)
|
|
90
|
+
>>> print(pooled_embs)
|
|
91
|
+
>>>
|
|
92
|
+
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
|
|
93
|
+
14., 15.],
|
|
94
|
+
[16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
|
|
95
|
+
30., 31.],
|
|
96
|
+
[32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
|
|
97
|
+
46., 47.]], device='cuda:0')
|
|
98
|
+
>>>
|
|
99
|
+
>>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list)
|
|
100
|
+
>>>
|
|
101
|
+
tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5.,
|
|
102
|
+
6., 7.],
|
|
103
|
+
[24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
|
|
104
|
+
22., 23.],
|
|
105
|
+
[40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
|
|
106
|
+
38., 39.]], device='cuda:0')
|
|
107
|
+
""",
|
|
108
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from .common import add_docs
|
|
10
|
+
|
|
11
|
+
add_docs(
|
|
12
|
+
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf,
|
|
13
|
+
"""
|
|
14
|
+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate) -> Tensor
|
|
15
|
+
|
|
16
|
+
Convert FP32/16 to INT8/4/2 using rowwise quantization.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
input (Tensor): An input tensor. Must be either FP32 (`torch.float`)
|
|
20
|
+
or FP16 (`torch.half`) and must be 2 dimensions.
|
|
21
|
+
|
|
22
|
+
bit_rate (int): Quantized bit rate (2 for INT2, 4 for INT4, or 8 for
|
|
23
|
+
INT8)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Quantized output (Tensor). Data type is `torch.uint8` (byte type)
|
|
27
|
+
|
|
28
|
+
**Example:**
|
|
29
|
+
|
|
30
|
+
>>> # Randomize input
|
|
31
|
+
>>> input = torch.randn(2, 4, dtype=torch.float32, device="cuda")
|
|
32
|
+
>>> print(input)
|
|
33
|
+
tensor([[ 0.8247, 0.0031, -1.0068, -1.2081],
|
|
34
|
+
[ 0.5427, 1.5772, 1.0291, -0.7626]], device='cuda:0')
|
|
35
|
+
>>> # Quantize
|
|
36
|
+
>>> output = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate=4)
|
|
37
|
+
>>> print(output)
|
|
38
|
+
tensor([[159, 1, 86, 48, 213, 188],
|
|
39
|
+
[248, 11, 254, 48, 26, 186]], device='cuda:0', dtype=torch.uint8)
|
|
40
|
+
""",
|
|
41
|
+
)
|