sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -6,53 +6,82 @@ from copy import deepcopy
|
|
6
6
|
from typing import Callable, Dict, Optional, Type, Union
|
7
7
|
|
8
8
|
import torch
|
9
|
-
|
10
|
-
|
11
|
-
from vllm.model_executor.layers.quantization.
|
12
|
-
from vllm.model_executor.layers.quantization.
|
13
|
-
from vllm.model_executor.layers.quantization.
|
14
|
-
|
15
|
-
|
16
|
-
from vllm.model_executor.layers.quantization.
|
17
|
-
from vllm.model_executor.layers.quantization.
|
18
|
-
from vllm.model_executor.layers.quantization.
|
19
|
-
from vllm.model_executor.layers.quantization.
|
20
|
-
|
21
|
-
|
22
|
-
from vllm.model_executor.layers.quantization.
|
23
|
-
from vllm.model_executor.layers.quantization.
|
9
|
+
|
10
|
+
try:
|
11
|
+
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
+
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
13
|
+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
14
|
+
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
15
|
+
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
16
|
+
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
17
|
+
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
18
|
+
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
19
|
+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
20
|
+
GPTQMarlin24Config,
|
21
|
+
)
|
22
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
23
|
+
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
24
|
+
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
25
|
+
|
26
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
27
|
+
|
28
|
+
VLLM_AVAILABLE = True
|
29
|
+
except ImportError:
|
30
|
+
VLLM_AVAILABLE = False
|
31
|
+
|
32
|
+
# Define empty classes as placeholders when vllm is not available
|
33
|
+
class DummyConfig:
|
34
|
+
pass
|
35
|
+
|
36
|
+
AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
|
37
|
+
CompressedTensorsConfig
|
38
|
+
) = DummyConfig
|
39
|
+
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
40
|
+
GPTQMarlin24Config
|
41
|
+
) = DummyConfig
|
42
|
+
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
24
43
|
|
25
44
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
45
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
46
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
47
|
+
CompressedTensorsConfig,
|
48
|
+
)
|
27
49
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
28
|
-
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
29
50
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
30
51
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
31
52
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
32
53
|
|
33
|
-
|
34
|
-
|
35
|
-
"awq": AWQConfig,
|
36
|
-
"deepspeedfp": DeepSpeedFPConfig,
|
37
|
-
"tpu_int8": Int8TpuConfig,
|
54
|
+
# Base quantization methods that don't depend on vllm
|
55
|
+
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
38
56
|
"fp8": Fp8Config,
|
39
57
|
"blockwise_int8": BlockInt8Config,
|
40
|
-
"fbgemm_fp8": FBGEMMFp8Config,
|
41
|
-
"marlin": MarlinConfig,
|
42
58
|
"modelopt": ModelOptFp8Config,
|
43
|
-
"gguf": GGUFConfig,
|
44
|
-
"gptq_marlin_24": GPTQMarlin24Config,
|
45
|
-
"gptq_marlin": GPTQMarlinConfig,
|
46
|
-
"awq_marlin": AWQMarlinConfig,
|
47
|
-
"gptq": GPTQConfig,
|
48
|
-
"compressed-tensors": CompressedTensorsConfig,
|
49
|
-
"bitsandbytes": BitsAndBytesConfig,
|
50
|
-
"qqq": QQQConfig,
|
51
|
-
"experts_int8": ExpertsInt8Config,
|
52
59
|
"w8a8_int8": W8A8Int8Config,
|
53
60
|
"w8a8_fp8": W8A8Fp8Config,
|
61
|
+
"compressed-tensors": CompressedTensorsConfig,
|
54
62
|
}
|
55
63
|
|
64
|
+
# Add vllm-dependent methods if available
|
65
|
+
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
|
66
|
+
if VLLM_AVAILABLE:
|
67
|
+
VLLM_QUANTIZATION_METHODS = {
|
68
|
+
"aqlm": AQLMConfig,
|
69
|
+
"awq": AWQConfig,
|
70
|
+
"deepspeedfp": DeepSpeedFPConfig,
|
71
|
+
"tpu_int8": Int8TpuConfig,
|
72
|
+
"fbgemm_fp8": FBGEMMFp8Config,
|
73
|
+
"marlin": MarlinConfig,
|
74
|
+
"gguf": GGUFConfig,
|
75
|
+
"gptq_marlin_24": GPTQMarlin24Config,
|
76
|
+
"awq_marlin": AWQMarlinConfig,
|
77
|
+
"bitsandbytes": BitsAndBytesConfig,
|
78
|
+
"qqq": QQQConfig,
|
79
|
+
"experts_int8": ExpertsInt8Config,
|
80
|
+
"gptq_marlin": GPTQMarlinConfig,
|
81
|
+
"gptq": GPTQConfig,
|
82
|
+
}
|
83
|
+
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
84
|
+
|
56
85
|
|
57
86
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
58
87
|
if quantization not in QUANTIZATION_METHODS:
|
@@ -157,25 +186,31 @@ def get_linear_quant_method(
|
|
157
186
|
|
158
187
|
|
159
188
|
def gptq_get_quant_method(self, layer, prefix):
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
189
|
+
if not VLLM_AVAILABLE:
|
190
|
+
return None
|
191
|
+
|
192
|
+
try:
|
193
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
194
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
195
|
+
GPTQMarlinLinearMethod,
|
196
|
+
GPTQMarlinMoEMethod,
|
197
|
+
)
|
165
198
|
|
166
|
-
|
199
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
167
200
|
|
168
|
-
|
169
|
-
|
201
|
+
if isinstance(layer, FusedMoE):
|
202
|
+
return GPTQMarlinMoEMethod(self)
|
170
203
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
204
|
+
if isinstance(self, GPTQConfig):
|
205
|
+
return get_linear_quant_method(
|
206
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
207
|
+
)
|
208
|
+
elif isinstance(self, GPTQMarlinConfig):
|
209
|
+
return get_linear_quant_method(
|
210
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
211
|
+
)
|
212
|
+
except ImportError:
|
213
|
+
pass
|
179
214
|
return None
|
180
215
|
|
181
216
|
|
@@ -187,33 +222,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
187
222
|
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
188
223
|
can recognize sglang layers
|
189
224
|
"""
|
225
|
+
if not VLLM_AVAILABLE:
|
226
|
+
return
|
190
227
|
|
191
228
|
if reverse:
|
192
229
|
builtins.isinstance = original_isinstance
|
193
230
|
return
|
194
231
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
202
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
203
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
204
|
-
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
205
|
-
)
|
232
|
+
try:
|
233
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
234
|
+
from vllm.model_executor.layers.linear import LinearBase
|
235
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
236
|
+
VocabParallelEmbedding,
|
237
|
+
)
|
206
238
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
return original_isinstance(obj, classinfo)
|
239
|
+
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
240
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
241
|
+
FusedMoE as PatchedFusedMoE,
|
242
|
+
)
|
243
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
244
|
+
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
245
|
+
)
|
215
246
|
|
216
|
-
|
247
|
+
def patched_isinstance(obj, classinfo):
|
248
|
+
if classinfo is LinearBase:
|
249
|
+
return original_isinstance(obj, PatchedLinearBase)
|
250
|
+
if classinfo is FusedMoE:
|
251
|
+
return original_isinstance(obj, PatchedFusedMoE)
|
252
|
+
if classinfo is VocabParallelEmbedding:
|
253
|
+
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
254
|
+
return original_isinstance(obj, classinfo)
|
255
|
+
|
256
|
+
builtins.isinstance = patched_isinstance
|
257
|
+
except ImportError:
|
258
|
+
return
|
217
259
|
|
218
260
|
|
219
261
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
@@ -221,72 +263,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
221
263
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
222
264
|
Convert sglang arguments to vllm arguments.
|
223
265
|
"""
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
"
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
266
|
+
if not VLLM_AVAILABLE:
|
267
|
+
return
|
268
|
+
|
269
|
+
try:
|
270
|
+
original_apply = class_obj.apply
|
271
|
+
sig = inspect.signature(original_apply)
|
272
|
+
param_names = list(sig.parameters.keys())
|
273
|
+
has_correction_bias = "e_score_correction_bias" in param_names
|
274
|
+
|
275
|
+
def new_apply(
|
276
|
+
self,
|
277
|
+
layer: torch.nn.Module,
|
278
|
+
x: torch.Tensor,
|
279
|
+
router_logits: torch.Tensor,
|
280
|
+
top_k: int,
|
281
|
+
renormalize: bool,
|
282
|
+
use_grouped_topk: bool,
|
283
|
+
topk_group: Optional[int] = None,
|
284
|
+
num_expert_group: Optional[int] = None,
|
285
|
+
custom_routing_function: Optional[Callable] = None,
|
286
|
+
correction_bias: Optional[torch.Tensor] = None,
|
287
|
+
activation: str = "silu",
|
288
|
+
inplace: bool = True,
|
289
|
+
no_combine: bool = False,
|
290
|
+
):
|
291
|
+
assert activation == "silu"
|
292
|
+
assert inplace and not no_combine
|
293
|
+
|
294
|
+
kwargs = {
|
295
|
+
"self": self,
|
296
|
+
"layer": layer,
|
297
|
+
"x": x,
|
298
|
+
"router_logits": router_logits,
|
299
|
+
"top_k": top_k,
|
300
|
+
"renormalize": renormalize,
|
301
|
+
"use_grouped_topk": use_grouped_topk,
|
302
|
+
"topk_group": topk_group,
|
303
|
+
"num_expert_group": num_expert_group,
|
304
|
+
"custom_routing_function": custom_routing_function,
|
305
|
+
}
|
306
|
+
if correction_bias is not None:
|
307
|
+
if not has_correction_bias:
|
308
|
+
raise ValueError(
|
309
|
+
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
310
|
+
)
|
311
|
+
kwargs["e_score_correction_bias"] = correction_bias
|
312
|
+
return original_apply(**kwargs)
|
313
|
+
|
314
|
+
setattr(class_obj, "apply", new_apply)
|
315
|
+
except (ImportError, AttributeError):
|
316
|
+
return
|
269
317
|
|
270
318
|
|
271
319
|
def monkey_patch_quant_configs():
|
272
320
|
"""Apply all monkey patches in one place."""
|
273
|
-
|
274
|
-
|
275
|
-
CompressedTensorsW8A8Fp8MoEMethod,
|
276
|
-
CompressedTensorsWNA16MoEMethod,
|
277
|
-
)
|
278
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
|
321
|
+
if not VLLM_AVAILABLE:
|
322
|
+
return
|
279
323
|
|
280
|
-
|
281
|
-
|
324
|
+
try:
|
325
|
+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
326
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
327
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
328
|
+
CompressedTensorsWNA16MoEMethod,
|
329
|
+
)
|
330
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
331
|
+
GPTQMarlinMoEMethod,
|
332
|
+
)
|
282
333
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
334
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
335
|
+
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
336
|
+
|
337
|
+
monkey_patch_moe_apply(AWQMoEMethod)
|
338
|
+
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
339
|
+
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
340
|
+
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
341
|
+
except ImportError:
|
342
|
+
return
|
287
343
|
|
288
344
|
|
289
|
-
|
345
|
+
# Only apply monkey patches if vllm is available
|
346
|
+
if VLLM_AVAILABLE:
|
347
|
+
monkey_patch_quant_configs()
|
290
348
|
|
291
349
|
|
292
350
|
__all__ = [
|
@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
|
|
38
38
|
class QuantizationConfig(ABC):
|
39
39
|
"""Base class for quantization configs."""
|
40
40
|
|
41
|
+
def __init__(self):
|
42
|
+
super().__init__()
|
43
|
+
# mapping is updated by models as they initialize
|
44
|
+
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
45
|
+
|
41
46
|
@abstractmethod
|
42
47
|
def get_name(self) -> str:
|
43
48
|
"""Name of the quantization method."""
|
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
8
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
9
8
|
|
10
9
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
11
10
|
from sglang.srt.layers.linear import (
|
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
18
|
QuantizeMethodBase,
|
20
19
|
)
|
21
20
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
21
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
22
22
|
from sglang.srt.utils import set_weight_attrs
|
23
23
|
|
24
24
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
File without changes
|