fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _compute_exp(
|
|
17
|
+
group_max,
|
|
18
|
+
rounding_mode,
|
|
19
|
+
mbits,
|
|
20
|
+
):
|
|
21
|
+
"""Compute shared exponent of group using specified rounding mode.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
group_max (Tensor): Group of values to compute exponent of.
|
|
25
|
+
rounding_mode (int or RoundingMode): Which rounding mode to use.
|
|
26
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Tensor: Shared exponent of group.
|
|
30
|
+
"""
|
|
31
|
+
# Helpful constants.
|
|
32
|
+
MBITS_FP32 = 23
|
|
33
|
+
RAND_MASK = (1 << (MBITS_FP32 - mbits)) - 1
|
|
34
|
+
# Nearest rounding mode.
|
|
35
|
+
if rounding_mode == 0:
|
|
36
|
+
return torch.floor(torch.log2(group_max) + 0.5)
|
|
37
|
+
# Floor rounding mode.
|
|
38
|
+
if rounding_mode == 1:
|
|
39
|
+
return torch.floor(torch.log2(group_max))
|
|
40
|
+
# Even pre-rounding mode.
|
|
41
|
+
elif rounding_mode == 2:
|
|
42
|
+
# First round to nearest even integer.
|
|
43
|
+
M_ROUND = (1 << (MBITS_FP32 - mbits - 1)) - 1
|
|
44
|
+
group_max = group_max.view(dtype=torch.int32) + M_ROUND
|
|
45
|
+
# Then perform floor rounding of log.
|
|
46
|
+
return torch.floor(torch.log2(group_max.view(dtype=torch.float32)))
|
|
47
|
+
# Stochastic rounding mode.
|
|
48
|
+
elif rounding_mode == 3:
|
|
49
|
+
# Create random noise.
|
|
50
|
+
rand_bits = torch.randint_like(group_max, high=2**31 - 1, dtype=torch.int32)
|
|
51
|
+
# Add noise to group max and round down.
|
|
52
|
+
group_max = group_max.view(dtype=torch.int32) + (RAND_MASK & rand_bits)
|
|
53
|
+
# Now compute log and truncate.
|
|
54
|
+
return torch.floor(torch.log2(group_max.view(dtype=torch.float32)))
|
|
55
|
+
else:
|
|
56
|
+
return torch.ceil(torch.log2(group_max))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def py_quantize_mx4(
|
|
60
|
+
a: torch.Tensor,
|
|
61
|
+
group_size: int = 32,
|
|
62
|
+
ebits: int = 2,
|
|
63
|
+
mbits: int = 1,
|
|
64
|
+
rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
|
|
65
|
+
stochastic_casting: bool = False,
|
|
66
|
+
) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Quantize a tensor to mx4 format.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
a (Tensor): [M] higher precision input tensor.
|
|
72
|
+
group_size (int): Size of chunks that will use the same shared exponent.
|
|
73
|
+
ebits (int): Number of exponent bits in target mx4 format.
|
|
74
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
75
|
+
rounding_mode (int or RoundingMode): Which type of rounding to use when
|
|
76
|
+
calculating shared exponent.
|
|
77
|
+
stochastic_casting (bool): Whether to use stochastic rounding when downcasting.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
|
|
81
|
+
with group exponents attached to each row.
|
|
82
|
+
|
|
83
|
+
eg.
|
|
84
|
+
Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
|
|
85
|
+
each value contain two elements packed into an int8 and
|
|
86
|
+
there are 32 groups in each row.
|
|
87
|
+
"""
|
|
88
|
+
# Define helpful constants.
|
|
89
|
+
FP32_MIN_NORMAL = 2 ** (-126)
|
|
90
|
+
FP32_SIGN_OFFSET = 31
|
|
91
|
+
SIGN_MASK = 0x1
|
|
92
|
+
FP32_EXP_MASK = 0x7F800000
|
|
93
|
+
FP32_EXP_OFFSET = 23
|
|
94
|
+
FP32_MANTISSA_MASK = 0x007FFFFF
|
|
95
|
+
# Set number of exponent bits and mantissa (plus implicit) bits.
|
|
96
|
+
EBITS = ebits
|
|
97
|
+
MBITS = mbits + 1
|
|
98
|
+
# FP32 and and FP4 have very different exponent biases, adjust to fp4.
|
|
99
|
+
FP32_EXP_BIAS = 127
|
|
100
|
+
FP4_EXP_BIAS = get_mx4_exp_bias(EBITS)
|
|
101
|
+
MAX_FP32_MANTISSA_BITS = 24
|
|
102
|
+
RAND_MASK = (1 << (FP32_EXP_OFFSET - mbits)) - 1
|
|
103
|
+
MANTISSA_OVERFLOW_THRESHOLD = (1 << MBITS) - 1
|
|
104
|
+
EXPONENT_OVERFLOW_THRESHOLD = (1 << EBITS) - 1
|
|
105
|
+
IMPLICIT_1_MASK = (1 << (MBITS - 1)) - 1
|
|
106
|
+
|
|
107
|
+
# Make sure input has a supported shape.
|
|
108
|
+
# If given an empty shape, return an empty tensor.
|
|
109
|
+
if a.numel() == 0:
|
|
110
|
+
return torch.empty(a.shape, device=a.device, dtype=torch.uint8)
|
|
111
|
+
# Make sure input has a supported shape, if not pad each row.
|
|
112
|
+
if a.shape[-1] % group_size != 0:
|
|
113
|
+
pad = group_size - (a.shape[-1] % group_size)
|
|
114
|
+
a = torch.nn.functional.pad(a, (0, pad))
|
|
115
|
+
|
|
116
|
+
# Keep track of original shape.
|
|
117
|
+
orig_shape = a.shape
|
|
118
|
+
# Prepare for grouping by subdiving the last axis.
|
|
119
|
+
a = a.view(a.numel() // group_size, group_size)
|
|
120
|
+
# Now we can easily compute the shared exponents for each group.
|
|
121
|
+
shared_exp, _ = torch.max(torch.abs(a), dim=1, keepdim=True)
|
|
122
|
+
# Replace zero values with the minimum expressible normal value.
|
|
123
|
+
shared_exp = torch.where(shared_exp == 0, FP32_MIN_NORMAL, shared_exp)
|
|
124
|
+
# Convert max into an integer exponent.
|
|
125
|
+
shared_exp = _compute_exp(shared_exp, rounding_mode, mbits)
|
|
126
|
+
# Offset exponent by largest exponent in target datatype.
|
|
127
|
+
shared_exp = shared_exp - EBITS
|
|
128
|
+
# Restrict to range expressible as int8.
|
|
129
|
+
shared_exp = torch.clamp(shared_exp, min=-127, max=125)
|
|
130
|
+
# Convert exponent to scale and apply to input.
|
|
131
|
+
# Need to do this calculation on cpu for accuracy.
|
|
132
|
+
_shared_exp = shared_exp.cpu()
|
|
133
|
+
scale = (2**_shared_exp).to(device=a.device)
|
|
134
|
+
a = a / scale
|
|
135
|
+
# View as integer for bitwise ops.
|
|
136
|
+
a = a.view(torch.int32)
|
|
137
|
+
|
|
138
|
+
# When doing ceiling rounding, we apply stochastic downcasting.
|
|
139
|
+
if stochastic_casting:
|
|
140
|
+
rand_bits = torch.randint_like(a, high=2**31 - 1, dtype=torch.int32)
|
|
141
|
+
a = a + (rand_bits & RAND_MASK)
|
|
142
|
+
|
|
143
|
+
# Quantization step: convert fp32 values to fp4.
|
|
144
|
+
# Start by extracting float components.
|
|
145
|
+
sign_bit = torch.bitwise_right_shift(a, FP32_SIGN_OFFSET).to(torch.int8)
|
|
146
|
+
# Torch does arithmetic shifts so we need to isolate sign bit.
|
|
147
|
+
sign_bit = torch.bitwise_and(sign_bit, SIGN_MASK)
|
|
148
|
+
|
|
149
|
+
# Next extract exponent.
|
|
150
|
+
biased_exp = torch.bitwise_and(a, FP32_EXP_MASK)
|
|
151
|
+
# Shift exponent over to least significant bits.
|
|
152
|
+
biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8)
|
|
153
|
+
|
|
154
|
+
# Finally extract the mantissa.
|
|
155
|
+
trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)
|
|
156
|
+
new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
|
|
157
|
+
|
|
158
|
+
# Compute difference between ideal exponent and what can be represented.
|
|
159
|
+
exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
|
|
160
|
+
# Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit).
|
|
161
|
+
exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS)
|
|
162
|
+
|
|
163
|
+
# Now perform mantissa rounding down to fp4.
|
|
164
|
+
is_subnorm = biased_exp == 0
|
|
165
|
+
# Add implied 1 to normal values.
|
|
166
|
+
mantissa = torch.where(is_subnorm, trailing_mantissa, trailing_mantissa + (1 << 23))
|
|
167
|
+
# Compute base number of bits corresponding to the mantissa. We use a smaller value
|
|
168
|
+
# for subnorms since implicit one is included in exp_diff above.
|
|
169
|
+
fp32_sig_bits = torch.where(is_subnorm, 23, 24).to(torch.int32)
|
|
170
|
+
# Shift down to target bitwidth - 1 and efficiently represent.
|
|
171
|
+
mantissa = torch.bitwise_right_shift(
|
|
172
|
+
mantissa, fp32_sig_bits + exp_diff - MBITS - 1
|
|
173
|
+
).to(torch.int8)
|
|
174
|
+
# Perform rounding by adding 1 then shifting down.
|
|
175
|
+
mantissa = mantissa + 1
|
|
176
|
+
mantissa = torch.bitwise_right_shift(mantissa, 1)
|
|
177
|
+
|
|
178
|
+
# Check for overflow and adjust exponent accordingly.
|
|
179
|
+
overflow = mantissa > MANTISSA_OVERFLOW_THRESHOLD
|
|
180
|
+
# Allow subnorms to overflow into normals, otherwise shift off overflow.
|
|
181
|
+
mantissa = torch.where(
|
|
182
|
+
torch.bitwise_and(overflow, torch.bitwise_not(is_subnorm)),
|
|
183
|
+
torch.bitwise_right_shift(mantissa, 1),
|
|
184
|
+
mantissa,
|
|
185
|
+
)
|
|
186
|
+
# Special case where a value is subnorm and has a large mantissa, overflow it.
|
|
187
|
+
new_biased_exp = torch.where(
|
|
188
|
+
torch.bitwise_and(new_biased_exp <= 0, mantissa == 2), 1, new_biased_exp
|
|
189
|
+
)
|
|
190
|
+
# Remove implicit 1.
|
|
191
|
+
mantissa = torch.bitwise_and(mantissa, IMPLICIT_1_MASK)
|
|
192
|
+
# Add overflow to exponent.
|
|
193
|
+
new_biased_exp = torch.where(overflow, new_biased_exp + 1, new_biased_exp)
|
|
194
|
+
# If exp overflows, set mantissa so we're at max representable value.
|
|
195
|
+
mantissa = torch.where(new_biased_exp > EXPONENT_OVERFLOW_THRESHOLD, 1, mantissa)
|
|
196
|
+
|
|
197
|
+
# Construct fp4 value from components.
|
|
198
|
+
new_biased_exp = torch.clamp(new_biased_exp, min=0, max=EXPONENT_OVERFLOW_THRESHOLD)
|
|
199
|
+
mx4_value = torch.bitwise_or(
|
|
200
|
+
torch.bitwise_left_shift(new_biased_exp, MBITS - 1), mantissa
|
|
201
|
+
)
|
|
202
|
+
mx4_value = torch.bitwise_or(
|
|
203
|
+
torch.bitwise_left_shift(sign_bit, EBITS + MBITS - 1), mx4_value
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Pack int4 values into single int8 outputs.
|
|
207
|
+
low_mx4 = mx4_value[:, ::2]
|
|
208
|
+
high_mx4 = mx4_value[:, 1::2]
|
|
209
|
+
high_mx4 = torch.bitwise_left_shift(high_mx4, 4)
|
|
210
|
+
packed_mx4 = torch.bitwise_or(low_mx4, high_mx4)
|
|
211
|
+
|
|
212
|
+
# Ravel packed values together with shared exponent.
|
|
213
|
+
packed_mx4 = torch.concat(
|
|
214
|
+
[
|
|
215
|
+
packed_mx4.view(-1, group_size // 2),
|
|
216
|
+
(shared_exp + FP32_EXP_BIAS).to(torch.int8).view(-1, 1),
|
|
217
|
+
],
|
|
218
|
+
dim=1,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Inputs are now fully quantized and ready to return.
|
|
222
|
+
# Try to return in the original shape if possible.
|
|
223
|
+
if orig_shape[-1] % group_size == 0:
|
|
224
|
+
output_shape = list(orig_shape[:-1]) + [-1]
|
|
225
|
+
return packed_mx4.view(output_shape).view(torch.uint8)
|
|
226
|
+
# If we cant, return as a flat array.
|
|
227
|
+
else:
|
|
228
|
+
return packed_mx4.view(-1).view(torch.uint8)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def py_dequantize_mx4(
|
|
232
|
+
a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
|
|
233
|
+
) -> torch.Tensor:
|
|
234
|
+
"""
|
|
235
|
+
Dequantize a tensor from mx4 format to fp32.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
|
|
239
|
+
with group exponents attached to end of each row.
|
|
240
|
+
group_size (int): Size of chunks that use the same shared exponent.
|
|
241
|
+
ebits (int): Number of exponent bits in target mx4 format.
|
|
242
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
torch.Tensor: [M] dequantized fp32 tensor.
|
|
246
|
+
"""
|
|
247
|
+
# If given an empty shape, return an empty tensor.
|
|
248
|
+
if a.numel() == 0:
|
|
249
|
+
return torch.empty(a.shape, device=a.device, dtype=torch.float32)
|
|
250
|
+
# Keep track of starting shape.
|
|
251
|
+
orig_shape = a.shape
|
|
252
|
+
device = a.device
|
|
253
|
+
# Unravel packed inputs from shared exponents.
|
|
254
|
+
a = a.view(-1, (group_size // 2) + 1).view(torch.int8)
|
|
255
|
+
num_groups = a.numel() // ((group_size // 2) + 1)
|
|
256
|
+
packed_input = a[:, :-1]
|
|
257
|
+
shared_exp = a[:, -1:]
|
|
258
|
+
# Remove fp32 exponent bias
|
|
259
|
+
FP32_EXP_BIAS = 127
|
|
260
|
+
shared_exp = shared_exp - FP32_EXP_BIAS
|
|
261
|
+
# First pull shared exponent off the end of each row.
|
|
262
|
+
M, K_2 = packed_input.shape
|
|
263
|
+
|
|
264
|
+
# Pull out high and low mx4 values.
|
|
265
|
+
FP4_BIT_MASK = 0xF
|
|
266
|
+
low_mx4 = torch.bitwise_and(packed_input, FP4_BIT_MASK)
|
|
267
|
+
high_mx4 = torch.bitwise_right_shift(packed_input, 4)
|
|
268
|
+
# Remove sign bit from high values since shift was arithmetic.
|
|
269
|
+
high_mx4 = torch.bitwise_and(high_mx4, FP4_BIT_MASK)
|
|
270
|
+
# Recombine into a single tensor.
|
|
271
|
+
a = torch.stack([low_mx4, high_mx4], dim=0).view(2, -1).t().contiguous()
|
|
272
|
+
|
|
273
|
+
# Use a lookup table to convert
|
|
274
|
+
mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, device)
|
|
275
|
+
# Convert values into float32 equivalent via lookup.
|
|
276
|
+
out = torch.index_select(mx4_to_fp_values, 0, a.to(torch.int32).view(-1))
|
|
277
|
+
|
|
278
|
+
# Exponent needs to be computed on cpu for perfect precision.
|
|
279
|
+
_shared_exp = shared_exp.cpu().to(torch.float)
|
|
280
|
+
scale = (2**_shared_exp).to(device)
|
|
281
|
+
|
|
282
|
+
# Finally, apply shared exponent to restore full value.
|
|
283
|
+
out = out.view(-1, num_groups, group_size) * scale.view(1, num_groups, 1)
|
|
284
|
+
# Restore original shape and return.
|
|
285
|
+
out_shape = list(orig_shape[:-1]) + [-1]
|
|
286
|
+
return out.view(out_shape)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
|
|
10
|
+
from .filestore import FileStore # noqa F401
|
|
11
|
+
from .torch_library import TorchLibraryFragment # noqa F401
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
|
|
11
|
+
import io
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import BinaryIO, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class FileStore:
|
|
25
|
+
"""
|
|
26
|
+
A basic file store implementation for easy data reads / writes / deletes.
|
|
27
|
+
|
|
28
|
+
This class is intended to be used as a utility inside the FBGEMM_GPU codebase
|
|
29
|
+
for consistent writing of tensors and other objects to the filesystem.
|
|
30
|
+
|
|
31
|
+
Attribute:
|
|
32
|
+
bucket (str): A directory in the filesystem.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
bucket: str
|
|
36
|
+
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if not os.path.isdir(self.bucket):
|
|
39
|
+
raise ValueError(f"Directory {self.bucket} does not exist")
|
|
40
|
+
|
|
41
|
+
def write(
|
|
42
|
+
self,
|
|
43
|
+
path: str,
|
|
44
|
+
raw_input: Union[BinaryIO, torch.Tensor, Path],
|
|
45
|
+
ttls: int = 864000,
|
|
46
|
+
) -> "FileStore":
|
|
47
|
+
"""
|
|
48
|
+
Writes a binary stream, or a torch.Tensor to the file located at `path`
|
|
49
|
+
(relative to `self.bucket`).
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
path (str): The path of the node or symlink to a directory.
|
|
53
|
+
raw_input (BinaryIO | torch.Tensor | Path): The data to write.
|
|
54
|
+
|
|
55
|
+
ttls (int): The time to live for the data in seconds. Defaults to
|
|
56
|
+
10 days.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
self. This allows for method-chaining.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
filepath = f"{self.bucket}/{path}"
|
|
63
|
+
event = f"writing to {filepath}"
|
|
64
|
+
logger.info(f"FileStore: {event}")
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
if os.path.isfile(filepath):
|
|
68
|
+
raise FileExistsError(
|
|
69
|
+
f"File {filepath} already exists in the filesystem"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if isinstance(raw_input, torch.Tensor):
|
|
73
|
+
torch.save(raw_input, filepath)
|
|
74
|
+
|
|
75
|
+
elif isinstance(raw_input, Path):
|
|
76
|
+
if not os.path.exists(raw_input):
|
|
77
|
+
raise FileNotFoundError(f"File {raw_input} does not exist")
|
|
78
|
+
# Open the source file and destination file, and copy the contents
|
|
79
|
+
with open(raw_input, "rb") as src_file, open(
|
|
80
|
+
filepath, "wb"
|
|
81
|
+
) as dst_file:
|
|
82
|
+
while chunk := src_file.read(4096): # Read 4 KB at a time
|
|
83
|
+
dst_file.write(chunk)
|
|
84
|
+
|
|
85
|
+
elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
|
|
86
|
+
with open(filepath, "wb") as file:
|
|
87
|
+
raw_input.seek(0)
|
|
88
|
+
while chunk := raw_input.read(4096): # Read 4 KB at a time
|
|
89
|
+
file.write(chunk)
|
|
90
|
+
else:
|
|
91
|
+
raise TypeError(f"Unsupported input type: {type(raw_input)}")
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.error(f"FileStore: exception occurred when {event}: {e}")
|
|
95
|
+
raise e
|
|
96
|
+
|
|
97
|
+
return self
|
|
98
|
+
|
|
99
|
+
def read(self, path: str) -> io.BytesIO:
|
|
100
|
+
"""
|
|
101
|
+
Reads a file into a BytesIO object.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
path (str): The path of the node or symlink to a directory (relative
|
|
105
|
+
to `self.bucket`) to be read.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Data from the file in BytesIO object format.
|
|
109
|
+
"""
|
|
110
|
+
filepath = f"{self.bucket}/{path}"
|
|
111
|
+
event = f"reading from {filepath}"
|
|
112
|
+
logger.info(f"FileStore: {event}")
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
if not os.path.isfile(filepath):
|
|
116
|
+
raise FileNotFoundError(
|
|
117
|
+
f"File {filepath} does not exist in the FileStore"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return io.BytesIO(open(filepath, "rb").read())
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error(f"FileStore: exception occurred when {event}: {e}")
|
|
124
|
+
raise e
|
|
125
|
+
|
|
126
|
+
def remove(self, path: str) -> "FileStore":
|
|
127
|
+
"""
|
|
128
|
+
Removes a file or directory from the file store.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
path (str): The path of the node or symlink to a directory (relative
|
|
132
|
+
to `self.bucket`) to be removed.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
self. This allows for method-chaining.
|
|
136
|
+
"""
|
|
137
|
+
filepath = f"{self.bucket}/{path}"
|
|
138
|
+
event = f"deleting {filepath}"
|
|
139
|
+
logger.info(f"FileStore: {event}")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
if os.path.isfile(filepath):
|
|
143
|
+
os.remove(filepath)
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.error(f"Manifold: exception occurred when {event}: {e}")
|
|
147
|
+
raise e
|
|
148
|
+
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
def exists(self, path: str) -> bool:
|
|
152
|
+
"""
|
|
153
|
+
Checks for existence of file in the file store.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
path (str): The Manifold target path (relative to `self.bucket`).
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
True if file exists, False otherwise.
|
|
160
|
+
"""
|
|
161
|
+
filepath = f"{self.bucket}/{path}"
|
|
162
|
+
return os.path.exists(filepath)
|
|
163
|
+
|
|
164
|
+
def create_directory(self, path: str) -> "FileStore":
|
|
165
|
+
"""
|
|
166
|
+
Creates a directory in the file store.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
path (str): The path of the node or symlink to a directory (relative
|
|
170
|
+
to `self.bucket`) to be created.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
self. This allows for method-chaining.
|
|
174
|
+
"""
|
|
175
|
+
filepath = f"{self.bucket}/{path}"
|
|
176
|
+
event = f"creating directory {filepath}"
|
|
177
|
+
logger.info(f"FileStore: {event}")
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
if not os.path.exists(filepath):
|
|
181
|
+
os.makedirs(filepath, exist_ok=True)
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.error(f"FileStore: exception occurred when {event}: {e}")
|
|
184
|
+
raise e
|
|
185
|
+
|
|
186
|
+
return self
|
|
187
|
+
|
|
188
|
+
def remove_directory(self, path: str) -> "FileStore":
|
|
189
|
+
"""
|
|
190
|
+
Removes a directory from the file store.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
path (str): The path of the node or symlink to a directory (relative
|
|
194
|
+
to `self.bucket`) to be removed.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
self. This allows for method-chaining.
|
|
198
|
+
"""
|
|
199
|
+
filepath = f"{self.bucket}/{path}"
|
|
200
|
+
event = f"deleting {filepath}"
|
|
201
|
+
logger.info(f"FileStore: {event}")
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
if os.path.isdir(filepath):
|
|
205
|
+
os.rmdir(filepath)
|
|
206
|
+
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.error(f"Manifold: exception occurred when {event}: {e}")
|
|
209
|
+
raise e
|
|
210
|
+
|
|
211
|
+
return self
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def load_torch_module(
|
|
17
|
+
unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None
|
|
18
|
+
) -> None:
|
|
19
|
+
try:
|
|
20
|
+
torch.ops.load_library(unified_path)
|
|
21
|
+
except Exception:
|
|
22
|
+
if torch.version.hip:
|
|
23
|
+
if not hip_path:
|
|
24
|
+
hip_path = f"{unified_path}_hip"
|
|
25
|
+
torch.ops.load_library(hip_path)
|
|
26
|
+
else:
|
|
27
|
+
if not cuda_path:
|
|
28
|
+
cuda_path = f"{unified_path}_cuda"
|
|
29
|
+
torch.ops.load_library(cuda_path)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def load_torch_module_bc(new_path: str, old_path: str) -> None:
|
|
33
|
+
try:
|
|
34
|
+
torch.ops.load_library(new_path)
|
|
35
|
+
except Exception:
|
|
36
|
+
torch.ops.load_library(old_path)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TorchLibraryFragment:
|
|
17
|
+
"""
|
|
18
|
+
A wrapper class around PyTorch library fragments, which are used to define
|
|
19
|
+
and register PyTorch operators. Handles duplicate operator definitions and
|
|
20
|
+
registrations under the hood.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, namespace: str) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Constructs the TorchLibraryFragment class.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
namespace: The namespace for the operators.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
None
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
lib = TorchLibrary("fbgemm")
|
|
35
|
+
"""
|
|
36
|
+
self.namespace = namespace
|
|
37
|
+
self.lib = torch.library.Library(namespace, "FRAGMENT")
|
|
38
|
+
|
|
39
|
+
def define(self, schema: str) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Defines an operator schema. This function handles the case where the
|
|
42
|
+
opeator name has already been defined.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
schema: The schema of the operator to be defined. The operator name
|
|
46
|
+
should NOT be prefixed with the operator namespace.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
None
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
lib = TorchLibrary("fbgemm")
|
|
53
|
+
lib.define("sll_jagged_jagged_bmm(Tensor x, Tensor y, bool flag=True) -> Tensor")
|
|
54
|
+
"""
|
|
55
|
+
pattern = re.compile(
|
|
56
|
+
r"""
|
|
57
|
+
(\w+) # Match the function name (capturing group)
|
|
58
|
+
\s*\( # Match the opening parenthesis with optional whitespace
|
|
59
|
+
([^)]*) # Match params list (capturing group)
|
|
60
|
+
\s*\) # Match the closing parenthesis with optional whitespace
|
|
61
|
+
\s*->\s*.+ # Match '-> <Return Type>'
|
|
62
|
+
""",
|
|
63
|
+
re.VERBOSE,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
match = pattern.search(schema.strip())
|
|
67
|
+
if match:
|
|
68
|
+
name = match.group(1)
|
|
69
|
+
if f"{self.namespace}::{name}" not in torch.library._defs:
|
|
70
|
+
self.lib.define(schema)
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"PyTorch operator schema appears to be ill-defined: '''{schema}'''"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# pyre-ignore[24]
|
|
77
|
+
def register_dispatch(self, op_name: str, dispatch_key: str, fn: Callable) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Registers a single dispatch for an operator with the given name and dispatch key.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
op_name: operator name
|
|
83
|
+
dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA")
|
|
84
|
+
fn: a function that is the operator implementation for the input dispatch key
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
None
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
lib = TorchLibrary("fbgemm")
|
|
91
|
+
lib.define(...)
|
|
92
|
+
lib.register_dispatch(lib, "jagged_dense_bmm", jagged_dense_bmm, "CUDA")
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
valid_backends = [
|
|
96
|
+
"CUDA",
|
|
97
|
+
"AutogradCUDA",
|
|
98
|
+
"CPU",
|
|
99
|
+
"AutogradCPU",
|
|
100
|
+
"AutogradMeta",
|
|
101
|
+
"Meta",
|
|
102
|
+
"CompositeImplicitAutograd",
|
|
103
|
+
]
|
|
104
|
+
assert dispatch_key in valid_backends
|
|
105
|
+
|
|
106
|
+
if not torch._C._dispatch_has_kernel_for_dispatch_key(
|
|
107
|
+
f"{self.namespace}::{op_name}", dispatch_key
|
|
108
|
+
):
|
|
109
|
+
if dispatch_key == "Meta":
|
|
110
|
+
self.lib._register_fake(op_name, fn)
|
|
111
|
+
else:
|
|
112
|
+
self.lib.impl(op_name, fn, dispatch_key)
|
|
113
|
+
|
|
114
|
+
# pyre-ignore[24]
|
|
115
|
+
def register(self, op_name: str, functors: dict[str, Callable]) -> None:
|
|
116
|
+
"""
|
|
117
|
+
Registers a set of dispatches for a defined operator.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
op_name: operator name
|
|
121
|
+
functors: A dictionary of dispatch keys to dispatch implementations
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
None
|
|
125
|
+
|
|
126
|
+
Example:
|
|
127
|
+
lib = TorchLibrary("fbgemm")
|
|
128
|
+
lib.define(...)
|
|
129
|
+
lib.register(lib, "jagged_dense_bmm", {"CUDA": jagged_dense_bmm, "Meta": jagged_dense_bmm_meta })
|
|
130
|
+
"""
|
|
131
|
+
for dispatch, func in functors.items():
|
|
132
|
+
self.register_dispatch(op_name, dispatch, func)
|