sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,65 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
from __future__ import annotations
|
3
|
+
|
2
4
|
import logging
|
3
|
-
|
5
|
+
import warnings
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
4
7
|
|
5
8
|
import torch
|
6
9
|
|
7
|
-
from sglang.srt.layers.linear import
|
8
|
-
|
10
|
+
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
|
11
|
+
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
12
|
+
from sglang.srt.layers.quantization.base_config import (
|
13
|
+
FusedMoEMethodBase,
|
9
14
|
LinearMethodBase,
|
10
|
-
|
15
|
+
QuantizationConfig,
|
16
|
+
QuantizeMethodBase,
|
11
17
|
)
|
12
|
-
from sglang.srt.layers.
|
13
|
-
|
14
|
-
|
18
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
19
|
+
apply_awq_marlin_linear,
|
20
|
+
awq_to_marlin_zero_points,
|
21
|
+
check_marlin_supported,
|
22
|
+
check_marlin_supports_layer,
|
23
|
+
check_moe_marlin_supports_layer,
|
24
|
+
marlin_make_empty_g_idx,
|
25
|
+
marlin_make_workspace,
|
26
|
+
marlin_moe_permute_scales,
|
27
|
+
marlin_permute_scales,
|
28
|
+
moe_awq_to_marlin_zero_points,
|
29
|
+
verify_marlin_supported,
|
30
|
+
verify_marlin_supports_shape,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
33
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
+
from sglang.srt.layers.quantization.utils import replace_parameter
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
38
|
+
|
39
|
+
try:
|
40
|
+
from vllm import _custom_ops as ops
|
41
|
+
|
42
|
+
warnings.warn(
|
43
|
+
f"Using kernels directly from vllm. This might lead to performance degradation or "
|
44
|
+
f"missing functionalities as certain kernels may not be optimized. "
|
45
|
+
)
|
46
|
+
except ImportError:
|
47
|
+
ops = None
|
48
|
+
|
49
|
+
from sglang.srt.utils import is_cuda, is_hip
|
15
50
|
|
16
51
|
_is_cuda = is_cuda()
|
52
|
+
_is_hip = is_hip()
|
17
53
|
if _is_cuda:
|
18
|
-
from sgl_kernel import awq_dequantize
|
54
|
+
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
55
|
+
elif _is_hip:
|
56
|
+
from sglang.srt.layers.quantization.awq_triton import (
|
57
|
+
awq_dequantize_triton as awq_dequantize,
|
58
|
+
)
|
59
|
+
|
60
|
+
warnings.warn(f"HIP does not support fused_marlin_moe currently.")
|
61
|
+
else:
|
62
|
+
warnings.warn(f"Only CUDA and HIP support AWQ currently.")
|
19
63
|
|
20
64
|
logger = logging.getLogger(__name__)
|
21
65
|
|
@@ -81,7 +125,7 @@ class AWQConfig(QuantizationConfig):
|
|
81
125
|
]
|
82
126
|
|
83
127
|
@classmethod
|
84
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
128
|
+
def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
|
85
129
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
86
130
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
87
131
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
@@ -92,7 +136,8 @@ class AWQConfig(QuantizationConfig):
|
|
92
136
|
|
93
137
|
def get_quant_method(
|
94
138
|
self, layer: torch.nn.Module, prefix: str
|
95
|
-
) -> Optional[
|
139
|
+
) -> Optional[LinearMethodBase]:
|
140
|
+
from sglang.srt.layers.linear import LinearBase
|
96
141
|
|
97
142
|
if isinstance(layer, LinearBase):
|
98
143
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
@@ -101,6 +146,176 @@ class AWQConfig(QuantizationConfig):
|
|
101
146
|
return None
|
102
147
|
|
103
148
|
|
149
|
+
class AWQMarlinConfig(QuantizationConfig):
|
150
|
+
"""Config class for AWQ Marlin"""
|
151
|
+
|
152
|
+
# num_bits -> type
|
153
|
+
TYPE_MAP = {
|
154
|
+
4: scalar_types.uint4,
|
155
|
+
8: scalar_types.uint8,
|
156
|
+
}
|
157
|
+
|
158
|
+
def __init__(
|
159
|
+
self,
|
160
|
+
weight_bits: int,
|
161
|
+
group_size: int,
|
162
|
+
zero_point: bool,
|
163
|
+
lm_head_quantized: bool,
|
164
|
+
modules_to_not_convert: Optional[list[str]],
|
165
|
+
full_config: dict[str, Any],
|
166
|
+
) -> None:
|
167
|
+
super().__init__()
|
168
|
+
self.pack_factor = 32 // weight_bits # packed into int32
|
169
|
+
self.group_size = group_size
|
170
|
+
self.zero_point = zero_point
|
171
|
+
self.lm_head_quantized = lm_head_quantized
|
172
|
+
self.weight_bits = weight_bits
|
173
|
+
self.modules_to_not_convert = modules_to_not_convert or []
|
174
|
+
self.full_config = full_config
|
175
|
+
|
176
|
+
if self.weight_bits not in self.TYPE_MAP:
|
177
|
+
raise ValueError(
|
178
|
+
f"Unsupported num_bits = {self.weight_bits}. "
|
179
|
+
f"Supported num_bits = {self.TYPE_MAP.keys()}"
|
180
|
+
)
|
181
|
+
|
182
|
+
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
183
|
+
|
184
|
+
verify_marlin_supported(
|
185
|
+
self.quant_type, group_size=self.group_size, has_zp=self.zero_point
|
186
|
+
)
|
187
|
+
|
188
|
+
def __repr__(self) -> str:
|
189
|
+
return (
|
190
|
+
f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
191
|
+
f"group_size={self.group_size}, "
|
192
|
+
f"zero_point={self.zero_point}, "
|
193
|
+
f"lm_head_quantized={self.lm_head_quantized}, "
|
194
|
+
f"modules_to_not_convert={self.modules_to_not_convert})"
|
195
|
+
)
|
196
|
+
|
197
|
+
def get_scaled_act_names(self) -> List[str]:
|
198
|
+
return []
|
199
|
+
|
200
|
+
@classmethod
|
201
|
+
def get_name(cls) -> str:
|
202
|
+
return "awq_marlin"
|
203
|
+
|
204
|
+
@classmethod
|
205
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
206
|
+
return [torch.half, torch.bfloat16]
|
207
|
+
|
208
|
+
@classmethod
|
209
|
+
def get_min_capability(cls) -> int:
|
210
|
+
return 80
|
211
|
+
|
212
|
+
@classmethod
|
213
|
+
def get_config_filenames(cls) -> list[str]:
|
214
|
+
return ["quantize_config.json"]
|
215
|
+
|
216
|
+
@classmethod
|
217
|
+
def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig:
|
218
|
+
weight_bits = cls.get_from_keys(config, ["bits"])
|
219
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
220
|
+
zero_point = cls.get_from_keys(config, ["zero_point"])
|
221
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
222
|
+
modules_to_not_convert = cls.get_from_keys_or(
|
223
|
+
config, ["modules_to_not_convert"], None
|
224
|
+
)
|
225
|
+
return cls(
|
226
|
+
weight_bits,
|
227
|
+
group_size,
|
228
|
+
zero_point,
|
229
|
+
lm_head_quantized,
|
230
|
+
modules_to_not_convert,
|
231
|
+
config,
|
232
|
+
)
|
233
|
+
|
234
|
+
@classmethod
|
235
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
236
|
+
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
237
|
+
is_valid_user_quant = (
|
238
|
+
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
239
|
+
)
|
240
|
+
|
241
|
+
if can_convert and is_valid_user_quant:
|
242
|
+
msg = (
|
243
|
+
"The model is convertible to {} during runtime."
|
244
|
+
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
245
|
+
)
|
246
|
+
logger.info(msg)
|
247
|
+
return cls.get_name()
|
248
|
+
|
249
|
+
if can_convert and user_quant == "awq":
|
250
|
+
logger.info(
|
251
|
+
"Detected that the model can run with awq_marlin"
|
252
|
+
", however you specified quantization=awq explicitly,"
|
253
|
+
" so forcing awq. Use quantization=awq_marlin for"
|
254
|
+
" faster inference"
|
255
|
+
)
|
256
|
+
return None
|
257
|
+
|
258
|
+
def get_quant_method(
|
259
|
+
self, layer: torch.nn.Module, prefix: str
|
260
|
+
) -> Optional[QuantizeMethodBase]:
|
261
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
262
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
263
|
+
|
264
|
+
if isinstance(layer, LinearBase) or (
|
265
|
+
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
266
|
+
):
|
267
|
+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
268
|
+
return UnquantizedLinearMethod()
|
269
|
+
# Check if the layer is supported by AWQMarlin.
|
270
|
+
if not check_marlin_supports_layer(layer, self.group_size):
|
271
|
+
logger.warning_once(
|
272
|
+
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
|
273
|
+
prefix,
|
274
|
+
)
|
275
|
+
return AWQConfig.from_config(self.full_config).get_quant_method(
|
276
|
+
layer, prefix
|
277
|
+
)
|
278
|
+
return AWQMarlinLinearMethod(self)
|
279
|
+
elif isinstance(layer, FusedMoE):
|
280
|
+
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
281
|
+
|
282
|
+
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
283
|
+
logger.warning_once(
|
284
|
+
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
285
|
+
"Falling back to Moe WNA16 kernels."
|
286
|
+
)
|
287
|
+
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
288
|
+
layer, prefix
|
289
|
+
)
|
290
|
+
return AWQMoEMethod(self)
|
291
|
+
return None
|
292
|
+
|
293
|
+
@classmethod
|
294
|
+
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
295
|
+
# Extract data from quant config.
|
296
|
+
quant_method = quant_config.get("quant_method", "").lower()
|
297
|
+
num_bits = quant_config.get("bits")
|
298
|
+
group_size = quant_config.get("group_size")
|
299
|
+
zero_point = quant_config.get("zero_point")
|
300
|
+
|
301
|
+
if not _is_cuda:
|
302
|
+
return False
|
303
|
+
|
304
|
+
if quant_method != "awq":
|
305
|
+
return False
|
306
|
+
|
307
|
+
# If we cannot find the info needed in the config, cannot convert.
|
308
|
+
if num_bits is None or group_size is None or zero_point is None:
|
309
|
+
return False
|
310
|
+
|
311
|
+
if num_bits not in cls.TYPE_MAP:
|
312
|
+
return False
|
313
|
+
|
314
|
+
return check_marlin_supported(
|
315
|
+
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
|
316
|
+
)
|
317
|
+
|
318
|
+
|
104
319
|
class AWQLinearMethod(LinearMethodBase):
|
105
320
|
"""Linear method for AWQ.
|
106
321
|
|
@@ -195,10 +410,362 @@ class AWQLinearMethod(LinearMethodBase):
|
|
195
410
|
pack_factor = self.quant_config.pack_factor
|
196
411
|
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
197
412
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
198
|
-
|
199
413
|
out = awq_dequantize(qweight, scales, qzeros)
|
200
414
|
out = torch.matmul(reshaped_x, out)
|
201
415
|
|
202
416
|
if bias is not None:
|
203
417
|
out.add_(bias)
|
204
418
|
return out.reshape(out_shape)
|
419
|
+
|
420
|
+
|
421
|
+
class AWQMarlinLinearMethod(LinearMethodBase):
|
422
|
+
"""Linear method for AWQ Marlin.
|
423
|
+
|
424
|
+
Args:
|
425
|
+
quant_config: The AWQ Marlin quantization config.
|
426
|
+
"""
|
427
|
+
|
428
|
+
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
429
|
+
self.quant_config = quant_config
|
430
|
+
|
431
|
+
def create_weights(
|
432
|
+
self,
|
433
|
+
layer: torch.nn.Module,
|
434
|
+
input_size_per_partition: int,
|
435
|
+
output_partition_sizes: list[int],
|
436
|
+
input_size: int,
|
437
|
+
output_size: int,
|
438
|
+
params_dtype: torch.dtype,
|
439
|
+
**extra_weight_attrs,
|
440
|
+
) -> None:
|
441
|
+
del output_size
|
442
|
+
output_size_per_partition = sum(output_partition_sizes)
|
443
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
444
|
+
|
445
|
+
# Normalize group_size
|
446
|
+
if self.quant_config.group_size != -1:
|
447
|
+
group_size = self.quant_config.group_size
|
448
|
+
else:
|
449
|
+
group_size = input_size
|
450
|
+
|
451
|
+
verify_marlin_supports_shape(
|
452
|
+
output_size_per_partition=output_size_per_partition,
|
453
|
+
input_size_per_partition=input_size_per_partition,
|
454
|
+
input_size=input_size,
|
455
|
+
group_size=group_size,
|
456
|
+
)
|
457
|
+
|
458
|
+
qweight = PackedvLLMParameter(
|
459
|
+
data=torch.empty(
|
460
|
+
input_size_per_partition,
|
461
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
462
|
+
dtype=torch.int32,
|
463
|
+
),
|
464
|
+
input_dim=0,
|
465
|
+
output_dim=1,
|
466
|
+
packed_dim=1,
|
467
|
+
packed_factor=self.quant_config.pack_factor,
|
468
|
+
weight_loader=weight_loader,
|
469
|
+
)
|
470
|
+
|
471
|
+
num_groups = input_size_per_partition // group_size
|
472
|
+
|
473
|
+
qzeros = PackedvLLMParameter(
|
474
|
+
data=torch.empty(
|
475
|
+
num_groups,
|
476
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
477
|
+
dtype=torch.int32,
|
478
|
+
),
|
479
|
+
input_dim=0,
|
480
|
+
output_dim=1,
|
481
|
+
packed_dim=1,
|
482
|
+
packed_factor=self.quant_config.pack_factor,
|
483
|
+
weight_loader=weight_loader,
|
484
|
+
)
|
485
|
+
|
486
|
+
scales = GroupQuantScaleParameter(
|
487
|
+
data=torch.empty(
|
488
|
+
num_groups,
|
489
|
+
output_size_per_partition,
|
490
|
+
dtype=params_dtype,
|
491
|
+
),
|
492
|
+
input_dim=0,
|
493
|
+
output_dim=1,
|
494
|
+
weight_loader=weight_loader,
|
495
|
+
)
|
496
|
+
|
497
|
+
layer.register_parameter("qweight", qweight)
|
498
|
+
layer.register_parameter("qzeros", qzeros)
|
499
|
+
layer.register_parameter("scales", scales)
|
500
|
+
|
501
|
+
layer.input_size_per_partition = input_size_per_partition
|
502
|
+
layer.output_size_per_partition = output_size_per_partition
|
503
|
+
layer.num_groups = num_groups
|
504
|
+
|
505
|
+
# TODO: Update this docs
|
506
|
+
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
507
|
+
# marlin format. This function is called after the weights are loaded.
|
508
|
+
# Here, we handle the repacking
|
509
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
510
|
+
device = layer.qweight.device
|
511
|
+
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
512
|
+
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
513
|
+
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
514
|
+
|
515
|
+
# Allocate marlin workspace
|
516
|
+
layer.workspace = marlin_make_workspace(device)
|
517
|
+
|
518
|
+
# Repack weights from AWQ format to marlin format.
|
519
|
+
marlin_qweight = ops.awq_marlin_repack(
|
520
|
+
layer.qweight,
|
521
|
+
size_k=layer.input_size_per_partition,
|
522
|
+
size_n=layer.output_size_per_partition,
|
523
|
+
num_bits=self.quant_config.quant_type.size_bits,
|
524
|
+
)
|
525
|
+
replace_parameter(layer, "qweight", marlin_qweight)
|
526
|
+
|
527
|
+
# Permute scales from AWQ format to marlin format.
|
528
|
+
marlin_scales = marlin_permute_scales(
|
529
|
+
layer.scales,
|
530
|
+
size_k=layer.input_size_per_partition,
|
531
|
+
size_n=layer.output_size_per_partition,
|
532
|
+
group_size=self.quant_config.group_size,
|
533
|
+
)
|
534
|
+
replace_parameter(layer, "scales", marlin_scales)
|
535
|
+
|
536
|
+
# Permute zero-points from AWQ format to marlin format.
|
537
|
+
marlin_zp = awq_to_marlin_zero_points(
|
538
|
+
layer.qzeros,
|
539
|
+
size_k=layer.num_groups,
|
540
|
+
size_n=layer.output_size_per_partition,
|
541
|
+
num_bits=self.quant_config.quant_type.size_bits,
|
542
|
+
)
|
543
|
+
replace_parameter(layer, "qzeros", marlin_zp)
|
544
|
+
|
545
|
+
# Not-used
|
546
|
+
layer.g_idx = marlin_make_empty_g_idx(device)
|
547
|
+
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
548
|
+
|
549
|
+
def apply(
|
550
|
+
self,
|
551
|
+
layer: torch.nn.Module,
|
552
|
+
x: torch.Tensor,
|
553
|
+
bias: Optional[torch.Tensor] = None,
|
554
|
+
) -> torch.Tensor:
|
555
|
+
return apply_awq_marlin_linear(
|
556
|
+
input=x,
|
557
|
+
weight=layer.qweight,
|
558
|
+
weight_scale=layer.scales,
|
559
|
+
weight_zp=layer.qzeros,
|
560
|
+
g_idx=layer.g_idx,
|
561
|
+
g_idx_sort_indices=layer.g_idx_sort_indices,
|
562
|
+
workspace=layer.workspace,
|
563
|
+
quant_type=self.quant_config.quant_type,
|
564
|
+
output_size_per_partition=layer.output_size_per_partition,
|
565
|
+
input_size_per_partition=layer.input_size_per_partition,
|
566
|
+
bias=bias,
|
567
|
+
)
|
568
|
+
|
569
|
+
|
570
|
+
class AWQMoEMethod(FusedMoEMethodBase):
|
571
|
+
|
572
|
+
def __init__(self, quant_config: AWQMarlinConfig):
|
573
|
+
self.quant_config = quant_config
|
574
|
+
if self.quant_config.weight_bits != 4:
|
575
|
+
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
576
|
+
self.quant_type = scalar_types.uint4
|
577
|
+
|
578
|
+
def create_weights(
|
579
|
+
self,
|
580
|
+
layer: torch.nn.Module,
|
581
|
+
num_experts: int,
|
582
|
+
hidden_size: int,
|
583
|
+
intermediate_size_per_partition: int,
|
584
|
+
params_dtype: torch.dtype,
|
585
|
+
**extra_weight_attrs,
|
586
|
+
):
|
587
|
+
# Delay the import to avoid circular dependency
|
588
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
589
|
+
|
590
|
+
extra_weight_attrs.update(
|
591
|
+
{
|
592
|
+
"is_transposed": True,
|
593
|
+
"quant_method": FusedMoeWeightScaleSupported.GROUP.value,
|
594
|
+
}
|
595
|
+
)
|
596
|
+
|
597
|
+
w13_qweight = torch.nn.Parameter(
|
598
|
+
torch.empty(
|
599
|
+
num_experts,
|
600
|
+
hidden_size,
|
601
|
+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
602
|
+
dtype=torch.int32,
|
603
|
+
),
|
604
|
+
requires_grad=False,
|
605
|
+
)
|
606
|
+
layer.register_parameter("w13_qweight", w13_qweight)
|
607
|
+
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
608
|
+
|
609
|
+
w2_qweight = torch.nn.Parameter(
|
610
|
+
torch.empty(
|
611
|
+
num_experts,
|
612
|
+
intermediate_size_per_partition,
|
613
|
+
hidden_size // self.quant_config.pack_factor,
|
614
|
+
dtype=torch.int32,
|
615
|
+
),
|
616
|
+
requires_grad=False,
|
617
|
+
)
|
618
|
+
layer.register_parameter("w2_qweight", w2_qweight)
|
619
|
+
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
620
|
+
|
621
|
+
num_groups_w13 = hidden_size // self.quant_config.group_size
|
622
|
+
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
|
623
|
+
|
624
|
+
# WEIGHT_SCALES
|
625
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
626
|
+
w13_scales = torch.nn.Parameter(
|
627
|
+
torch.empty(
|
628
|
+
num_experts,
|
629
|
+
num_groups_w13,
|
630
|
+
intermediate_size_per_partition * 2,
|
631
|
+
dtype=params_dtype,
|
632
|
+
),
|
633
|
+
requires_grad=False,
|
634
|
+
)
|
635
|
+
layer.register_parameter("w13_scales", w13_scales)
|
636
|
+
set_weight_attrs(w13_scales, extra_weight_attrs)
|
637
|
+
|
638
|
+
w2_scales = torch.nn.Parameter(
|
639
|
+
torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
640
|
+
requires_grad=False,
|
641
|
+
)
|
642
|
+
layer.register_parameter("w2_scales", w2_scales)
|
643
|
+
set_weight_attrs(w2_scales, extra_weight_attrs)
|
644
|
+
|
645
|
+
# WEIGHT_ZERO_POINT
|
646
|
+
# Allocate 2 zero points for w1 and w3 respectively.
|
647
|
+
w13_qzeros = torch.nn.Parameter(
|
648
|
+
torch.empty(
|
649
|
+
num_experts,
|
650
|
+
num_groups_w13,
|
651
|
+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
652
|
+
dtype=torch.int32,
|
653
|
+
),
|
654
|
+
requires_grad=False,
|
655
|
+
)
|
656
|
+
layer.register_parameter("w13_qzeros", w13_qzeros)
|
657
|
+
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
658
|
+
|
659
|
+
w2_qzeros = torch.nn.Parameter(
|
660
|
+
torch.empty(
|
661
|
+
num_experts,
|
662
|
+
num_groups_w2,
|
663
|
+
hidden_size // self.quant_config.pack_factor,
|
664
|
+
dtype=torch.int32,
|
665
|
+
),
|
666
|
+
requires_grad=False,
|
667
|
+
)
|
668
|
+
layer.register_parameter("w2_qzeros", w2_qzeros)
|
669
|
+
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
670
|
+
|
671
|
+
device = layer.w13_qweight.device
|
672
|
+
layer.workspace = marlin_make_workspace(device, 4)
|
673
|
+
|
674
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
675
|
+
num_experts = layer.w13_qweight.shape[0]
|
676
|
+
device = layer.w13_qweight.device
|
677
|
+
|
678
|
+
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
679
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
680
|
+
requires_grad=False,
|
681
|
+
)
|
682
|
+
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
683
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
684
|
+
requires_grad=False,
|
685
|
+
)
|
686
|
+
|
687
|
+
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
688
|
+
layer.w13_qweight,
|
689
|
+
layer.w13_g_idx_sort_indices,
|
690
|
+
size_k=layer.w13_qweight.shape[1],
|
691
|
+
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
692
|
+
num_bits=self.quant_config.weight_bits,
|
693
|
+
)
|
694
|
+
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
695
|
+
|
696
|
+
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
697
|
+
layer.w2_qweight,
|
698
|
+
layer.w2_g_idx_sort_indices,
|
699
|
+
size_k=layer.w2_qweight.shape[1],
|
700
|
+
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
701
|
+
num_bits=self.quant_config.weight_bits,
|
702
|
+
)
|
703
|
+
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
704
|
+
|
705
|
+
# hidden_size->intermediate_size
|
706
|
+
marlin_w13_scales = marlin_moe_permute_scales(
|
707
|
+
s=layer.w13_scales,
|
708
|
+
size_k=layer.intermediate_size_per_partition,
|
709
|
+
size_n=layer.w13_scales.shape[2],
|
710
|
+
group_size=self.quant_config.group_size,
|
711
|
+
)
|
712
|
+
|
713
|
+
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
714
|
+
|
715
|
+
marlin_w2_scales = marlin_moe_permute_scales(
|
716
|
+
s=layer.w2_scales,
|
717
|
+
size_k=layer.intermediate_size_per_partition,
|
718
|
+
size_n=layer.w2_scales.shape[2],
|
719
|
+
group_size=self.quant_config.group_size,
|
720
|
+
)
|
721
|
+
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
722
|
+
|
723
|
+
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
724
|
+
layer.w13_qzeros,
|
725
|
+
size_k=layer.w13_qzeros.shape[1],
|
726
|
+
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
727
|
+
num_bits=self.quant_config.weight_bits,
|
728
|
+
)
|
729
|
+
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
730
|
+
|
731
|
+
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
732
|
+
layer.w2_qzeros,
|
733
|
+
size_k=layer.w2_qzeros.shape[1],
|
734
|
+
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
735
|
+
num_bits=self.quant_config.weight_bits,
|
736
|
+
)
|
737
|
+
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
738
|
+
|
739
|
+
def apply(
|
740
|
+
self,
|
741
|
+
layer: torch.nn.Module,
|
742
|
+
x: torch.Tensor,
|
743
|
+
topk_output: TopKOutput,
|
744
|
+
*,
|
745
|
+
activation: str = "silu",
|
746
|
+
**kwargs,
|
747
|
+
) -> torch.Tensor:
|
748
|
+
|
749
|
+
assert activation == "silu", "Only SiLU activation is supported."
|
750
|
+
|
751
|
+
# The input must currently be float16
|
752
|
+
orig_dtype = x.dtype
|
753
|
+
x = x.half()
|
754
|
+
|
755
|
+
topk_weights, topk_ids, router_logits = topk_output
|
756
|
+
|
757
|
+
return fused_marlin_moe(
|
758
|
+
x,
|
759
|
+
layer.w13_qweight,
|
760
|
+
layer.w2_qweight,
|
761
|
+
layer.w13_scales,
|
762
|
+
layer.w2_scales,
|
763
|
+
router_logits,
|
764
|
+
topk_weights,
|
765
|
+
topk_ids,
|
766
|
+
sort_indices1=layer.w13_g_idx_sort_indices,
|
767
|
+
sort_indices2=layer.w2_g_idx_sort_indices,
|
768
|
+
w1_zeros=layer.w13_qzeros,
|
769
|
+
w2_zeros=layer.w2_qzeros,
|
770
|
+
num_bits=self.quant_config.weight_bits,
|
771
|
+
).to(orig_dtype)
|