sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,352 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
9
|
+
USE_FP32_REDUCE_DEFAULT,
|
10
|
+
marlin_make_workspace,
|
11
|
+
marlin_permute_bias,
|
12
|
+
marlin_permute_scales,
|
13
|
+
should_use_atomic_add_reduce,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types
|
16
|
+
from sglang.srt.utils import is_cuda
|
17
|
+
|
18
|
+
_is_cuda = is_cuda()
|
19
|
+
if _is_cuda:
|
20
|
+
from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack
|
21
|
+
|
22
|
+
ScalarType, scalar_types = get_scalar_types()
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
def fp8_fused_exponent_bias_into_scales(scales):
|
28
|
+
fp8_exponent = 4
|
29
|
+
if scales.dtype == torch.half:
|
30
|
+
target_exponent = 5
|
31
|
+
elif scales.dtype == torch.bfloat16:
|
32
|
+
target_exponent = 8
|
33
|
+
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
|
34
|
+
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
|
35
|
+
exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1)
|
36
|
+
s = torch.ones_like(scales) * 2
|
37
|
+
s = s**exponent_bias
|
38
|
+
return scales * s
|
39
|
+
|
40
|
+
|
41
|
+
def apply_fp8_marlin_linear(
|
42
|
+
input: torch.Tensor,
|
43
|
+
weight: torch.Tensor,
|
44
|
+
weight_scale: torch.Tensor,
|
45
|
+
workspace: torch.Tensor,
|
46
|
+
size_n: int,
|
47
|
+
size_k: int,
|
48
|
+
bias: Optional[torch.Tensor],
|
49
|
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
50
|
+
) -> torch.Tensor:
|
51
|
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
52
|
+
# Marlin kernel for fast weight-only FP8 quantization
|
53
|
+
|
54
|
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
55
|
+
out_shape = input.shape[:-1] + (size_n,)
|
56
|
+
|
57
|
+
use_atomic_add = should_use_atomic_add_reduce(
|
58
|
+
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
|
59
|
+
)
|
60
|
+
|
61
|
+
output = gptq_marlin_gemm(
|
62
|
+
a=reshaped_x,
|
63
|
+
c=None,
|
64
|
+
b_q_weight=weight,
|
65
|
+
b_bias=bias,
|
66
|
+
b_scales=weight_scale,
|
67
|
+
global_scale=None,
|
68
|
+
b_zeros=None,
|
69
|
+
g_idx=None,
|
70
|
+
perm=None,
|
71
|
+
workspace=workspace,
|
72
|
+
b_q_type=scalar_types.float8_e4m3fn,
|
73
|
+
size_m=reshaped_x.size(0),
|
74
|
+
size_n=size_n,
|
75
|
+
size_k=size_k,
|
76
|
+
use_atomic_add=use_atomic_add,
|
77
|
+
use_fp32_reduce=use_fp32_reduce,
|
78
|
+
)
|
79
|
+
|
80
|
+
return output.reshape(out_shape)
|
81
|
+
|
82
|
+
|
83
|
+
def prepare_fp8_layer_for_marlin(
|
84
|
+
layer: torch.nn.Module, size_k_first: bool = True
|
85
|
+
) -> None:
|
86
|
+
logger.warning_once(
|
87
|
+
"Your GPU does not have native support for FP8 computation but "
|
88
|
+
"FP8 quantization is being used. Weight-only FP8 compression will "
|
89
|
+
"be used leveraging the Marlin kernel. This may degrade "
|
90
|
+
"performance for compute-heavy workloads."
|
91
|
+
)
|
92
|
+
|
93
|
+
part_size_n = layer.output_size_per_partition
|
94
|
+
part_size_k = layer.input_size_per_partition
|
95
|
+
weight_block_size = getattr(layer, "weight_block_size", None)
|
96
|
+
|
97
|
+
if size_k_first:
|
98
|
+
assert layer.weight.shape == (part_size_k, part_size_n)
|
99
|
+
else:
|
100
|
+
assert layer.weight.shape == (part_size_n, part_size_k)
|
101
|
+
|
102
|
+
device = layer.weight.device
|
103
|
+
|
104
|
+
# WORKSPACE
|
105
|
+
layer.workspace = marlin_make_workspace(device)
|
106
|
+
|
107
|
+
# WEIGHT
|
108
|
+
# Repack weights to marlin format
|
109
|
+
perm = torch.empty(0, dtype=torch.int, device=device)
|
110
|
+
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
|
111
|
+
if not size_k_first:
|
112
|
+
qweight = qweight.T.contiguous()
|
113
|
+
|
114
|
+
marlin_qweight = gptq_marlin_repack(
|
115
|
+
b_q_weight=qweight,
|
116
|
+
perm=perm,
|
117
|
+
size_k=part_size_k,
|
118
|
+
size_n=part_size_n,
|
119
|
+
num_bits=8,
|
120
|
+
)
|
121
|
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
122
|
+
|
123
|
+
# WEIGHT SCALES
|
124
|
+
# Permute scales
|
125
|
+
if "weight_scale" in dir(layer):
|
126
|
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
127
|
+
elif "weight_scale_inv" in dir(layer):
|
128
|
+
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
129
|
+
del layer.weight_scale_inv
|
130
|
+
|
131
|
+
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
132
|
+
|
133
|
+
# marlin kernel only support channel-wise and group-wise quantization
|
134
|
+
# we need to convert the scales
|
135
|
+
if weight_block_size is None:
|
136
|
+
if scales.nelement() == 1:
|
137
|
+
# tensor-wise quantization -> channel-wise quantization
|
138
|
+
# (1, 1) =>(repeat)=> (1, size_n)
|
139
|
+
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
140
|
+
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
141
|
+
assert part_size_n % scales.nelement() == 0
|
142
|
+
s_size = scales.nelement()
|
143
|
+
# tensor-wise quantization (for gate-up proj)
|
144
|
+
# -> channel-wise quantization
|
145
|
+
# (1, s_size) =>(repeat)=> (1, size_n)
|
146
|
+
scales = scales.view(1, s_size)
|
147
|
+
scales = scales.repeat_interleave(part_size_n // s_size, 1)
|
148
|
+
else:
|
149
|
+
# channel-wise quantization
|
150
|
+
# (1, size_n)
|
151
|
+
scales = scales.view(1, part_size_n)
|
152
|
+
else:
|
153
|
+
# block-wise quantization -> group-wise quantization
|
154
|
+
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
155
|
+
# =>(repeat)=> (size_k // block_size[1], size_n)
|
156
|
+
if not size_k_first:
|
157
|
+
scales = scales.T.contiguous()
|
158
|
+
block_n = weight_block_size[0]
|
159
|
+
scales = scales.repeat_interleave(block_n, 1)
|
160
|
+
# size_n may not divisible by block_size[0]
|
161
|
+
scales = scales[:, :part_size_n]
|
162
|
+
|
163
|
+
marlin_scales = marlin_permute_scales(
|
164
|
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
|
165
|
+
)
|
166
|
+
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
167
|
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
168
|
+
|
169
|
+
if hasattr(layer, "bias") and layer.bias is not None:
|
170
|
+
assert layer.bias.shape == (part_size_n,)
|
171
|
+
bias = marlin_permute_bias(layer.bias)
|
172
|
+
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
173
|
+
|
174
|
+
|
175
|
+
def prepare_moe_fp8_layer_for_marlin(
|
176
|
+
layer: torch.nn.Module, size_k_first: bool = True
|
177
|
+
) -> None:
|
178
|
+
logger.warning_once(
|
179
|
+
"Your GPU does not have native support for FP8 computation but "
|
180
|
+
"FP8 quantization is being used. Weight-only FP8 compression will "
|
181
|
+
"be used leveraging the Marlin kernel. This may degrade "
|
182
|
+
"performance for compute-heavy workloads."
|
183
|
+
)
|
184
|
+
|
185
|
+
e = layer.num_experts
|
186
|
+
k = layer.hidden_size
|
187
|
+
n = layer.intermediate_size_per_partition
|
188
|
+
weight_block_size = getattr(layer, "weight_block_size", None)
|
189
|
+
|
190
|
+
# WORKSPACE
|
191
|
+
device = layer.w13_weight.device
|
192
|
+
layer.workspace = marlin_make_workspace(device, 4)
|
193
|
+
perm = torch.empty(0, dtype=torch.int, device=device)
|
194
|
+
|
195
|
+
# WEIGHT
|
196
|
+
# Repack weights to marlin format
|
197
|
+
for name in ["w13_weight", "w2_weight"]:
|
198
|
+
weight = getattr(layer, name)
|
199
|
+
tensor_list = []
|
200
|
+
if "w13" in name:
|
201
|
+
size_n, size_k = n * 2, k
|
202
|
+
else:
|
203
|
+
size_n, size_k = k, n
|
204
|
+
|
205
|
+
if size_k_first:
|
206
|
+
assert weight.shape == (e, size_k, size_n)
|
207
|
+
else:
|
208
|
+
assert weight.shape == (e, size_n, size_k)
|
209
|
+
|
210
|
+
for i in range(e):
|
211
|
+
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
212
|
+
if not size_k_first:
|
213
|
+
qweight = qweight.T.contiguous()
|
214
|
+
|
215
|
+
marlin_qweight = gptq_marlin_repack(
|
216
|
+
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
|
217
|
+
)
|
218
|
+
tensor_list.append(marlin_qweight)
|
219
|
+
|
220
|
+
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
221
|
+
weight = torch.nn.Parameter(weight, requires_grad=False)
|
222
|
+
|
223
|
+
setattr(layer, name, weight)
|
224
|
+
|
225
|
+
# WEIGHT SCALES
|
226
|
+
# Permute scales
|
227
|
+
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
228
|
+
|
229
|
+
for name in ["w13", "w2"]:
|
230
|
+
if name + "_weight_scale" in dir(layer):
|
231
|
+
new_name = name + "_weight_scale"
|
232
|
+
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
233
|
+
delattr(layer, new_name)
|
234
|
+
elif name + "_weight_scale_inv" in dir(layer):
|
235
|
+
new_name = name + "_weight_scale_inv"
|
236
|
+
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
237
|
+
delattr(layer, new_name)
|
238
|
+
|
239
|
+
tensor_list = []
|
240
|
+
if "w13" in name:
|
241
|
+
size_n, size_k = n * 2, k
|
242
|
+
else:
|
243
|
+
size_n, size_k = k, n
|
244
|
+
|
245
|
+
# marlin kernel only support channel-wise and group-wise quantization
|
246
|
+
# we need to convert the scales
|
247
|
+
if weight_block_size is None:
|
248
|
+
if scales.nelement() == e:
|
249
|
+
# tensor-wise quantization -> channel-wise quantization
|
250
|
+
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
251
|
+
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
|
252
|
+
elif scales.nelement() > e and scales.nelement() != e * size_n:
|
253
|
+
assert (e * size_n) % scales.nelement() == 0
|
254
|
+
s_size = scales.nelement() // e
|
255
|
+
# tensor-wise quantization (for gate-up proj)
|
256
|
+
# -> channel-wise quantization
|
257
|
+
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
|
258
|
+
scales = scales.view(e, 1, s_size)
|
259
|
+
scales = scales.repeat_interleave(size_n // s_size, 2)
|
260
|
+
else:
|
261
|
+
# channel-wise quantization
|
262
|
+
# (e, 1, size_n)
|
263
|
+
scales = scales.view(e, 1, size_n)
|
264
|
+
else:
|
265
|
+
# block-wise quantization -> group-wise quantization
|
266
|
+
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
267
|
+
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
268
|
+
if not size_k_first:
|
269
|
+
scales = scales.permute(0, 2, 1)
|
270
|
+
block_n = weight_block_size[0]
|
271
|
+
scales = scales.repeat_interleave(block_n, 2)
|
272
|
+
# size_n may not divisible by block_size[0]
|
273
|
+
scales = scales[..., :size_n].contiguous()
|
274
|
+
|
275
|
+
for i in range(e):
|
276
|
+
marlin_scales = marlin_permute_scales(
|
277
|
+
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
|
278
|
+
)
|
279
|
+
tensor_list.append(marlin_scales)
|
280
|
+
|
281
|
+
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
282
|
+
scales = fp8_fused_exponent_bias_into_scales(scales)
|
283
|
+
scales = torch.nn.Parameter(scales, requires_grad=False)
|
284
|
+
|
285
|
+
setattr(layer, name + "_weight_scale", scales)
|
286
|
+
|
287
|
+
# BIAS
|
288
|
+
# Permute bias
|
289
|
+
for name in ["w13_bias", "w2_bias"]:
|
290
|
+
if not hasattr(layer, name):
|
291
|
+
continue
|
292
|
+
bias = getattr(layer, name).to(layer.orig_dtype)
|
293
|
+
|
294
|
+
tensor_list = []
|
295
|
+
for i in range(e):
|
296
|
+
expert_bias = bias[i]
|
297
|
+
|
298
|
+
tensor_list.append(marlin_permute_bias(expert_bias))
|
299
|
+
|
300
|
+
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
301
|
+
bias = torch.nn.Parameter(bias, requires_grad=False)
|
302
|
+
setattr(layer, name, bias)
|
303
|
+
|
304
|
+
|
305
|
+
def pack_fp8_to_int32(
|
306
|
+
fp8_tensor: torch.Tensor, size_k_first: bool = True
|
307
|
+
) -> torch.Tensor:
|
308
|
+
"""
|
309
|
+
Repack FP8 weights to gptq format (packed int32 elements)
|
310
|
+
"""
|
311
|
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
312
|
+
assert fp8_tensor.ndim == 2
|
313
|
+
|
314
|
+
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
|
315
|
+
fp8_tensor = fp8_tensor.contiguous()
|
316
|
+
# fp8_tensor is contiguous and have shape (N, K) now
|
317
|
+
# with `.view(torch.int32)`, it become (N, K // 4)
|
318
|
+
int32_tensor = fp8_tensor.view(torch.int32)
|
319
|
+
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
320
|
+
|
321
|
+
|
322
|
+
def marlin_quant_fp8_torch(weight, group_size):
|
323
|
+
size_n, size_k = weight.shape
|
324
|
+
device = weight.device
|
325
|
+
|
326
|
+
if group_size != -1:
|
327
|
+
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
|
328
|
+
repeated_scales = scales.repeat_interleave(group_size, 1)
|
329
|
+
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
330
|
+
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
331
|
+
else:
|
332
|
+
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
|
333
|
+
repeated_scales = scales.repeat_interleave(size_k, 1)
|
334
|
+
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
335
|
+
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
336
|
+
|
337
|
+
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
338
|
+
marlin_qweight = gptq_marlin_repack(
|
339
|
+
b_q_weight=packed_weight,
|
340
|
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
341
|
+
size_k=size_k,
|
342
|
+
size_n=size_n,
|
343
|
+
num_bits=8,
|
344
|
+
)
|
345
|
+
|
346
|
+
marlin_scales = marlin_permute_scales(
|
347
|
+
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
|
348
|
+
)
|
349
|
+
|
350
|
+
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
351
|
+
|
352
|
+
return weight_ref.T, marlin_qweight, marlin_scales
|