fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from typing import Callable, Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .common import to_device
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Merged indices with shape (T, B, L) -> (flattened indices with shape
|
|
18
|
+
# (T * B * L), offsets with shape (T * B + 1))
|
|
19
|
+
def get_table_batched_offsets_from_dense(
|
|
20
|
+
merged_indices: torch.Tensor,
|
|
21
|
+
L: Optional[int] = None,
|
|
22
|
+
total_B: Optional[int] = None,
|
|
23
|
+
use_cpu: bool = False,
|
|
24
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
25
|
+
if L is None and total_B is None:
|
|
26
|
+
(T, B, L) = merged_indices.size()
|
|
27
|
+
total_B = T * B
|
|
28
|
+
# pyre-fixme[6]: For 1st argument expected `Union[Sequence[SupportsIndex],
|
|
29
|
+
# SupportsIndex]` but got `Optional[int]`.
|
|
30
|
+
lengths = np.ones(total_B) * L
|
|
31
|
+
return (
|
|
32
|
+
to_device(merged_indices.contiguous().view(-1), use_cpu),
|
|
33
|
+
to_device(
|
|
34
|
+
torch.tensor(([0] + np.cumsum(lengths).tolist())).long(),
|
|
35
|
+
use_cpu,
|
|
36
|
+
),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
+
(B, L) = indices.size()
|
|
42
|
+
return (
|
|
43
|
+
indices.contiguous().view(-1),
|
|
44
|
+
torch.tensor(
|
|
45
|
+
np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64)
|
|
46
|
+
),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def b_indices(
|
|
51
|
+
b: Callable[..., torch.Tensor],
|
|
52
|
+
x: torch.Tensor,
|
|
53
|
+
per_sample_weights: Optional[torch.Tensor] = None,
|
|
54
|
+
use_cpu: bool = False,
|
|
55
|
+
do_pooling: bool = True,
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
(indices, offsets) = get_offsets_from_dense(x)
|
|
58
|
+
if do_pooling:
|
|
59
|
+
return b(
|
|
60
|
+
to_device(indices, use_cpu),
|
|
61
|
+
to_device(offsets, use_cpu),
|
|
62
|
+
per_sample_weights=per_sample_weights,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return b(to_device(indices, use_cpu))
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
# pyre-ignore-all-errors[61]
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .common import to_device
|
|
15
|
+
from fbgemm_gpu.split_embedding_configs import (
|
|
16
|
+
FP8QuantizationConfig,
|
|
17
|
+
SparseType,
|
|
18
|
+
) # usort:skip
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def quantize_embs(
|
|
22
|
+
weight: torch.Tensor,
|
|
23
|
+
weight_ty: SparseType,
|
|
24
|
+
fp8_config: Optional[FP8QuantizationConfig] = None,
|
|
25
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
26
|
+
weight = weight.detach()
|
|
27
|
+
if weight_ty == SparseType.FP32:
|
|
28
|
+
q_weight = weight.float()
|
|
29
|
+
res_weight = q_weight.view(torch.uint8)
|
|
30
|
+
return (res_weight, None)
|
|
31
|
+
|
|
32
|
+
elif weight_ty == SparseType.FP16:
|
|
33
|
+
q_weight = weight.half()
|
|
34
|
+
res_weight = q_weight.view(torch.uint8)
|
|
35
|
+
return (res_weight, None)
|
|
36
|
+
|
|
37
|
+
elif weight_ty == SparseType.FP8:
|
|
38
|
+
assert fp8_config is not None
|
|
39
|
+
# Quantize FP32 to HPF8
|
|
40
|
+
res_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
|
|
41
|
+
weight.float(),
|
|
42
|
+
fp8_config.get("exponent_bits"),
|
|
43
|
+
fp8_config.get("exponent_bias"),
|
|
44
|
+
fp8_config.get("max_position"),
|
|
45
|
+
)
|
|
46
|
+
return (res_weight, None)
|
|
47
|
+
|
|
48
|
+
elif weight_ty == SparseType.INT8:
|
|
49
|
+
# Note that FloatToFused8BitRowwiseQuantized might have additional padding
|
|
50
|
+
# for alignment if embedding dimension is not a multiple of 4:
|
|
51
|
+
# https://fburl.com/code/z009xsy6
|
|
52
|
+
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
|
|
53
|
+
res_weight = q_weight[:, :-8].view(torch.uint8)
|
|
54
|
+
res_scale_shift = torch.tensor(
|
|
55
|
+
q_weight[:, -8:].view(torch.float32).to(torch.float16).view(torch.uint8)
|
|
56
|
+
) # [-4, -2]: scale; [-2:]: bias
|
|
57
|
+
return (res_weight, res_scale_shift)
|
|
58
|
+
|
|
59
|
+
elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
|
|
60
|
+
# Note that FP32 -> INT4/INT2 conersion op below might have additional padding
|
|
61
|
+
# for alignment: https://fburl.com/code/xx9kkduf
|
|
62
|
+
q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
|
|
63
|
+
weight,
|
|
64
|
+
bit_rate=weight_ty.bit_rate(),
|
|
65
|
+
)
|
|
66
|
+
res_weight = q_weight[:, :-4].view(torch.uint8)
|
|
67
|
+
res_scale_shift = torch.tensor(
|
|
68
|
+
q_weight[:, -4:].view(torch.uint8)
|
|
69
|
+
) # [-4, -2]: scale; [-2:]: bias
|
|
70
|
+
return (res_weight, res_scale_shift)
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def dequantize_embs(
|
|
77
|
+
weights: torch.Tensor,
|
|
78
|
+
scale_shift: torch.Tensor,
|
|
79
|
+
weight_ty: SparseType,
|
|
80
|
+
use_cpu: bool,
|
|
81
|
+
fp8_config: Optional[FP8QuantizationConfig] = None,
|
|
82
|
+
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
print(f"weight_ty: {weight_ty}")
|
|
85
|
+
assert (
|
|
86
|
+
weights.dtype == torch.uint8
|
|
87
|
+
), "The input tensor for dequantize_embs function needs to be byte tensor"
|
|
88
|
+
th_weights = weights
|
|
89
|
+
|
|
90
|
+
if scale_shift is not None:
|
|
91
|
+
th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)
|
|
92
|
+
|
|
93
|
+
if weight_ty == SparseType.INT4:
|
|
94
|
+
(E, D_2) = th_weights.shape
|
|
95
|
+
D = D_2 * 2
|
|
96
|
+
|
|
97
|
+
def comp(i: int) -> torch.Tensor:
|
|
98
|
+
subs = th_weights.view(torch.uint8) >> (i * 4)
|
|
99
|
+
sub_mask = subs & 0xF
|
|
100
|
+
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
|
|
101
|
+
-1, 1
|
|
102
|
+
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
103
|
+
return result.to(torch.float32)
|
|
104
|
+
|
|
105
|
+
comps = [comp(i) for i in range(2)]
|
|
106
|
+
comps = torch.stack(comps)
|
|
107
|
+
comps = comps.permute(1, 2, 0)
|
|
108
|
+
comps = comps.reshape(E, D)
|
|
109
|
+
return to_device(torch.tensor(comps), use_cpu)
|
|
110
|
+
|
|
111
|
+
elif weight_ty == SparseType.INT2:
|
|
112
|
+
(E, D_4) = th_weights.shape
|
|
113
|
+
D = D_4 * 4
|
|
114
|
+
|
|
115
|
+
# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
|
|
116
|
+
# pyre-fixme[53]: Captured variable `weights` is not annotated.
|
|
117
|
+
def comp(i: int) -> torch.Tensor:
|
|
118
|
+
subs = th_weights.view(torch.uint8) >> (i * 2)
|
|
119
|
+
sub_mask = subs & 0x3
|
|
120
|
+
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
|
|
121
|
+
-1, 1
|
|
122
|
+
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
123
|
+
return result.to(torch.float32)
|
|
124
|
+
|
|
125
|
+
comps = [comp(i) for i in range(4)]
|
|
126
|
+
comps = torch.stack(comps)
|
|
127
|
+
comps = comps.permute(1, 2, 0)
|
|
128
|
+
comps = comps.reshape(E, D)
|
|
129
|
+
return to_device(torch.tensor(comps), use_cpu)
|
|
130
|
+
|
|
131
|
+
elif weight_ty == SparseType.INT8:
|
|
132
|
+
(E, D) = th_weights.shape
|
|
133
|
+
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
|
|
134
|
+
torch.float32
|
|
135
|
+
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
136
|
+
return to_device(torch.tensor(comps), use_cpu)
|
|
137
|
+
|
|
138
|
+
elif weight_ty == SparseType.FP8:
|
|
139
|
+
assert fp8_config is not None
|
|
140
|
+
assert scale_shift is None
|
|
141
|
+
# Dequantize HPF8 to FP32
|
|
142
|
+
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
|
|
143
|
+
weights,
|
|
144
|
+
fp8_config.get("exponent_bits"),
|
|
145
|
+
fp8_config.get("exponent_bias"),
|
|
146
|
+
)
|
|
147
|
+
return to_device(comps, use_cpu)
|
|
148
|
+
|
|
149
|
+
elif weight_ty == SparseType.FP16:
|
|
150
|
+
assert scale_shift is None
|
|
151
|
+
comps = th_weights.view(torch.half)
|
|
152
|
+
return to_device(torch.tensor(comps), use_cpu)
|
|
153
|
+
|
|
154
|
+
elif weight_ty == SparseType.FP32:
|
|
155
|
+
assert scale_shift is None
|
|
156
|
+
comps = th_weights.view(torch.float32)
|
|
157
|
+
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
158
|
+
return to_device(torch.tensor(comps), use_cpu)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def fake_quantize_embs(
|
|
162
|
+
weights: torch.Tensor,
|
|
163
|
+
scale_shift: Optional[torch.Tensor],
|
|
164
|
+
dequant_weights: torch.Tensor,
|
|
165
|
+
weight_ty: SparseType,
|
|
166
|
+
use_cpu: bool,
|
|
167
|
+
fp8_config: Optional[FP8QuantizationConfig] = None,
|
|
168
|
+
) -> None:
|
|
169
|
+
assert (
|
|
170
|
+
weights.dtype == torch.uint8
|
|
171
|
+
), "The input tensor for dequantize_embs function needs to be byte tensor"
|
|
172
|
+
th_weights = weights
|
|
173
|
+
|
|
174
|
+
if scale_shift is not None:
|
|
175
|
+
th_scale_shift: torch.Tensor = (
|
|
176
|
+
scale_shift.contiguous().view(torch.float16).to(torch.float32)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if weight_ty == SparseType.INT4:
|
|
180
|
+
(E, D_2) = th_weights.shape
|
|
181
|
+
D = D_2 * 2
|
|
182
|
+
|
|
183
|
+
def comp(i: int) -> torch.Tensor:
|
|
184
|
+
subs = th_weights.view(torch.uint8) >> (i * 4)
|
|
185
|
+
sub_mask = subs & 0xF
|
|
186
|
+
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
|
|
187
|
+
-1, 1
|
|
188
|
+
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
189
|
+
return result.to(torch.float32)
|
|
190
|
+
|
|
191
|
+
comps = [comp(i) for i in range(2)]
|
|
192
|
+
comps = torch.stack(comps)
|
|
193
|
+
comps = comps.permute(1, 2, 0)
|
|
194
|
+
comps = comps.reshape(E, D)
|
|
195
|
+
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
196
|
+
|
|
197
|
+
elif weight_ty == SparseType.INT2:
|
|
198
|
+
(E, D_4) = th_weights.shape
|
|
199
|
+
D = D_4 * 4
|
|
200
|
+
|
|
201
|
+
# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
|
|
202
|
+
# pyre-fixme[53]: Captured variable `weights` is not annotated.
|
|
203
|
+
def comp(i: int) -> torch.Tensor:
|
|
204
|
+
subs = th_weights.view(torch.uint8) >> (i * 2)
|
|
205
|
+
sub_mask = subs & 0x3
|
|
206
|
+
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
|
|
207
|
+
-1, 1
|
|
208
|
+
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
209
|
+
return result.to(torch.float32)
|
|
210
|
+
|
|
211
|
+
comps = [comp(i) for i in range(4)]
|
|
212
|
+
comps = torch.stack(comps)
|
|
213
|
+
comps = comps.permute(1, 2, 0)
|
|
214
|
+
comps = comps.reshape(E, D)
|
|
215
|
+
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
216
|
+
|
|
217
|
+
elif weight_ty == SparseType.INT8:
|
|
218
|
+
(E, D) = th_weights.shape
|
|
219
|
+
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
|
|
220
|
+
torch.float32
|
|
221
|
+
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
|
|
222
|
+
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
223
|
+
|
|
224
|
+
elif weight_ty == SparseType.FP8:
|
|
225
|
+
assert fp8_config is not None
|
|
226
|
+
assert scale_shift is None
|
|
227
|
+
# Quantize FP32 to HPF8
|
|
228
|
+
comps = torch.ops.fbgemm.FloatToHFP8Quantized(
|
|
229
|
+
dequant_weights.detach().float(),
|
|
230
|
+
fp8_config.get("exponent_bits"),
|
|
231
|
+
fp8_config.get("exponent_bias"),
|
|
232
|
+
fp8_config.get("max_position"),
|
|
233
|
+
)
|
|
234
|
+
weights.copy_(comps)
|
|
235
|
+
|
|
236
|
+
# Dequantize HPF8 to FP32
|
|
237
|
+
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
|
|
238
|
+
comps,
|
|
239
|
+
fp8_config.get("exponent_bits"),
|
|
240
|
+
fp8_config.get("exponent_bias"),
|
|
241
|
+
)
|
|
242
|
+
dequant_weights.copy_(to_device(comps, use_cpu))
|
|
243
|
+
|
|
244
|
+
elif weight_ty == SparseType.FP16:
|
|
245
|
+
assert scale_shift is None
|
|
246
|
+
comps = dequant_weights.detach().half().view(torch.uint8)
|
|
247
|
+
weights.copy_(comps)
|
|
248
|
+
elif weight_ty == SparseType.FP32:
|
|
249
|
+
assert scale_shift is None
|
|
250
|
+
comps = dequant_weights.detach().float().view(torch.uint8)
|
|
251
|
+
weights.copy_(comps)
|