fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,259 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # [fbgemm-gpu.autogen.docs.examples.docstring.start]
8
+ import torch
9
+
10
+ from .common import add_docs
11
+
12
+
13
+ add_docs(
14
+ torch.ops.fbgemm.jagged_2d_to_dense,
15
+ """
16
+ jagged_2d_to_dense(values, x_offsets, max_sequence_length) -> Tensor
17
+
18
+ Converts a jagged tensor, with a 2D values array into a dense tensor, padding with zeros.
19
+
20
+ Args:
21
+ values (Tensor): 2D tensor containing the values of the jagged tensor.
22
+
23
+ x_offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor.
24
+
25
+ max_sequence_length (int): Maximum length of any row in the jagged dimension.
26
+
27
+ Returns:
28
+ Tensor: The padded dense tensor
29
+
30
+ Example:
31
+ >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]])
32
+ >>> x_offsets = torch.tensor([0, 1, 3])
33
+ >>> torch.ops.fbgemm.jagged_2d_to_dense(values, x_offsets, 3)
34
+ tensor([[[1, 1],
35
+ [0, 0],
36
+ [0, 0]],
37
+ [[2, 2],
38
+ [3, 3],
39
+ [0, 0]]])
40
+
41
+ """,
42
+ )
43
+ # [fbgemm-gpu.autogen.docs.examples.docstring.end]
44
+
45
+ add_docs(
46
+ torch.ops.fbgemm.jagged_1d_to_dense,
47
+ """
48
+ jagged_1d_to_dense(values, offsets, max_sequence_length, padding_value) -> Tensor)
49
+
50
+ Converts a jagged tensor, with a 1D values array, into a dense tensor, padding with a specified padding value.
51
+
52
+ Args:
53
+ values (Tensor): 1D tensor containing the values of the jagged tensor.
54
+
55
+ offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor.
56
+
57
+ max_sequence_length (int): Maximum length of any row in the jagged dimension.
58
+
59
+ padding_value (int): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage.
60
+
61
+ Returns:
62
+ Tensor: the padded dense tensor
63
+
64
+ Example:
65
+ >>> values = torch.tensor([1,2,3,4])
66
+ >>> offsets = torch.tensor([0, 1, 3])
67
+ >>> torch.ops.fbgemm.jagged_1d_to_dense(values, x_offsets, 3, 0)
68
+ tensor([[1, 0, 0],
69
+ [2, 3, 0]])
70
+
71
+ """,
72
+ )
73
+
74
+ add_docs(
75
+ torch.ops.fbgemm.dense_to_jagged,
76
+ """
77
+ dense_to_jagged(dense, x_offsets, total_L) -> (Tensor, Tensor[])
78
+
79
+ Converts a dense tensor into a jagged tensor, given the desired offsets of the resulting dense tensor.
80
+
81
+ Args:
82
+ dense (Tensor): A dense input tensor to be converted
83
+
84
+ x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
85
+
86
+ total_L (int, Optional): Total number of values in the resulting jagged tensor.
87
+
88
+ Returns:
89
+ (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
90
+
91
+ Example:
92
+ >>> dense = torch.tensor([[[1, 1], [0, 0], [0, 0]], [[2, 2], [3, 3], [0, 0]]])
93
+ >>> x_offsets = torch.tensor([0, 1, 3])
94
+ >>> torch.ops.fbgemm.dense_to_jagged(dense, [x_offsets])
95
+ (tensor([[1, 1],
96
+ [2, 2],
97
+ [3, 3]]), [tensor([0, 1, 3])])
98
+
99
+ """,
100
+ )
101
+
102
+
103
+ add_docs(
104
+ torch.ops.fbgemm.jagged_to_padded_dense,
105
+ """
106
+ jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0) -> Tensor
107
+
108
+ Converts a jagged tensor into a dense tensor, padding with a specified padding value.
109
+
110
+ Args:
111
+ values (Tensor): Jagged tensor values
112
+
113
+ offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
114
+
115
+ max_lengths (int[]): A list with max_length for each jagged dimension.
116
+
117
+ padding_value (float): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage.
118
+
119
+ Returns:
120
+ Tensor: the padded dense tensor
121
+
122
+ Example:
123
+ >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]])
124
+ >>> offsets = torch.tensor([0, 1, 3])
125
+ >>> torch.ops.fbgemm.jagged_to_padded_dense(values, [offsets], [3], 7)
126
+ tensor([[[1, 1],
127
+ [7, 7],
128
+ [7, 7]],
129
+ [[2, 2],
130
+ [3, 3],
131
+ [7, 7]]])
132
+ """,
133
+ )
134
+
135
+
136
+ add_docs(
137
+ torch.ops.fbgemm.jagged_dense_elementwise_add,
138
+ """
139
+ jagged_dense_elementwise_add(x_values, x_offsets, y) -> Tensor
140
+
141
+ Adds a jagged tensor to a dense tensor, resulting in dense tensor. Jagged
142
+ tensor input will be padded with zeros for the purposes of the addition.
143
+
144
+ Args:
145
+ x_values (Tensor): Jagged tensor values
146
+
147
+ offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
148
+
149
+ y (Tensor): A dense tensor
150
+
151
+ Returns:
152
+ Tensor: The sum of jagged input tensor + y
153
+
154
+ """,
155
+ )
156
+
157
+
158
+ add_docs(
159
+ torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output,
160
+ """
161
+ jagged_dense_elementwise_add_jagged_output(x_values, x_offsets, y) -> (Tensor, Tensor[])
162
+
163
+ Adds a jagged tensor to a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor.
164
+
165
+ Args:
166
+ x_values (Tensor): Jagged tensor values
167
+
168
+ x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
169
+
170
+ y (Tensor): A dense tensor
171
+
172
+ Returns:
173
+ (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
174
+
175
+ """,
176
+ )
177
+
178
+
179
+ add_docs(
180
+ torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output,
181
+ """
182
+ jagged_dense_dense_elementwise_add_jagged_output(x_values, x_offsets, y_0, y_1) -> (Tensor, Tensor[])
183
+
184
+ Adds a jagged tensor to the sum of two dense tensors, resulting in a jagged tensor with the same structure as the input jagged tensor.
185
+
186
+ Args:
187
+ x_values (Tensor): Jagged tensor values
188
+
189
+ x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
190
+
191
+ y_0 (Tensor): A dense tensor
192
+
193
+ y_1 (Tensor): A dense tensor
194
+
195
+ Returns:
196
+ (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
197
+
198
+ """,
199
+ )
200
+
201
+
202
+ add_docs(
203
+ torch.ops.fbgemm.jagged_dense_elementwise_mul,
204
+ """
205
+ jagged_dense_elementwise_mul(x_values, x_offsets, y) -> (Tensor, Tensor[])
206
+
207
+ Elementwise-multiplies a jagged tensor a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor.
208
+
209
+ Args:
210
+ x_values (Tensor): Jagged tensor values
211
+
212
+ x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension.
213
+
214
+ y (Tensor): A dense tensor
215
+
216
+ Returns:
217
+ (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input.
218
+
219
+ """,
220
+ )
221
+
222
+ add_docs(
223
+ torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul,
224
+ """
225
+ batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor
226
+
227
+ Batched vector matrix multiplication of a batched dense vector with a jagged tensor, dense vector is in
228
+ size (B * H, max_N) and jagged tensor is in size (B, max_N, H * D) where max_N is the maximum size of
229
+ jagged dimension. B * H is the batch size and each multiplies is max_N with [max_N, D]
230
+
231
+ Args:
232
+ v (Tensor): dense vector tensor
233
+
234
+ a_values (Tensor): Jagged tensor values
235
+
236
+ a_offsets (Tensor []): A list of jagged offset tensors, one for each jagged dimension.
237
+
238
+ Returns:
239
+ Tensor: output of batch matmul in size (B * H, D)
240
+
241
+ """,
242
+ )
243
+
244
+ # add_docs(
245
+ # torch.ops.fbgemm.stacked_jagged_1d_to_dense,
246
+ # """Args:
247
+ # {input}
248
+ # Keyword args:
249
+ # {out}""",
250
+ # )
251
+ #
252
+ #
253
+ # add_docs(
254
+ # torch.ops.fbgemm.stacked_jagged_2d_to_dense,
255
+ # """Args:
256
+ # {input}
257
+ # Keyword args:
258
+ # {out}""",
259
+ # )
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from .common import add_docs
10
+
11
+ add_docs(
12
+ torch.ops.fbgemm.merge_pooled_embeddings,
13
+ """
14
+ merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor
15
+
16
+ Concatenate embedding outputs from different devices (on the same host)
17
+ on to the target device.
18
+
19
+ Args:
20
+ pooled_embeddings (List[Tensor]): A list of embedding outputs from
21
+ different devices on the same host. Each output has 2
22
+ dimensions.
23
+
24
+ uncat_dim_size (int): The size of the dimension that is not
25
+ concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size
26
+ of dim 1 and vice versa.
27
+
28
+ target_device (torch.device): The target device that aggregates all
29
+ the embedding outputs.
30
+
31
+ cat_dim (int = 1): The dimension that the tensors are concatenated
32
+
33
+ Returns:
34
+ The concatenated embedding output (2D) on the target device
35
+ """,
36
+ )
@@ -0,0 +1,108 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from .common import add_docs
10
+
11
+ add_docs(
12
+ torch.ops.fbgemm.permute_pooled_embs,
13
+ """
14
+ permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) -> Tensor
15
+
16
+ Permute embedding outputs along the feature dimension.
17
+
18
+ The embedding output tensor `pooled_embs` contains the embedding outputs
19
+ for all features in a batch. It is represented in a 2D format, where the
20
+ rows are the batch size dimension and the columns are the feature *
21
+ embedding dimension. Permuting along the feature dimension is
22
+ essentially permuting along the second dimension (dim 1).
23
+
24
+ Args:
25
+ pooled_embs (Tensor): The embedding outputs to permute. Shape is
26
+ `(B_local, total_global_D)`, where `B_local` = a local batch size
27
+ and `total_global_D` is the total embedding dimension across all
28
+ features (global)
29
+
30
+ offset_dim_list (Tensor): The complete cumulative sum of embedding
31
+ dimensions of all features. Shape is `T + 1` where `T` is the
32
+ total number of features
33
+
34
+ permute_list (Tensor): A tensor that describes how each feature is
35
+ permuted. `permute_list[i]` indicates that the feature
36
+ `permute_list[i]` is permuted to position `i`
37
+
38
+ inv_offset_dim_list (Tensor): The complete cumulative sum of inverse
39
+ embedding dimensions, which are the permuted embedding dimensions.
40
+ `inv_offset_dim_list[i]` represents the starting embedding position of
41
+ feature `permute_list[i]`
42
+
43
+ inv_permute_list (Tensor): The inverse permute list, which contains the
44
+ permuted positions of each feature. `inv_permute_list[i]` represents
45
+ the permuted position of feature `i`
46
+
47
+ Returns:
48
+ Permuted embedding outputs (Tensor). Same shape as `pooled_embs`
49
+
50
+ **Example:**
51
+
52
+ >>> import torch
53
+ >>> from itertools import accumulate
54
+ >>>
55
+ >>> # Suppose batch size = 3 and there are 3 features
56
+ >>> batch_size = 3
57
+ >>>
58
+ >>> # Embedding dimensions for each feature
59
+ >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
60
+ >>>
61
+ >>> # Permute list, i.e., move feature 2 to position 0, move feature 0
62
+ >>> # to position 1, so on
63
+ >>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda")
64
+ >>>
65
+ >>> # Compute embedding dim offsets
66
+ >>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda")
67
+ >>> print(offset_dim_list)
68
+ >>>
69
+ tensor([ 0, 4, 8, 16], device='cuda:0')
70
+ >>>
71
+ >>> # Compute inverse embedding dims
72
+ >>> inv_embs_dims = [embs_dims[p] for p in permute]
73
+ >>> # Compute complete cumulative sum of inverse embedding dims
74
+ >>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda")
75
+ >>> print(inv_offset_dim_list)
76
+ >>>
77
+ tensor([ 0, 8, 12, 16], device='cuda:0')
78
+ >>>
79
+ >>> # Compute inverse permutes
80
+ >>> inv_permute = [0] * len(permute)
81
+ >>> for i, p in enumerate(permute):
82
+ >>> inv_permute[p] = i
83
+ >>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda")
84
+ >>> print(inv_permute_list)
85
+ >>>
86
+ tensor([[1, 2, 0]], device='cuda:0')
87
+ >>>
88
+ >>> # Generate an example input
89
+ >>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1)
90
+ >>> print(pooled_embs)
91
+ >>>
92
+ tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
93
+ 14., 15.],
94
+ [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
95
+ 30., 31.],
96
+ [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
97
+ 46., 47.]], device='cuda:0')
98
+ >>>
99
+ >>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list)
100
+ >>>
101
+ tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5.,
102
+ 6., 7.],
103
+ [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
104
+ 22., 23.],
105
+ [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
106
+ 38., 39.]], device='cuda:0')
107
+ """,
108
+ )
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from .common import add_docs
10
+
11
+ add_docs(
12
+ torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf,
13
+ """
14
+ FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate) -> Tensor
15
+
16
+ Convert FP32/16 to INT8/4/2 using rowwise quantization.
17
+
18
+ Args:
19
+ input (Tensor): An input tensor. Must be either FP32 (`torch.float`)
20
+ or FP16 (`torch.half`) and must be 2 dimensions.
21
+
22
+ bit_rate (int): Quantized bit rate (2 for INT2, 4 for INT4, or 8 for
23
+ INT8)
24
+
25
+ Returns:
26
+ Quantized output (Tensor). Data type is `torch.uint8` (byte type)
27
+
28
+ **Example:**
29
+
30
+ >>> # Randomize input
31
+ >>> input = torch.randn(2, 4, dtype=torch.float32, device="cuda")
32
+ >>> print(input)
33
+ tensor([[ 0.8247, 0.0031, -1.0068, -1.2081],
34
+ [ 0.5427, 1.5772, 1.0291, -0.7626]], device='cuda:0')
35
+ >>> # Quantize
36
+ >>> output = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate=4)
37
+ >>> print(output)
38
+ tensor([[159, 1, 86, 48, 213, 188],
39
+ [248, 11, 254, 48, 26, 186]], device='cuda:0', dtype=torch.uint8)
40
+ """,
41
+ )