sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -6,53 +6,98 @@ from copy import deepcopy
|
|
6
6
|
from typing import Callable, Dict, Optional, Type, Union
|
7
7
|
|
8
8
|
import torch
|
9
|
-
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
10
|
-
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
11
|
-
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
12
|
-
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
13
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
14
|
-
CompressedTensorsConfig,
|
15
|
-
)
|
16
|
-
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
17
|
-
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
18
|
-
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
19
|
-
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
20
|
-
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
21
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
22
|
-
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
23
|
-
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
24
9
|
|
10
|
+
try:
|
11
|
+
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
13
|
+
AWQMarlinConfig,
|
14
|
+
AWQMoEMethod,
|
15
|
+
)
|
16
|
+
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
17
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
18
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
19
|
+
CompressedTensorsWNA16MoEMethod,
|
20
|
+
)
|
21
|
+
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
22
|
+
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
23
|
+
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
24
|
+
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
25
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
|
+
GPTQMarlinLinearMethod,
|
28
|
+
GPTQMarlinMoEMethod,
|
29
|
+
)
|
30
|
+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
31
|
+
GPTQMarlin24Config,
|
32
|
+
)
|
33
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
34
|
+
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
35
|
+
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
36
|
+
|
37
|
+
VLLM_AVAILABLE = True
|
38
|
+
except ImportError:
|
39
|
+
VLLM_AVAILABLE = False
|
40
|
+
|
41
|
+
# Define empty classes as placeholders when vllm is not available
|
42
|
+
class DummyConfig:
|
43
|
+
def override_quantization_method(self, *args, **kwargs):
|
44
|
+
return None
|
45
|
+
|
46
|
+
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
47
|
+
DeepSpeedFPConfig
|
48
|
+
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
49
|
+
MarlinConfig
|
50
|
+
) = QQQConfig = Int8TpuConfig = DummyConfig
|
51
|
+
|
52
|
+
|
53
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
54
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
55
|
+
from sglang.srt.layers.quantization.awq import AWQConfig
|
25
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
57
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
58
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
59
|
+
CompressedTensorsConfig,
|
60
|
+
)
|
27
61
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
28
62
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
29
63
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
30
64
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
31
65
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
+
ParallelLMHead,
|
68
|
+
UnquantizedEmbeddingMethod,
|
69
|
+
)
|
32
70
|
|
33
|
-
|
71
|
+
# Base quantization methods that don't depend on vllm
|
72
|
+
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
73
|
+
"fp8": Fp8Config,
|
74
|
+
"blockwise_int8": BlockInt8Config,
|
75
|
+
"modelopt": ModelOptFp8Config,
|
76
|
+
"w8a8_int8": W8A8Int8Config,
|
77
|
+
"w8a8_fp8": W8A8Fp8Config,
|
78
|
+
"compressed-tensors": CompressedTensorsConfig,
|
79
|
+
}
|
80
|
+
|
81
|
+
# VLLM-dependent quantization methods
|
82
|
+
VLLM_QUANTIZATION_METHODS = {
|
34
83
|
"aqlm": AQLMConfig,
|
35
84
|
"awq": AWQConfig,
|
36
85
|
"deepspeedfp": DeepSpeedFPConfig,
|
37
86
|
"tpu_int8": Int8TpuConfig,
|
38
|
-
"fp8": Fp8Config,
|
39
|
-
"blockwise_int8": BlockInt8Config,
|
40
87
|
"fbgemm_fp8": FBGEMMFp8Config,
|
41
88
|
"marlin": MarlinConfig,
|
42
|
-
"modelopt": ModelOptFp8Config,
|
43
89
|
"gguf": GGUFConfig,
|
44
90
|
"gptq_marlin_24": GPTQMarlin24Config,
|
45
|
-
"gptq_marlin": GPTQMarlinConfig,
|
46
91
|
"awq_marlin": AWQMarlinConfig,
|
47
|
-
"gptq": GPTQConfig,
|
48
|
-
"compressed-tensors": CompressedTensorsConfig,
|
49
92
|
"bitsandbytes": BitsAndBytesConfig,
|
50
93
|
"qqq": QQQConfig,
|
51
94
|
"experts_int8": ExpertsInt8Config,
|
52
|
-
"
|
53
|
-
"
|
95
|
+
"gptq_marlin": GPTQMarlinConfig,
|
96
|
+
"gptq": GPTQConfig,
|
54
97
|
}
|
55
98
|
|
99
|
+
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
100
|
+
|
56
101
|
|
57
102
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
58
103
|
if quantization not in QUANTIZATION_METHODS:
|
@@ -60,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
60
105
|
f"Invalid quantization method: {quantization}. "
|
61
106
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
62
107
|
)
|
108
|
+
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
109
|
+
raise ValueError(
|
110
|
+
f"{quantization} quantization requires some operators from vllm. "
|
111
|
+
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
112
|
+
)
|
113
|
+
|
63
114
|
return QUANTIZATION_METHODS[quantization]
|
64
115
|
|
65
116
|
|
@@ -124,13 +175,6 @@ def get_linear_quant_method(
|
|
124
175
|
prefix: str,
|
125
176
|
linear_method_cls: type,
|
126
177
|
):
|
127
|
-
|
128
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
129
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
130
|
-
ParallelLMHead,
|
131
|
-
UnquantizedEmbeddingMethod,
|
132
|
-
)
|
133
|
-
|
134
178
|
cloned_config = deepcopy(config)
|
135
179
|
parallel_lm_head_quantized = (
|
136
180
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -157,14 +201,6 @@ def get_linear_quant_method(
|
|
157
201
|
|
158
202
|
|
159
203
|
def gptq_get_quant_method(self, layer, prefix):
|
160
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
161
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
162
|
-
GPTQMarlinLinearMethod,
|
163
|
-
GPTQMarlinMoEMethod,
|
164
|
-
)
|
165
|
-
|
166
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
167
|
-
|
168
204
|
if isinstance(layer, FusedMoE):
|
169
205
|
return GPTQMarlinMoEMethod(self)
|
170
206
|
|
@@ -187,6 +223,8 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
187
223
|
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
188
224
|
can recognize sglang layers
|
189
225
|
"""
|
226
|
+
if not VLLM_AVAILABLE:
|
227
|
+
return
|
190
228
|
|
191
229
|
if reverse:
|
192
230
|
builtins.isinstance = original_isinstance
|
@@ -270,13 +308,6 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
270
308
|
|
271
309
|
def monkey_patch_quant_configs():
|
272
310
|
"""Apply all monkey patches in one place."""
|
273
|
-
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
274
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
275
|
-
CompressedTensorsW8A8Fp8MoEMethod,
|
276
|
-
CompressedTensorsWNA16MoEMethod,
|
277
|
-
)
|
278
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
|
279
|
-
|
280
311
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
281
312
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
282
313
|
|
@@ -286,10 +317,6 @@ def monkey_patch_quant_configs():
|
|
286
317
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
287
318
|
|
288
319
|
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
__all__ = [
|
293
|
-
"get_quantization_config",
|
294
|
-
"QUANTIZATION_METHODS",
|
295
|
-
]
|
320
|
+
# Only apply monkey patches if vllm is available
|
321
|
+
if VLLM_AVAILABLE:
|
322
|
+
monkey_patch_quant_configs()
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
import logging
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import awq_dequantize
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import (
|
9
|
+
LinearBase,
|
10
|
+
LinearMethodBase,
|
11
|
+
UnquantizedLinearMethod,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
14
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
20
|
+
return any(module_name in prefix for module_name in modules_to_not_convert)
|
21
|
+
|
22
|
+
|
23
|
+
class AWQConfig(QuantizationConfig):
|
24
|
+
"""Config class for AWQ.
|
25
|
+
|
26
|
+
Reference: https://arxiv.org/abs/2306.00978
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
weight_bits: int,
|
32
|
+
group_size: int,
|
33
|
+
zero_point: bool,
|
34
|
+
modules_to_not_convert: Optional[List[str]] = None,
|
35
|
+
) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.weight_bits = weight_bits
|
38
|
+
self.group_size = group_size
|
39
|
+
self.zero_point = zero_point
|
40
|
+
self.modules_to_not_convert = modules_to_not_convert or []
|
41
|
+
|
42
|
+
if self.weight_bits != 4:
|
43
|
+
raise ValueError(
|
44
|
+
"Currently, only 4-bit weight quantization is supported for "
|
45
|
+
f"AWQ, but got {self.weight_bits} bits."
|
46
|
+
)
|
47
|
+
self.pack_factor = 32 // self.weight_bits
|
48
|
+
|
49
|
+
def __repr__(self) -> str:
|
50
|
+
return (
|
51
|
+
f"AWQConfig(weight_bits={self.weight_bits}, "
|
52
|
+
f"group_size={self.group_size}, "
|
53
|
+
f"zero_point={self.zero_point}, "
|
54
|
+
f"modules_to_not_convert={self.modules_to_not_convert})"
|
55
|
+
)
|
56
|
+
|
57
|
+
def get_scaled_act_names(self) -> List[str]:
|
58
|
+
return []
|
59
|
+
|
60
|
+
def get_name(self) -> str:
|
61
|
+
return "awq"
|
62
|
+
|
63
|
+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
64
|
+
return [torch.half]
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def get_min_capability(cls) -> int:
|
68
|
+
# The AWQ kernel only supports Turing or newer GPUs.
|
69
|
+
return 75
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def get_config_filenames() -> List[str]:
|
73
|
+
return [
|
74
|
+
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
75
|
+
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
76
|
+
"quantize_config.json",
|
77
|
+
]
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
81
|
+
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
82
|
+
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
83
|
+
zero_point = cls.get_from_keys(config, ["zero_point"])
|
84
|
+
modules_to_not_convert = cls.get_from_keys_or(
|
85
|
+
config, ["modules_to_not_convert"], None
|
86
|
+
)
|
87
|
+
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
88
|
+
|
89
|
+
def get_quant_method(
|
90
|
+
self, layer: torch.nn.Module, prefix: str
|
91
|
+
) -> Optional["LinearMethodBase"]:
|
92
|
+
|
93
|
+
if isinstance(layer, LinearBase):
|
94
|
+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
95
|
+
return UnquantizedLinearMethod()
|
96
|
+
return AWQLinearMethod(self)
|
97
|
+
return None
|
98
|
+
|
99
|
+
|
100
|
+
class AWQLinearMethod(LinearMethodBase):
|
101
|
+
"""Linear method for AWQ.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
quant_config: The AWQ quantization config.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self, quant_config: AWQConfig):
|
108
|
+
self.quant_config = quant_config
|
109
|
+
|
110
|
+
def create_weights(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
input_size_per_partition: int,
|
114
|
+
output_partition_sizes: List[int],
|
115
|
+
input_size: int,
|
116
|
+
output_size: int,
|
117
|
+
params_dtype: torch.dtype,
|
118
|
+
**extra_weight_attrs,
|
119
|
+
):
|
120
|
+
if input_size_per_partition % self.quant_config.group_size != 0:
|
121
|
+
raise ValueError(
|
122
|
+
"The input size is not aligned with the quantized "
|
123
|
+
"weight shape. This can be caused by too large "
|
124
|
+
"tensor parallel size."
|
125
|
+
)
|
126
|
+
|
127
|
+
output_size_per_partition = sum(output_partition_sizes)
|
128
|
+
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
129
|
+
raise ValueError(
|
130
|
+
"The output size is not aligned with the quantized "
|
131
|
+
"weight shape. This can be caused by too large "
|
132
|
+
"tensor parallel size."
|
133
|
+
)
|
134
|
+
|
135
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
136
|
+
qweight = PackedvLLMParameter(
|
137
|
+
data=torch.empty(
|
138
|
+
input_size_per_partition,
|
139
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
140
|
+
dtype=torch.int32,
|
141
|
+
),
|
142
|
+
input_dim=0,
|
143
|
+
output_dim=1,
|
144
|
+
packed_dim=1,
|
145
|
+
packed_factor=self.quant_config.pack_factor,
|
146
|
+
weight_loader=weight_loader,
|
147
|
+
)
|
148
|
+
|
149
|
+
qzeros = PackedvLLMParameter(
|
150
|
+
data=torch.empty(
|
151
|
+
input_size_per_partition // self.quant_config.group_size,
|
152
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
153
|
+
dtype=torch.int32,
|
154
|
+
),
|
155
|
+
input_dim=0,
|
156
|
+
output_dim=1,
|
157
|
+
packed_dim=1,
|
158
|
+
packed_factor=self.quant_config.pack_factor,
|
159
|
+
weight_loader=weight_loader,
|
160
|
+
)
|
161
|
+
|
162
|
+
scales = GroupQuantScaleParameter(
|
163
|
+
data=torch.empty(
|
164
|
+
input_size_per_partition // self.quant_config.group_size,
|
165
|
+
output_size_per_partition,
|
166
|
+
dtype=params_dtype,
|
167
|
+
),
|
168
|
+
input_dim=0,
|
169
|
+
output_dim=1,
|
170
|
+
weight_loader=weight_loader,
|
171
|
+
)
|
172
|
+
|
173
|
+
layer.register_parameter("qweight", qweight)
|
174
|
+
layer.register_parameter("qzeros", qzeros)
|
175
|
+
layer.register_parameter("scales", scales)
|
176
|
+
|
177
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
178
|
+
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
179
|
+
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
180
|
+
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
181
|
+
|
182
|
+
def apply(
|
183
|
+
self,
|
184
|
+
layer: torch.nn.Module,
|
185
|
+
x: torch.Tensor,
|
186
|
+
bias: Optional[torch.Tensor] = None,
|
187
|
+
) -> torch.Tensor:
|
188
|
+
qweight = layer.qweight
|
189
|
+
scales = layer.scales
|
190
|
+
qzeros = layer.qzeros
|
191
|
+
pack_factor = self.quant_config.pack_factor
|
192
|
+
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
193
|
+
reshaped_x = x.reshape(-1, x.shape[-1])
|
194
|
+
|
195
|
+
out = awq_dequantize(qweight, scales, qzeros)
|
196
|
+
out = torch.matmul(reshaped_x, out)
|
197
|
+
|
198
|
+
if bias is not None:
|
199
|
+
out.add_(bias)
|
200
|
+
return out.reshape(out_shape)
|
@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
|
|
38
38
|
class QuantizationConfig(ABC):
|
39
39
|
"""Base class for quantization configs."""
|
40
40
|
|
41
|
+
def __init__(self):
|
42
|
+
super().__init__()
|
43
|
+
# mapping is updated by models as they initialize
|
44
|
+
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
45
|
+
|
41
46
|
@abstractmethod
|
42
47
|
def get_name(self) -> str:
|
43
48
|
"""Name of the quantization method."""
|
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
8
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
9
8
|
|
10
9
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
11
10
|
from sglang.srt.layers.linear import (
|
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
18
|
QuantizeMethodBase,
|
20
19
|
)
|
21
20
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
21
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
22
22
|
from sglang.srt.utils import set_weight_attrs
|
23
23
|
|
24
24
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
File without changes
|