fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,307 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ # Helper functions for using FBGEMM quantized operators.
10
+
11
+
12
+ import torch
13
+
14
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
15
+
16
+
17
+ def pack_int4(x: torch.Tensor) -> torch.Tensor:
18
+ # Given int8 x, pack adjacent int4 values into a single int8.
19
+ low_x = x[:, ::2]
20
+ high_x = x[:, 1::2]
21
+
22
+ # High bits need to left shift, this also masks off extra bits.
23
+ high_x = torch.bitwise_left_shift(high_x, 4)
24
+ # Low bits need to have sign bits removed.
25
+ low_x = torch.bitwise_and(low_x, 0xF)
26
+
27
+ # Recombine into a single value with bitwise or.
28
+ return torch.bitwise_or(low_x, high_x).contiguous()
29
+
30
+
31
+ def int4_row_quantize_zp(
32
+ x: torch.Tensor,
33
+ group_size: int = 128,
34
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ n_bit = 4 # Number of target bits.
36
+ # Split input into chunks of group_size. This approach allows K that isnt divisible by group_size.
37
+ to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
38
+
39
+ max_val = [chunk.amax(dim=1, keepdim=True) for chunk in to_quant]
40
+ min_val = [chunk.amin(dim=1, keepdim=True) for chunk in to_quant]
41
+ max_int = 2**n_bit - 1
42
+ min_int = 0
43
+ scales = [
44
+ (max_chunk - min_chunk).clamp(min=1e-6) / max_int
45
+ for max_chunk, min_chunk in zip(max_val, min_val)
46
+ ]
47
+
48
+ zeros = [
49
+ min_chunk + scale_chunk * (2 ** (n_bit - 1))
50
+ for min_chunk, scale_chunk in zip(min_val, scales)
51
+ ]
52
+
53
+ out = [
54
+ chunk.sub(min_chunk).div(scale_chunk).round().clamp_(min_int, max_int)
55
+ for chunk, min_chunk, scale_chunk in zip(to_quant, min_val, scales)
56
+ ]
57
+
58
+ # Recenter output and move to int8.
59
+ out = [(chunk - 2 ** (n_bit - 1)).to(dtype=torch.int8) for chunk in out]
60
+
61
+ # Recombine chunks.
62
+ out = torch.cat(out, dim=-1)
63
+
64
+ # Cutlass expects column major layout for scale and zero point,
65
+ # so we transpose here and make them contiguous.
66
+ scales = torch.cat(scales, dim=-1).t().contiguous()
67
+ zeros = torch.cat(zeros, dim=-1).t().contiguous()
68
+
69
+ return out, scales, zeros
70
+
71
+
72
+ def int4_row_quantize(
73
+ x: torch.Tensor,
74
+ group_size: int = 128,
75
+ ) -> tuple[torch.Tensor, torch.Tensor]:
76
+ """
77
+ Helper function to quantize a tensor to int4 with groupwise scales.
78
+
79
+ Args:
80
+ x (Tensor): [N, K] Higher precision weight tensor to quantize.
81
+ group_size (int): Number of elements to calculate group scale for.
82
+ Returns:
83
+ wq (Tensor): [N, K] Quantized int4 tensor stored in int8 elements.
84
+ group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
85
+ """
86
+ n_bit = 4 # Number of target bits.
87
+ # Split input into chunks of group_size. This approach allows K that isnt divisible by group_size.
88
+ to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
89
+
90
+ max_val = [torch.abs(chunk).amax(dim=-1, keepdim=True) for chunk in to_quant]
91
+ max_int = 2 ** (n_bit - 1)
92
+ min_int = -(2 ** (n_bit - 1))
93
+ scales = [chunk.clamp(min=1e-6) / max_int for chunk in max_val]
94
+
95
+ out = [
96
+ chunk.div(chunk_scale).round().clamp_(min_int, max_int - 1)
97
+ for chunk, chunk_scale in zip(to_quant, scales)
98
+ ]
99
+ # Recombine chunks.
100
+ out = torch.cat(out, dim=-1)
101
+
102
+ # Cast to int8 and restore shape.
103
+ out = out.to(dtype=torch.int8)
104
+
105
+ # Scales should be in [num_groups, N] layout.
106
+ scales = torch.cat(scales, dim=-1).t().contiguous()
107
+
108
+ return out, scales
109
+
110
+
111
+ def quantize_int4_preshuffle(
112
+ w: torch.Tensor, group_size: int = 128, dtype: str = "fp8", use_zp: bool = True
113
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
114
+ """
115
+ Quantizes an input weight tensor to int4 using preshuffling and scale packing.
116
+ This function is intended to be used with fbgemms mixed dtype kernels and is expected
117
+ to be applied to weights ahead of time. As such, it is not perfectly optimized.
118
+
119
+ Args:
120
+ w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
121
+ group_size (int): Number of elements to calculate group scale for, must be at least 128.
122
+ dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
123
+ use_zp (bool): If true, uses zero points during weight quantization. Only relevant for bf16 currently.
124
+ Returns:
125
+ wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
126
+ scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
127
+ scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is
128
+ used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N])
129
+ """
130
+
131
+ def _quantize(
132
+ w: torch.Tensor, dtype: str = "fp8"
133
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
134
+
135
+ if dtype == "fp8":
136
+ # Start by lowering weights to FP8 and producing row scales.
137
+ wq, row_scale = quantize_fp8_row(w)
138
+
139
+ # Now reduce to INT4.
140
+ wq, group_scale = int4_row_quantize(wq, group_size)
141
+ # Reduce group scale to FP8.
142
+ group_scale = group_scale.to(torch.float8_e4m3fn)
143
+ # Take quantized weights and pack them efficiently.
144
+ wq = pack_int4(wq)
145
+ # Finally pack weights and scales into efficient preshuffled format.
146
+ wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
147
+ return wq, (group_scale, row_scale)
148
+
149
+ elif dtype == "bf16":
150
+ if use_zp:
151
+ wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
152
+ else:
153
+ wq, group_scale = int4_row_quantize(w, group_size)
154
+ group_zero = torch.zeros_like(group_scale)
155
+ # Set scales to activation type.
156
+ group_scale = group_scale.to(torch.bfloat16)
157
+ group_zero = group_zero.to(torch.bfloat16)
158
+ # Take quantized weights and pack them efficiently.
159
+ wq = pack_int4(wq)
160
+ # Finally pack weights and scales into efficient preshuffled format.
161
+ wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
162
+ return wq, (group_scale, group_zero)
163
+ else:
164
+ raise NotImplementedError("Only fp8 and bf16 activations supported.")
165
+
166
+ if w.ndim >= 3:
167
+ orig_shape = w.shape
168
+ # Flatten to 3 dimensions then iterate over batches.
169
+ wq, scales = zip(*[_quantize(i, dtype=dtype) for i in w])
170
+ wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
171
+ # Decompose then stack scales back into a tuple.
172
+ a_scales, b_scales = zip(*scales)
173
+ scales = (
174
+ torch.stack(a_scales).view(*orig_shape[:-2], *a_scales[0].shape),
175
+ torch.stack(b_scales).view(*orig_shape[:-2], *b_scales[0].shape),
176
+ )
177
+ else:
178
+ wq, scales = _quantize(w, dtype=dtype)
179
+
180
+ return wq, scales
181
+
182
+
183
+ def shuffle_slice(
184
+ x: torch.Tensor, dim: int, start: int, length: int, dtype: str = "fp8"
185
+ ) -> torch.Tensor:
186
+ """
187
+ Helper function to slice a preshuffled int4 tensor. This is needed since the shuffling
188
+ reorders rows based on the size of the input. Slicing a tensor shuffled for a larger input
189
+ is no longer valid. We must reorder the tensor to the appropriate size then slice.
190
+ Args:
191
+ x (Tensor): [N, K // 2] Preshuffled int4 tensor.
192
+ dim (int): Dimension to slice.
193
+ start (int): Start of slice.
194
+ length (int): Number of elements to slice in the original [N, K] dimension.
195
+ dtype (str): Type of corresponding activations. Must be fp8 or bf16.
196
+ Returns:
197
+ sliced (Tensor): [stop-start, K // 2] Sliced tensor.
198
+ """
199
+ # Get the size of the input tensor.
200
+ assert dim in [x.ndim - 2, x.ndim - 1], "Only slicing along N or K is supported."
201
+ assert length % 16 == 0, "Slicing must be a multiple of 16."
202
+ orig_shape = x.shape
203
+ N = x.shape[-2]
204
+ K = x.shape[-1]
205
+ # Tile shape is based on the activation dtype.
206
+ assert dtype in ("fp8", "bf16"), "Only fp8 and bf16 activations supported."
207
+ # Handle slice along M
208
+ if dim == x.ndim - 2:
209
+ tile_shape = 8 if dtype == "fp8" else 16
210
+ block_size = N // length
211
+ # View the shape in terms of shuffled tiles then permute to allow slicing.
212
+ x_s = x.view(-1, tile_shape, block_size, length // tile_shape, K)
213
+ x_s = x_s.permute(0, 2, 1, 3, 4).contiguous().view(-1, N, K)
214
+ out_slice = x_s.narrow(1, start, length)
215
+ # Reshape back to original shape.
216
+ return out_slice.view(*orig_shape[:-2], length, K)
217
+ # Handle slice along K
218
+ else:
219
+ outer_dim = x.view(-1, N, K).shape[0]
220
+ x_s = x.view(outer_dim, -1, length // 2)
221
+ row_factor = x_s.shape[1] * (length // 2) // K
222
+ # Take slices of rows corresponding to column slice.
223
+ return x_s.narrow(1, start * 2 * K // length, row_factor).view(
224
+ *orig_shape[:-2], N, length // 2
225
+ )
226
+
227
+
228
+ def scale_nvfp4_quant(
229
+ input: torch.Tensor, input_global_scale: torch.Tensor
230
+ ) -> tuple[torch.Tensor, torch.Tensor]:
231
+ """
232
+ Quantize input tensor to FP4 and return quantized tensor and scale.
233
+ This function quantizes the last dimension of the given tensor `input`. For
234
+ every 16 consecutive elements, a single dynamically computed scaling factor
235
+ is shared. This scaling factor is quantized using the `input_global_scale`
236
+ and is stored in a swizzled layout (see
237
+ https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
238
+ Args:
239
+ input: The input tensor to be quantized to FP4
240
+ input_global_scale: A scalar scaling factor for the entire tensor.
241
+ Returns:
242
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
243
+ two values are packed into a uint8 and float8_e4m3 scaling factors
244
+ in the sizzled layout.
245
+ """
246
+ assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
247
+ other_dims = 1 if input.ndim == 1 else -1
248
+ input = input.reshape(other_dims, input.shape[-1])
249
+ m, n = input.shape
250
+ block_size = 16
251
+ device = input.device
252
+
253
+ assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
254
+ assert input.dtype in (
255
+ torch.float16,
256
+ torch.bfloat16,
257
+ ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
258
+
259
+ # Two fp4 values will be packed into an uint8.
260
+ output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
261
+
262
+ # We use the rounded values to store the swizzled values. Due to the
263
+ # requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
264
+ # So, we first pad the scales to multiples of 128 and 4. Then, the scales
265
+ # (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
266
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
267
+ def round_up(x: int, y: int) -> int:
268
+ return (x + y - 1) // y * y
269
+
270
+ rounded_m = round_up(m, 128)
271
+ scale_n = n // block_size
272
+ rounded_n = round_up(scale_n, 4)
273
+ output_scale = torch.empty(
274
+ (rounded_m, rounded_n // 4), device=device, dtype=torch.int32
275
+ )
276
+
277
+ torch.ops.fbgemm.scaled_fp4_quant(output, input, output_scale, input_global_scale)
278
+ output_scale = output_scale.view(torch.float8_e4m3fn)
279
+ return output, output_scale
280
+
281
+
282
+ def ck_preshuffle(src: torch.Tensor, NXdl: int = 16) -> torch.Tensor:
283
+ """
284
+ Applies shuffling to make weights more efficient for use with CK kernels.
285
+ Args:
286
+ src (torch.Tensor): Input tensor with dtype float8_e4m3fnuz.
287
+ NXdl (int): Wave tile size along N.
288
+ Returns:
289
+ torch.Tensor: The shuffled tensor.
290
+ """
291
+ # Check input datatype
292
+ if src.dtype != torch.float8_e4m3fnuz:
293
+ raise TypeError("Input must be type float8_e4m3fnuz.")
294
+ N, K = src.shape
295
+ KPack = 16
296
+ NLane = NXdl
297
+ KLane = 64 // NLane
298
+ K0 = K // (KLane * KPack)
299
+ # Reshape src to enable the required permutation
300
+ # Original shape: (N, K)
301
+ # Desired intermediate shape for permutation: (N0, NLane, K0, KLane, KPack)
302
+ src = src.reshape(N // NLane, NLane, K0, KLane, KPack)
303
+ # Apply permutation: (N0, NLane, K0, KLane, KPack) -> (N0, K0, KLane, NLane, KPack)
304
+ dst = src.permute(0, 2, 3, 1, 4).contiguous()
305
+ # Reshape to original input shape.
306
+ dst = dst.reshape(N, K)
307
+ return dst
fbgemm_gpu/fbgemm.so ADDED
Binary file
fbgemm_gpu/metrics.py ADDED
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from typing import Any, Callable
10
+
11
+ import torch
12
+
13
+
14
+ class BatchAuc(torch.nn.Module):
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+
18
+ def forward(
19
+ self,
20
+ n_tasks: int,
21
+ predictions: torch.Tensor,
22
+ labels: torch.Tensor,
23
+ weights: torch.Tensor,
24
+ ) -> torch.Tensor:
25
+ _, sorted_indices = torch.sort(predictions, descending=True, dim=-1)
26
+ sorted_labels = torch.gather(labels, 1, sorted_indices)
27
+ sorted_weights = torch.gather(weights, 1, sorted_indices)
28
+ cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=-1)
29
+ cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=-1)
30
+ fac = cum_fp[:, -1] * cum_tp[:, -1]
31
+ auc = torch.where(fac == 0, 0.5, torch.trapz(cum_tp, cum_fp, dim=-1) / fac)
32
+ return auc
33
+
34
+
35
+ class Auc(torch.nn.Module):
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+
39
+ def forward(
40
+ self,
41
+ n_tasks: int,
42
+ predictions: torch.Tensor,
43
+ labels: torch.Tensor,
44
+ weights: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ _, sorted_indices = torch.sort(predictions, descending=True, dim=-1)
47
+ aucs = []
48
+ for sorted_indices_i, labels_i, weights_i in zip(
49
+ sorted_indices, labels, weights
50
+ ):
51
+ sorted_labels = torch.index_select(labels_i, dim=0, index=sorted_indices_i)
52
+ sorted_weights = torch.index_select(
53
+ weights_i, dim=0, index=sorted_indices_i
54
+ )
55
+ cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0)
56
+ cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0)
57
+ auc = torch.where(
58
+ cum_fp[-1] * cum_tp[-1] == 0,
59
+ 0.5, # 0.5 is the no-signal default value for auc.
60
+ torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1],
61
+ )
62
+ aucs.append(auc.view(1))
63
+ return torch.cat(aucs)
64
+
65
+
66
+ class AucJiterator(torch.nn.Module):
67
+ def __init__(self) -> None:
68
+ super().__init__()
69
+ # Jiterator only works with elementwise kernels
70
+ fp_code_string = """
71
+ template <typename T> T fp(T weights, T labels) {
72
+ return weights * (1.0 - labels);
73
+ }"""
74
+
75
+ tp_code_string = """
76
+ template <typename T> T tp(T weights, T labels) {
77
+ return weights * labels;
78
+ }"""
79
+
80
+ # pyre-ignore [4]
81
+ self.jitted_fp: Callable[..., Any] = torch.cuda.jiterator._create_jit_fn(
82
+ fp_code_string
83
+ )
84
+ # pyre-ignore [4]
85
+ self.jitted_tp: Callable[..., Any] = torch.cuda.jiterator._create_jit_fn(
86
+ tp_code_string
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ n_tasks: int,
92
+ predictions: torch.Tensor,
93
+ labels: torch.Tensor,
94
+ weights: torch.Tensor,
95
+ ) -> torch.Tensor:
96
+ _, sorted_indices = torch.sort(predictions, descending=True, dim=-1)
97
+ aucs = []
98
+ for sorted_indices_i, labels_i, weights_i in zip(
99
+ sorted_indices, labels, weights
100
+ ):
101
+ sorted_labels = torch.index_select(labels_i, dim=0, index=sorted_indices_i)
102
+ sorted_weights = torch.index_select(
103
+ weights_i, dim=0, index=sorted_indices_i
104
+ )
105
+ cum_fp = torch.cumsum(self.jitted_fp(sorted_weights, sorted_labels), dim=0)
106
+ cum_tp = torch.cumsum(self.jitted_tp(sorted_weights, sorted_labels), dim=0)
107
+ auc = torch.where(
108
+ cum_fp[-1] * cum_tp[-1] == 0,
109
+ 0.5, # 0.5 is the no-signal default value for auc.
110
+ torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1],
111
+ )
112
+ aucs.append(auc.view(1))
113
+ return torch.cat(aucs)
114
+
115
+
116
+ class BatchAucJiterator(torch.nn.Module):
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ # Jiterator only works with elementwise kernels
120
+ fp_code_string = """
121
+ template <typename T> T fp(T weights, T labels) {
122
+ return weights * (1.0 - labels);
123
+ }"""
124
+
125
+ tp_code_string = """
126
+ template <typename T> T tp(T weights, T labels) {
127
+ return weights * labels;
128
+ }"""
129
+
130
+ # pyre-ignore [4]
131
+ self.jitted_fp: Callable[..., Any] = torch.cuda.jiterator._create_jit_fn(
132
+ fp_code_string
133
+ )
134
+ # pyre-ignore [4]
135
+ self.jitted_tp: Callable[..., Any] = torch.cuda.jiterator._create_jit_fn(
136
+ tp_code_string
137
+ )
138
+
139
+ def forward(
140
+ self,
141
+ n_tasks: int,
142
+ predictions: torch.Tensor,
143
+ labels: torch.Tensor,
144
+ weights: torch.Tensor,
145
+ ) -> torch.Tensor:
146
+ _, sorted_indices = torch.sort(predictions, descending=True, dim=-1)
147
+ sorted_labels = torch.gather(labels, 1, sorted_indices)
148
+ sorted_weights = torch.gather(weights, 1, sorted_indices)
149
+ cum_fp = torch.cumsum(self.jitted_fp(sorted_weights, sorted_labels), dim=-1)
150
+ cum_tp = torch.cumsum(self.jitted_tp(sorted_weights, sorted_labels), dim=-1)
151
+ fac = cum_fp[:, -1] * cum_tp[:, -1]
152
+ auc = torch.where(fac == 0, 0.5, torch.trapz(cum_tp, cum_fp, dim=-1) / fac)
153
+ return auc
154
+
155
+
156
+ def auc(
157
+ n_tasks: int, predictions: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
158
+ ) -> torch.Tensor:
159
+ _, sorted_indices = torch.sort(predictions, descending=True, dim=-1)
160
+ return torch.ops.fbgemm.batch_auc(n_tasks, sorted_indices, labels, weights)
@@ -0,0 +1,142 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from itertools import accumulate
11
+ from typing import Optional
12
+
13
+ import torch
14
+
15
+ from fbgemm_gpu.utils.loader import load_torch_module
16
+
17
+ try:
18
+ # pyre-ignore[21]
19
+ from fbgemm_gpu import open_source # noqa: F401
20
+ except Exception:
21
+ torch.ops.load_library(
22
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
23
+ )
24
+ load_torch_module(
25
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
26
+ )
27
+
28
+
29
+ class PermutePooledEmbeddings:
30
+ """
31
+ A module for permuting embedding outputs along the feature dimension
32
+
33
+ An embedding output tensor contains the embedding outputs for all features
34
+ in a batch. It is represented in a 2D format, where the rows are the batch
35
+ size dimension and the columns are the feature * embedding dimension.
36
+ Permuting along the feature dimension is essentially permuting along the
37
+ second dimension (dim 1).
38
+
39
+ **Example:**
40
+
41
+ >>> import torch
42
+ >>> import fbgemm_gpu
43
+ >>> from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
44
+ >>>
45
+ >>> # Suppose batch size = 3 and there are 3 features
46
+ >>> batch_size = 3
47
+ >>>
48
+ >>> # Embedding dimensions for each feature
49
+ >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
50
+ >>>
51
+ >>> # Permute list, i.e., move feature 2 to position 0, move feature 0
52
+ >>> # to position 1, so on
53
+ >>> permute = [2, 0, 1]
54
+ >>>
55
+ >>> # Instantiate the module
56
+ >>> perm = PermutePooledEmbeddings(embs_dims, permute)
57
+ >>>
58
+ >>> # Generate an example input
59
+ >>> pooled_embs = torch.arange(
60
+ >>> embs_dims.sum().item() * batch_size,
61
+ >>> dtype=torch.float32, device="cuda"
62
+ >>> ).reshape(batch_size, -1)
63
+ >>> print(pooled_embs)
64
+ >>>
65
+ tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
66
+ 14., 15.],
67
+ [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
68
+ 30., 31.],
69
+ [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
70
+ 46., 47.]], device='cuda:0')
71
+ >>>
72
+ >>> # Invoke
73
+ >>> perm(pooled_embs)
74
+ >>>
75
+ tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5.,
76
+ 6., 7.],
77
+ [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
78
+ 22., 23.],
79
+ [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
80
+ 38., 39.]], device='cuda:0')
81
+
82
+ Args:
83
+ embs_dims (List[int]): A list of embedding dimensions for all features.
84
+ Length = the number of features
85
+
86
+ permute (List[int]): A list that describes how each feature is
87
+ permuted. `permute[i]` is to permute feature `permute[i]` to
88
+ position `i`.
89
+
90
+ device (Optional[torch.device] = None): The device to run this module
91
+ on
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ embs_dims: list[int],
97
+ permute: list[int],
98
+ device: Optional[torch.device] = None,
99
+ ) -> None:
100
+ self._offset_dim_list: torch.Tensor = torch.tensor(
101
+ [0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64
102
+ )
103
+
104
+ self._permute: torch.Tensor = torch.tensor(
105
+ permute, device=device, dtype=torch.int64
106
+ )
107
+
108
+ inv_permute: list[int] = [0] * len(permute)
109
+ for i, p in enumerate(permute):
110
+ inv_permute[p] = i
111
+
112
+ self._inv_permute: torch.Tensor = torch.tensor(
113
+ inv_permute, device=device, dtype=torch.int64
114
+ )
115
+
116
+ inv_embs_dims = [embs_dims[i] for i in permute]
117
+
118
+ self._inv_offset_dim_list: torch.Tensor = torch.tensor(
119
+ [0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64
120
+ )
121
+
122
+ def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Performs pooled embedding output permutation along the feature dimension
125
+
126
+ Args:
127
+ pooled_embs (Tensor): The embedding outputs to permute. Shape is
128
+ `(B_local, total_global_D)`, where `B_local` = a local batch
129
+ size and `total_global_D` is the total embedding dimension
130
+ across all features (global)
131
+
132
+ Returns:
133
+ Permuted embedding outputs (Tensor). Same shape as `pooled_embs`
134
+ """
135
+ result = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
136
+ pooled_embs,
137
+ self._offset_dim_list.to(device=pooled_embs.device),
138
+ self._permute.to(device=pooled_embs.device),
139
+ self._inv_offset_dim_list.to(device=pooled_embs.device),
140
+ self._inv_permute.to(device=pooled_embs.device),
141
+ )
142
+ return result
@@ -0,0 +1,85 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import logging
11
+ from itertools import accumulate
12
+ from typing import Optional
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ try:
18
+ # pyre-ignore[21]
19
+ from fbgemm_gpu import open_source # noqa: F401
20
+ except Exception:
21
+ torch.ops.load_library(
22
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_gpu"
23
+ )
24
+ torch.ops.load_library(
25
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
26
+ )
27
+
28
+
29
+ @torch.fx.wrap
30
+ def _fx_wrap_tensor_to_device(t: torch.Tensor, device: torch.device) -> torch.Tensor:
31
+ return t.to(device=device)
32
+
33
+
34
+ class PermutePooledEmbeddingsSplit(nn.Module):
35
+ def __init__(
36
+ self,
37
+ embs_dims: list[int],
38
+ permute: list[int],
39
+ device: Optional[torch.device] = None,
40
+ ) -> None:
41
+ super(PermutePooledEmbeddingsSplit, self).__init__()
42
+ logging.info("Using Permute Pooled Embeddings")
43
+
44
+ self.register_buffer(
45
+ "_offset_dim_list",
46
+ torch.tensor(
47
+ [0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64
48
+ ),
49
+ )
50
+ self.register_buffer(
51
+ "_permute", torch.tensor(permute, device=device, dtype=torch.int64)
52
+ )
53
+
54
+ inv_permute: list[int] = [0] * len(permute)
55
+ for i, p in enumerate(permute):
56
+ inv_permute[p] = i
57
+
58
+ self.register_buffer(
59
+ "_inv_permute", torch.tensor(inv_permute, device=device, dtype=torch.int64)
60
+ )
61
+
62
+ # `Union[BoundMethod[typing.Callable(torch.Tensor.tolist)[[Named(self,
63
+ # torch.Tensor)], List[typing.Any]], torch.Tensor], nn.Module, torch.Tensor]`
64
+ # is not a function.
65
+
66
+ inv_embs_dims = [embs_dims[i] for i in permute]
67
+
68
+ self.register_buffer(
69
+ "_inv_offset_dim_list",
70
+ torch.tensor(
71
+ [0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64
72
+ ),
73
+ )
74
+
75
+ def forward(self, pooled_embs: torch.Tensor) -> torch.Tensor:
76
+ result = torch.ops.fbgemm.permute_pooled_embs_auto_grad_split(
77
+ pooled_embs,
78
+ _fx_wrap_tensor_to_device(self._offset_dim_list, device=pooled_embs.device),
79
+ _fx_wrap_tensor_to_device(self._permute, device=pooled_embs.device),
80
+ _fx_wrap_tensor_to_device(
81
+ self._inv_offset_dim_list, device=pooled_embs.device
82
+ ),
83
+ _fx_wrap_tensor_to_device(self._inv_permute, device=pooled_embs.device),
84
+ )
85
+ return result