fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,286 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+ from typing import Union
10
+
11
+ import torch
12
+
13
+ from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
14
+
15
+
16
+ def _compute_exp(
17
+ group_max,
18
+ rounding_mode,
19
+ mbits,
20
+ ):
21
+ """Compute shared exponent of group using specified rounding mode.
22
+
23
+ Args:
24
+ group_max (Tensor): Group of values to compute exponent of.
25
+ rounding_mode (int or RoundingMode): Which rounding mode to use.
26
+ mbits (int): Number of mantissa bits in target mx4 format.
27
+
28
+ Returns:
29
+ Tensor: Shared exponent of group.
30
+ """
31
+ # Helpful constants.
32
+ MBITS_FP32 = 23
33
+ RAND_MASK = (1 << (MBITS_FP32 - mbits)) - 1
34
+ # Nearest rounding mode.
35
+ if rounding_mode == 0:
36
+ return torch.floor(torch.log2(group_max) + 0.5)
37
+ # Floor rounding mode.
38
+ if rounding_mode == 1:
39
+ return torch.floor(torch.log2(group_max))
40
+ # Even pre-rounding mode.
41
+ elif rounding_mode == 2:
42
+ # First round to nearest even integer.
43
+ M_ROUND = (1 << (MBITS_FP32 - mbits - 1)) - 1
44
+ group_max = group_max.view(dtype=torch.int32) + M_ROUND
45
+ # Then perform floor rounding of log.
46
+ return torch.floor(torch.log2(group_max.view(dtype=torch.float32)))
47
+ # Stochastic rounding mode.
48
+ elif rounding_mode == 3:
49
+ # Create random noise.
50
+ rand_bits = torch.randint_like(group_max, high=2**31 - 1, dtype=torch.int32)
51
+ # Add noise to group max and round down.
52
+ group_max = group_max.view(dtype=torch.int32) + (RAND_MASK & rand_bits)
53
+ # Now compute log and truncate.
54
+ return torch.floor(torch.log2(group_max.view(dtype=torch.float32)))
55
+ else:
56
+ return torch.ceil(torch.log2(group_max))
57
+
58
+
59
+ def py_quantize_mx4(
60
+ a: torch.Tensor,
61
+ group_size: int = 32,
62
+ ebits: int = 2,
63
+ mbits: int = 1,
64
+ rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
65
+ stochastic_casting: bool = False,
66
+ ) -> torch.Tensor:
67
+ """
68
+ Quantize a tensor to mx4 format.
69
+
70
+ Args:
71
+ a (Tensor): [M] higher precision input tensor.
72
+ group_size (int): Size of chunks that will use the same shared exponent.
73
+ ebits (int): Number of exponent bits in target mx4 format.
74
+ mbits (int): Number of mantissa bits in target mx4 format.
75
+ rounding_mode (int or RoundingMode): Which type of rounding to use when
76
+ calculating shared exponent.
77
+ stochastic_casting (bool): Whether to use stochastic rounding when downcasting.
78
+
79
+ Returns:
80
+ torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
81
+ with group exponents attached to each row.
82
+
83
+ eg.
84
+ Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
85
+ each value contain two elements packed into an int8 and
86
+ there are 32 groups in each row.
87
+ """
88
+ # Define helpful constants.
89
+ FP32_MIN_NORMAL = 2 ** (-126)
90
+ FP32_SIGN_OFFSET = 31
91
+ SIGN_MASK = 0x1
92
+ FP32_EXP_MASK = 0x7F800000
93
+ FP32_EXP_OFFSET = 23
94
+ FP32_MANTISSA_MASK = 0x007FFFFF
95
+ # Set number of exponent bits and mantissa (plus implicit) bits.
96
+ EBITS = ebits
97
+ MBITS = mbits + 1
98
+ # FP32 and and FP4 have very different exponent biases, adjust to fp4.
99
+ FP32_EXP_BIAS = 127
100
+ FP4_EXP_BIAS = get_mx4_exp_bias(EBITS)
101
+ MAX_FP32_MANTISSA_BITS = 24
102
+ RAND_MASK = (1 << (FP32_EXP_OFFSET - mbits)) - 1
103
+ MANTISSA_OVERFLOW_THRESHOLD = (1 << MBITS) - 1
104
+ EXPONENT_OVERFLOW_THRESHOLD = (1 << EBITS) - 1
105
+ IMPLICIT_1_MASK = (1 << (MBITS - 1)) - 1
106
+
107
+ # Make sure input has a supported shape.
108
+ # If given an empty shape, return an empty tensor.
109
+ if a.numel() == 0:
110
+ return torch.empty(a.shape, device=a.device, dtype=torch.uint8)
111
+ # Make sure input has a supported shape, if not pad each row.
112
+ if a.shape[-1] % group_size != 0:
113
+ pad = group_size - (a.shape[-1] % group_size)
114
+ a = torch.nn.functional.pad(a, (0, pad))
115
+
116
+ # Keep track of original shape.
117
+ orig_shape = a.shape
118
+ # Prepare for grouping by subdiving the last axis.
119
+ a = a.view(a.numel() // group_size, group_size)
120
+ # Now we can easily compute the shared exponents for each group.
121
+ shared_exp, _ = torch.max(torch.abs(a), dim=1, keepdim=True)
122
+ # Replace zero values with the minimum expressible normal value.
123
+ shared_exp = torch.where(shared_exp == 0, FP32_MIN_NORMAL, shared_exp)
124
+ # Convert max into an integer exponent.
125
+ shared_exp = _compute_exp(shared_exp, rounding_mode, mbits)
126
+ # Offset exponent by largest exponent in target datatype.
127
+ shared_exp = shared_exp - EBITS
128
+ # Restrict to range expressible as int8.
129
+ shared_exp = torch.clamp(shared_exp, min=-127, max=125)
130
+ # Convert exponent to scale and apply to input.
131
+ # Need to do this calculation on cpu for accuracy.
132
+ _shared_exp = shared_exp.cpu()
133
+ scale = (2**_shared_exp).to(device=a.device)
134
+ a = a / scale
135
+ # View as integer for bitwise ops.
136
+ a = a.view(torch.int32)
137
+
138
+ # When doing ceiling rounding, we apply stochastic downcasting.
139
+ if stochastic_casting:
140
+ rand_bits = torch.randint_like(a, high=2**31 - 1, dtype=torch.int32)
141
+ a = a + (rand_bits & RAND_MASK)
142
+
143
+ # Quantization step: convert fp32 values to fp4.
144
+ # Start by extracting float components.
145
+ sign_bit = torch.bitwise_right_shift(a, FP32_SIGN_OFFSET).to(torch.int8)
146
+ # Torch does arithmetic shifts so we need to isolate sign bit.
147
+ sign_bit = torch.bitwise_and(sign_bit, SIGN_MASK)
148
+
149
+ # Next extract exponent.
150
+ biased_exp = torch.bitwise_and(a, FP32_EXP_MASK)
151
+ # Shift exponent over to least significant bits.
152
+ biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8)
153
+
154
+ # Finally extract the mantissa.
155
+ trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)
156
+ new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
157
+
158
+ # Compute difference between ideal exponent and what can be represented.
159
+ exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
160
+ # Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit).
161
+ exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS)
162
+
163
+ # Now perform mantissa rounding down to fp4.
164
+ is_subnorm = biased_exp == 0
165
+ # Add implied 1 to normal values.
166
+ mantissa = torch.where(is_subnorm, trailing_mantissa, trailing_mantissa + (1 << 23))
167
+ # Compute base number of bits corresponding to the mantissa. We use a smaller value
168
+ # for subnorms since implicit one is included in exp_diff above.
169
+ fp32_sig_bits = torch.where(is_subnorm, 23, 24).to(torch.int32)
170
+ # Shift down to target bitwidth - 1 and efficiently represent.
171
+ mantissa = torch.bitwise_right_shift(
172
+ mantissa, fp32_sig_bits + exp_diff - MBITS - 1
173
+ ).to(torch.int8)
174
+ # Perform rounding by adding 1 then shifting down.
175
+ mantissa = mantissa + 1
176
+ mantissa = torch.bitwise_right_shift(mantissa, 1)
177
+
178
+ # Check for overflow and adjust exponent accordingly.
179
+ overflow = mantissa > MANTISSA_OVERFLOW_THRESHOLD
180
+ # Allow subnorms to overflow into normals, otherwise shift off overflow.
181
+ mantissa = torch.where(
182
+ torch.bitwise_and(overflow, torch.bitwise_not(is_subnorm)),
183
+ torch.bitwise_right_shift(mantissa, 1),
184
+ mantissa,
185
+ )
186
+ # Special case where a value is subnorm and has a large mantissa, overflow it.
187
+ new_biased_exp = torch.where(
188
+ torch.bitwise_and(new_biased_exp <= 0, mantissa == 2), 1, new_biased_exp
189
+ )
190
+ # Remove implicit 1.
191
+ mantissa = torch.bitwise_and(mantissa, IMPLICIT_1_MASK)
192
+ # Add overflow to exponent.
193
+ new_biased_exp = torch.where(overflow, new_biased_exp + 1, new_biased_exp)
194
+ # If exp overflows, set mantissa so we're at max representable value.
195
+ mantissa = torch.where(new_biased_exp > EXPONENT_OVERFLOW_THRESHOLD, 1, mantissa)
196
+
197
+ # Construct fp4 value from components.
198
+ new_biased_exp = torch.clamp(new_biased_exp, min=0, max=EXPONENT_OVERFLOW_THRESHOLD)
199
+ mx4_value = torch.bitwise_or(
200
+ torch.bitwise_left_shift(new_biased_exp, MBITS - 1), mantissa
201
+ )
202
+ mx4_value = torch.bitwise_or(
203
+ torch.bitwise_left_shift(sign_bit, EBITS + MBITS - 1), mx4_value
204
+ )
205
+
206
+ # Pack int4 values into single int8 outputs.
207
+ low_mx4 = mx4_value[:, ::2]
208
+ high_mx4 = mx4_value[:, 1::2]
209
+ high_mx4 = torch.bitwise_left_shift(high_mx4, 4)
210
+ packed_mx4 = torch.bitwise_or(low_mx4, high_mx4)
211
+
212
+ # Ravel packed values together with shared exponent.
213
+ packed_mx4 = torch.concat(
214
+ [
215
+ packed_mx4.view(-1, group_size // 2),
216
+ (shared_exp + FP32_EXP_BIAS).to(torch.int8).view(-1, 1),
217
+ ],
218
+ dim=1,
219
+ )
220
+
221
+ # Inputs are now fully quantized and ready to return.
222
+ # Try to return in the original shape if possible.
223
+ if orig_shape[-1] % group_size == 0:
224
+ output_shape = list(orig_shape[:-1]) + [-1]
225
+ return packed_mx4.view(output_shape).view(torch.uint8)
226
+ # If we cant, return as a flat array.
227
+ else:
228
+ return packed_mx4.view(-1).view(torch.uint8)
229
+
230
+
231
+ def py_dequantize_mx4(
232
+ a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
233
+ ) -> torch.Tensor:
234
+ """
235
+ Dequantize a tensor from mx4 format to fp32.
236
+
237
+ Args:
238
+ a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
239
+ with group exponents attached to end of each row.
240
+ group_size (int): Size of chunks that use the same shared exponent.
241
+ ebits (int): Number of exponent bits in target mx4 format.
242
+ mbits (int): Number of mantissa bits in target mx4 format.
243
+
244
+ Returns:
245
+ torch.Tensor: [M] dequantized fp32 tensor.
246
+ """
247
+ # If given an empty shape, return an empty tensor.
248
+ if a.numel() == 0:
249
+ return torch.empty(a.shape, device=a.device, dtype=torch.float32)
250
+ # Keep track of starting shape.
251
+ orig_shape = a.shape
252
+ device = a.device
253
+ # Unravel packed inputs from shared exponents.
254
+ a = a.view(-1, (group_size // 2) + 1).view(torch.int8)
255
+ num_groups = a.numel() // ((group_size // 2) + 1)
256
+ packed_input = a[:, :-1]
257
+ shared_exp = a[:, -1:]
258
+ # Remove fp32 exponent bias
259
+ FP32_EXP_BIAS = 127
260
+ shared_exp = shared_exp - FP32_EXP_BIAS
261
+ # First pull shared exponent off the end of each row.
262
+ M, K_2 = packed_input.shape
263
+
264
+ # Pull out high and low mx4 values.
265
+ FP4_BIT_MASK = 0xF
266
+ low_mx4 = torch.bitwise_and(packed_input, FP4_BIT_MASK)
267
+ high_mx4 = torch.bitwise_right_shift(packed_input, 4)
268
+ # Remove sign bit from high values since shift was arithmetic.
269
+ high_mx4 = torch.bitwise_and(high_mx4, FP4_BIT_MASK)
270
+ # Recombine into a single tensor.
271
+ a = torch.stack([low_mx4, high_mx4], dim=0).view(2, -1).t().contiguous()
272
+
273
+ # Use a lookup table to convert
274
+ mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, device)
275
+ # Convert values into float32 equivalent via lookup.
276
+ out = torch.index_select(mx4_to_fp_values, 0, a.to(torch.int32).view(-1))
277
+
278
+ # Exponent needs to be computed on cpu for perfect precision.
279
+ _shared_exp = shared_exp.cpu().to(torch.float)
280
+ scale = (2**_shared_exp).to(device)
281
+
282
+ # Finally, apply shared exponent to restore full value.
283
+ out = out.view(-1, num_groups, group_size) * scale.view(1, num_groups, 1)
284
+ # Restore original shape and return.
285
+ out_shape = list(orig_shape[:-1]) + [-1]
286
+ return out.view(out_shape)
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+
10
+ from .filestore import FileStore # noqa F401
11
+ from .torch_library import TorchLibraryFragment # noqa F401
@@ -0,0 +1,211 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+
11
+ import io
12
+ import logging
13
+ import os
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import BinaryIO, Union
17
+
18
+ import torch
19
+
20
+ logger: logging.Logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class FileStore:
25
+ """
26
+ A basic file store implementation for easy data reads / writes / deletes.
27
+
28
+ This class is intended to be used as a utility inside the FBGEMM_GPU codebase
29
+ for consistent writing of tensors and other objects to the filesystem.
30
+
31
+ Attribute:
32
+ bucket (str): A directory in the filesystem.
33
+ """
34
+
35
+ bucket: str
36
+
37
+ def __post_init__(self) -> None:
38
+ if not os.path.isdir(self.bucket):
39
+ raise ValueError(f"Directory {self.bucket} does not exist")
40
+
41
+ def write(
42
+ self,
43
+ path: str,
44
+ raw_input: Union[BinaryIO, torch.Tensor, Path],
45
+ ttls: int = 864000,
46
+ ) -> "FileStore":
47
+ """
48
+ Writes a binary stream, or a torch.Tensor to the file located at `path`
49
+ (relative to `self.bucket`).
50
+
51
+ Args:
52
+ path (str): The path of the node or symlink to a directory.
53
+ raw_input (BinaryIO | torch.Tensor | Path): The data to write.
54
+
55
+ ttls (int): The time to live for the data in seconds. Defaults to
56
+ 10 days.
57
+
58
+ Returns:
59
+ self. This allows for method-chaining.
60
+ """
61
+
62
+ filepath = f"{self.bucket}/{path}"
63
+ event = f"writing to {filepath}"
64
+ logger.info(f"FileStore: {event}")
65
+
66
+ try:
67
+ if os.path.isfile(filepath):
68
+ raise FileExistsError(
69
+ f"File {filepath} already exists in the filesystem"
70
+ )
71
+
72
+ if isinstance(raw_input, torch.Tensor):
73
+ torch.save(raw_input, filepath)
74
+
75
+ elif isinstance(raw_input, Path):
76
+ if not os.path.exists(raw_input):
77
+ raise FileNotFoundError(f"File {raw_input} does not exist")
78
+ # Open the source file and destination file, and copy the contents
79
+ with open(raw_input, "rb") as src_file, open(
80
+ filepath, "wb"
81
+ ) as dst_file:
82
+ while chunk := src_file.read(4096): # Read 4 KB at a time
83
+ dst_file.write(chunk)
84
+
85
+ elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
86
+ with open(filepath, "wb") as file:
87
+ raw_input.seek(0)
88
+ while chunk := raw_input.read(4096): # Read 4 KB at a time
89
+ file.write(chunk)
90
+ else:
91
+ raise TypeError(f"Unsupported input type: {type(raw_input)}")
92
+
93
+ except Exception as e:
94
+ logger.error(f"FileStore: exception occurred when {event}: {e}")
95
+ raise e
96
+
97
+ return self
98
+
99
+ def read(self, path: str) -> io.BytesIO:
100
+ """
101
+ Reads a file into a BytesIO object.
102
+
103
+ Args:
104
+ path (str): The path of the node or symlink to a directory (relative
105
+ to `self.bucket`) to be read.
106
+
107
+ Returns:
108
+ Data from the file in BytesIO object format.
109
+ """
110
+ filepath = f"{self.bucket}/{path}"
111
+ event = f"reading from {filepath}"
112
+ logger.info(f"FileStore: {event}")
113
+
114
+ try:
115
+ if not os.path.isfile(filepath):
116
+ raise FileNotFoundError(
117
+ f"File {filepath} does not exist in the FileStore"
118
+ )
119
+
120
+ return io.BytesIO(open(filepath, "rb").read())
121
+
122
+ except Exception as e:
123
+ logger.error(f"FileStore: exception occurred when {event}: {e}")
124
+ raise e
125
+
126
+ def remove(self, path: str) -> "FileStore":
127
+ """
128
+ Removes a file or directory from the file store.
129
+
130
+ Args:
131
+ path (str): The path of the node or symlink to a directory (relative
132
+ to `self.bucket`) to be removed.
133
+
134
+ Returns:
135
+ self. This allows for method-chaining.
136
+ """
137
+ filepath = f"{self.bucket}/{path}"
138
+ event = f"deleting {filepath}"
139
+ logger.info(f"FileStore: {event}")
140
+
141
+ try:
142
+ if os.path.isfile(filepath):
143
+ os.remove(filepath)
144
+
145
+ except Exception as e:
146
+ logger.error(f"Manifold: exception occurred when {event}: {e}")
147
+ raise e
148
+
149
+ return self
150
+
151
+ def exists(self, path: str) -> bool:
152
+ """
153
+ Checks for existence of file in the file store.
154
+
155
+ Args:
156
+ path (str): The Manifold target path (relative to `self.bucket`).
157
+
158
+ Returns:
159
+ True if file exists, False otherwise.
160
+ """
161
+ filepath = f"{self.bucket}/{path}"
162
+ return os.path.exists(filepath)
163
+
164
+ def create_directory(self, path: str) -> "FileStore":
165
+ """
166
+ Creates a directory in the file store.
167
+
168
+ Args:
169
+ path (str): The path of the node or symlink to a directory (relative
170
+ to `self.bucket`) to be created.
171
+
172
+ Returns:
173
+ self. This allows for method-chaining.
174
+ """
175
+ filepath = f"{self.bucket}/{path}"
176
+ event = f"creating directory {filepath}"
177
+ logger.info(f"FileStore: {event}")
178
+
179
+ try:
180
+ if not os.path.exists(filepath):
181
+ os.makedirs(filepath, exist_ok=True)
182
+ except Exception as e:
183
+ logger.error(f"FileStore: exception occurred when {event}: {e}")
184
+ raise e
185
+
186
+ return self
187
+
188
+ def remove_directory(self, path: str) -> "FileStore":
189
+ """
190
+ Removes a directory from the file store.
191
+
192
+ Args:
193
+ path (str): The path of the node or symlink to a directory (relative
194
+ to `self.bucket`) to be removed.
195
+
196
+ Returns:
197
+ self. This allows for method-chaining.
198
+ """
199
+ filepath = f"{self.bucket}/{path}"
200
+ event = f"deleting {filepath}"
201
+ logger.info(f"FileStore: {event}")
202
+
203
+ try:
204
+ if os.path.isdir(filepath):
205
+ os.rmdir(filepath)
206
+
207
+ except Exception as e:
208
+ logger.error(f"Manifold: exception occurred when {event}: {e}")
209
+ raise e
210
+
211
+ return self
@@ -0,0 +1,36 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+
11
+ from typing import Optional
12
+
13
+ import torch
14
+
15
+
16
+ def load_torch_module(
17
+ unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None
18
+ ) -> None:
19
+ try:
20
+ torch.ops.load_library(unified_path)
21
+ except Exception:
22
+ if torch.version.hip:
23
+ if not hip_path:
24
+ hip_path = f"{unified_path}_hip"
25
+ torch.ops.load_library(hip_path)
26
+ else:
27
+ if not cuda_path:
28
+ cuda_path = f"{unified_path}_cuda"
29
+ torch.ops.load_library(cuda_path)
30
+
31
+
32
+ def load_torch_module_bc(new_path: str, old_path: str) -> None:
33
+ try:
34
+ torch.ops.load_library(new_path)
35
+ except Exception:
36
+ torch.ops.load_library(old_path)
@@ -0,0 +1,132 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import re
11
+ from typing import Callable
12
+
13
+ import torch
14
+
15
+
16
+ class TorchLibraryFragment:
17
+ """
18
+ A wrapper class around PyTorch library fragments, which are used to define
19
+ and register PyTorch operators. Handles duplicate operator definitions and
20
+ registrations under the hood.
21
+ """
22
+
23
+ def __init__(self, namespace: str) -> None:
24
+ """
25
+ Constructs the TorchLibraryFragment class.
26
+
27
+ Args:
28
+ namespace: The namespace for the operators.
29
+
30
+ Returns:
31
+ None
32
+
33
+ Example:
34
+ lib = TorchLibrary("fbgemm")
35
+ """
36
+ self.namespace = namespace
37
+ self.lib = torch.library.Library(namespace, "FRAGMENT")
38
+
39
+ def define(self, schema: str) -> None:
40
+ """
41
+ Defines an operator schema. This function handles the case where the
42
+ opeator name has already been defined.
43
+
44
+ Args:
45
+ schema: The schema of the operator to be defined. The operator name
46
+ should NOT be prefixed with the operator namespace.
47
+
48
+ Returns:
49
+ None
50
+
51
+ Example:
52
+ lib = TorchLibrary("fbgemm")
53
+ lib.define("sll_jagged_jagged_bmm(Tensor x, Tensor y, bool flag=True) -> Tensor")
54
+ """
55
+ pattern = re.compile(
56
+ r"""
57
+ (\w+) # Match the function name (capturing group)
58
+ \s*\( # Match the opening parenthesis with optional whitespace
59
+ ([^)]*) # Match params list (capturing group)
60
+ \s*\) # Match the closing parenthesis with optional whitespace
61
+ \s*->\s*.+ # Match '-> <Return Type>'
62
+ """,
63
+ re.VERBOSE,
64
+ )
65
+
66
+ match = pattern.search(schema.strip())
67
+ if match:
68
+ name = match.group(1)
69
+ if f"{self.namespace}::{name}" not in torch.library._defs:
70
+ self.lib.define(schema)
71
+ else:
72
+ raise ValueError(
73
+ f"PyTorch operator schema appears to be ill-defined: '''{schema}'''"
74
+ )
75
+
76
+ # pyre-ignore[24]
77
+ def register_dispatch(self, op_name: str, dispatch_key: str, fn: Callable) -> None:
78
+ """
79
+ Registers a single dispatch for an operator with the given name and dispatch key.
80
+
81
+ Args:
82
+ op_name: operator name
83
+ dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA")
84
+ fn: a function that is the operator implementation for the input dispatch key
85
+
86
+ Returns:
87
+ None
88
+
89
+ Example:
90
+ lib = TorchLibrary("fbgemm")
91
+ lib.define(...)
92
+ lib.register_dispatch(lib, "jagged_dense_bmm", jagged_dense_bmm, "CUDA")
93
+ """
94
+
95
+ valid_backends = [
96
+ "CUDA",
97
+ "AutogradCUDA",
98
+ "CPU",
99
+ "AutogradCPU",
100
+ "AutogradMeta",
101
+ "Meta",
102
+ "CompositeImplicitAutograd",
103
+ ]
104
+ assert dispatch_key in valid_backends
105
+
106
+ if not torch._C._dispatch_has_kernel_for_dispatch_key(
107
+ f"{self.namespace}::{op_name}", dispatch_key
108
+ ):
109
+ if dispatch_key == "Meta":
110
+ self.lib._register_fake(op_name, fn)
111
+ else:
112
+ self.lib.impl(op_name, fn, dispatch_key)
113
+
114
+ # pyre-ignore[24]
115
+ def register(self, op_name: str, functors: dict[str, Callable]) -> None:
116
+ """
117
+ Registers a set of dispatches for a defined operator.
118
+
119
+ Args:
120
+ op_name: operator name
121
+ functors: A dictionary of dispatch keys to dispatch implementations
122
+
123
+ Returns:
124
+ None
125
+
126
+ Example:
127
+ lib = TorchLibrary("fbgemm")
128
+ lib.define(...)
129
+ lib.register(lib, "jagged_dense_bmm", {"CUDA": jagged_dense_bmm, "Meta": jagged_dense_bmm_meta })
130
+ """
131
+ for dispatch, func in functors.items():
132
+ self.register_dispatch(op_name, dispatch, func)