sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,17 +12,19 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
+
import os
|
15
16
|
from typing import Callable, Optional
|
16
17
|
|
17
18
|
import torch
|
18
19
|
import torch.nn.functional as F
|
19
20
|
|
21
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
22
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
20
23
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
21
24
|
|
22
25
|
_is_cuda = is_cuda()
|
23
26
|
_is_hip = is_hip()
|
24
27
|
|
25
|
-
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
26
28
|
|
27
29
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
28
30
|
|
@@ -102,11 +104,13 @@ def grouped_topk(
|
|
102
104
|
renormalize: bool,
|
103
105
|
num_expert_group: int = 0,
|
104
106
|
topk_group: int = 0,
|
107
|
+
n_share_experts_fusion: int = 0,
|
105
108
|
):
|
106
109
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
107
110
|
|
108
111
|
scores = torch.softmax(gating_output, dim=-1)
|
109
112
|
num_token = scores.shape[0]
|
113
|
+
num_experts = scores.shape[1]
|
110
114
|
group_scores = (
|
111
115
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
112
116
|
) # [n, n_group]
|
@@ -122,15 +126,30 @@ def grouped_topk(
|
|
122
126
|
) # [n, e]
|
123
127
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
124
128
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
129
|
+
if n_share_experts_fusion:
|
130
|
+
topk_ids[:, -1] = torch.randint(
|
131
|
+
low=num_experts,
|
132
|
+
high=num_experts + n_share_experts_fusion,
|
133
|
+
size=(topk_ids.size(0),),
|
134
|
+
dtype=topk_ids.dtype,
|
135
|
+
device=topk_ids.device,
|
136
|
+
)
|
137
|
+
topk_weights[:, -1] = (
|
138
|
+
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
139
|
+
) # 2.5 is the routed_scaling_factor.
|
125
140
|
|
126
141
|
if renormalize:
|
127
|
-
|
142
|
+
topk_weights_sum = (
|
143
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
144
|
+
if n_share_experts_fusion == 0
|
145
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
146
|
+
)
|
147
|
+
topk_weights = topk_weights / topk_weights_sum
|
128
148
|
|
129
149
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
130
150
|
|
131
151
|
|
132
|
-
|
133
|
-
def biased_grouped_topk(
|
152
|
+
def biased_grouped_topk_impl(
|
134
153
|
hidden_states: torch.Tensor,
|
135
154
|
gating_output: torch.Tensor,
|
136
155
|
correction_bias: torch.Tensor,
|
@@ -138,11 +157,13 @@ def biased_grouped_topk(
|
|
138
157
|
renormalize: bool,
|
139
158
|
num_expert_group: int = 0,
|
140
159
|
topk_group: int = 0,
|
160
|
+
n_share_experts_fusion: int = 0,
|
141
161
|
):
|
142
162
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
143
163
|
|
144
164
|
scores = gating_output.sigmoid()
|
145
165
|
num_token = scores.shape[0]
|
166
|
+
num_experts = scores.shape[1]
|
146
167
|
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
147
168
|
group_scores = (
|
148
169
|
scores_for_choice.view(num_token, num_expert_group, -1)
|
@@ -165,12 +186,59 @@ def biased_grouped_topk(
|
|
165
186
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
166
187
|
topk_weights = scores.gather(1, topk_ids)
|
167
188
|
|
189
|
+
if n_share_experts_fusion:
|
190
|
+
topk_ids[:, -1] = torch.randint(
|
191
|
+
low=num_experts,
|
192
|
+
high=num_experts + n_share_experts_fusion,
|
193
|
+
size=(topk_ids.size(0),),
|
194
|
+
dtype=topk_ids.dtype,
|
195
|
+
device=topk_ids.device,
|
196
|
+
)
|
197
|
+
topk_weights[:, -1] = (
|
198
|
+
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
199
|
+
) # 2.5 is the routed_scaling_factor.
|
200
|
+
|
168
201
|
if renormalize:
|
169
|
-
|
202
|
+
topk_weights_sum = (
|
203
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
204
|
+
if n_share_experts_fusion == 0
|
205
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
206
|
+
)
|
207
|
+
topk_weights = topk_weights / topk_weights_sum
|
170
208
|
|
171
209
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
172
210
|
|
173
211
|
|
212
|
+
def biased_grouped_topk(
|
213
|
+
hidden_states: torch.Tensor,
|
214
|
+
gating_output: torch.Tensor,
|
215
|
+
correction_bias: torch.Tensor,
|
216
|
+
topk: int,
|
217
|
+
renormalize: bool,
|
218
|
+
num_expert_group: int = 0,
|
219
|
+
topk_group: int = 0,
|
220
|
+
compiled: bool = True,
|
221
|
+
n_share_experts_fusion: int = 0,
|
222
|
+
):
|
223
|
+
biased_grouped_topk_fn = (
|
224
|
+
torch.compile(
|
225
|
+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
226
|
+
)
|
227
|
+
if compiled
|
228
|
+
else biased_grouped_topk_impl
|
229
|
+
)
|
230
|
+
return biased_grouped_topk_fn(
|
231
|
+
hidden_states,
|
232
|
+
gating_output,
|
233
|
+
correction_bias,
|
234
|
+
topk,
|
235
|
+
renormalize,
|
236
|
+
num_expert_group,
|
237
|
+
topk_group,
|
238
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
239
|
+
)
|
240
|
+
|
241
|
+
|
174
242
|
def select_experts(
|
175
243
|
hidden_states: torch.Tensor,
|
176
244
|
router_logits: torch.Tensor,
|
@@ -183,7 +251,10 @@ def select_experts(
|
|
183
251
|
correction_bias: Optional[torch.Tensor] = None,
|
184
252
|
torch_native: bool = False,
|
185
253
|
):
|
186
|
-
|
254
|
+
n_share_experts_fusion = 0
|
255
|
+
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
256
|
+
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
257
|
+
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
187
258
|
if use_grouped_topk:
|
188
259
|
assert topk_group is not None
|
189
260
|
assert num_expert_group is not None
|
@@ -195,6 +266,7 @@ def select_experts(
|
|
195
266
|
renormalize=renormalize,
|
196
267
|
num_expert_group=num_expert_group,
|
197
268
|
topk_group=topk_group,
|
269
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
198
270
|
)
|
199
271
|
else:
|
200
272
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -205,6 +277,7 @@ def select_experts(
|
|
205
277
|
renormalize=renormalize,
|
206
278
|
num_expert_group=num_expert_group,
|
207
279
|
topk_group=topk_group,
|
280
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
208
281
|
)
|
209
282
|
elif torch_native and custom_routing_function is None:
|
210
283
|
topk_weights, topk_ids = fused_topk_native(
|
@@ -9,13 +9,24 @@ import torch
|
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
-
from vllm.model_executor.layers.quantization.
|
13
|
-
|
12
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
13
|
+
AWQMarlinConfig,
|
14
|
+
AWQMoEMethod,
|
15
|
+
)
|
14
16
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
17
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
18
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
19
|
+
CompressedTensorsWNA16MoEMethod,
|
20
|
+
)
|
15
21
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
16
22
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
17
23
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
18
24
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
25
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
|
+
GPTQMarlinLinearMethod,
|
28
|
+
GPTQMarlinMoEMethod,
|
29
|
+
)
|
19
30
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
20
31
|
GPTQMarlin24Config,
|
21
32
|
)
|
@@ -23,33 +34,39 @@ try:
|
|
23
34
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
24
35
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
25
36
|
|
26
|
-
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
27
|
-
|
28
37
|
VLLM_AVAILABLE = True
|
29
38
|
except ImportError:
|
30
39
|
VLLM_AVAILABLE = False
|
31
40
|
|
32
41
|
# Define empty classes as placeholders when vllm is not available
|
33
42
|
class DummyConfig:
|
34
|
-
|
43
|
+
def override_quantization_method(self, *args, **kwargs):
|
44
|
+
return None
|
45
|
+
|
46
|
+
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
47
|
+
DeepSpeedFPConfig
|
48
|
+
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
49
|
+
MarlinConfig
|
50
|
+
) = QQQConfig = Int8TpuConfig = DummyConfig
|
35
51
|
|
36
|
-
AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
|
37
|
-
CompressedTensorsConfig
|
38
|
-
) = DummyConfig
|
39
|
-
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
40
|
-
GPTQMarlin24Config
|
41
|
-
) = DummyConfig
|
42
|
-
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
43
52
|
|
53
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
54
|
+
from sglang.srt.layers.quantization.awq import AWQConfig
|
44
55
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
56
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
46
57
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
47
58
|
CompressedTensorsConfig,
|
48
59
|
)
|
49
60
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
50
62
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
63
|
+
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
51
64
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
52
65
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
+
ParallelLMHead,
|
68
|
+
UnquantizedEmbeddingMethod,
|
69
|
+
)
|
53
70
|
|
54
71
|
# Base quantization methods that don't depend on vllm
|
55
72
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
@@ -58,29 +75,29 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
58
75
|
"modelopt": ModelOptFp8Config,
|
59
76
|
"w8a8_int8": W8A8Int8Config,
|
60
77
|
"w8a8_fp8": W8A8Fp8Config,
|
78
|
+
"moe_wna16": MoeWNA16Config,
|
61
79
|
"compressed-tensors": CompressedTensorsConfig,
|
62
80
|
}
|
63
81
|
|
64
|
-
#
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
82
|
+
# VLLM-dependent quantization methods
|
83
|
+
VLLM_QUANTIZATION_METHODS = {
|
84
|
+
"aqlm": AQLMConfig,
|
85
|
+
"awq": AWQConfig,
|
86
|
+
"deepspeedfp": DeepSpeedFPConfig,
|
87
|
+
"tpu_int8": Int8TpuConfig,
|
88
|
+
"fbgemm_fp8": FBGEMMFp8Config,
|
89
|
+
"marlin": MarlinConfig,
|
90
|
+
"gguf": GGUFConfig,
|
91
|
+
"gptq_marlin_24": GPTQMarlin24Config,
|
92
|
+
"awq_marlin": AWQMarlinConfig,
|
93
|
+
"bitsandbytes": BitsAndBytesConfig,
|
94
|
+
"qqq": QQQConfig,
|
95
|
+
"experts_int8": ExpertsInt8Config,
|
96
|
+
"gptq_marlin": GPTQMarlinConfig,
|
97
|
+
"gptq": GPTQConfig,
|
98
|
+
}
|
99
|
+
|
100
|
+
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
84
101
|
|
85
102
|
|
86
103
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
@@ -89,6 +106,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
89
106
|
f"Invalid quantization method: {quantization}. "
|
90
107
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
91
108
|
)
|
109
|
+
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
110
|
+
raise ValueError(
|
111
|
+
f"{quantization} quantization requires some operators from vllm. "
|
112
|
+
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
113
|
+
)
|
114
|
+
|
92
115
|
return QUANTIZATION_METHODS[quantization]
|
93
116
|
|
94
117
|
|
@@ -153,13 +176,6 @@ def get_linear_quant_method(
|
|
153
176
|
prefix: str,
|
154
177
|
linear_method_cls: type,
|
155
178
|
):
|
156
|
-
|
157
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
158
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
159
|
-
ParallelLMHead,
|
160
|
-
UnquantizedEmbeddingMethod,
|
161
|
-
)
|
162
|
-
|
163
179
|
cloned_config = deepcopy(config)
|
164
180
|
parallel_lm_head_quantized = (
|
165
181
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -186,31 +202,19 @@ def get_linear_quant_method(
|
|
186
202
|
|
187
203
|
|
188
204
|
def gptq_get_quant_method(self, layer, prefix):
|
189
|
-
|
190
|
-
return None
|
205
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
191
206
|
|
192
|
-
|
193
|
-
|
194
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
195
|
-
GPTQMarlinLinearMethod,
|
196
|
-
GPTQMarlinMoEMethod,
|
197
|
-
)
|
207
|
+
if isinstance(layer, FusedMoE):
|
208
|
+
return GPTQMarlinMoEMethod(self)
|
198
209
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
)
|
208
|
-
elif isinstance(self, GPTQMarlinConfig):
|
209
|
-
return get_linear_quant_method(
|
210
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
211
|
-
)
|
212
|
-
except ImportError:
|
213
|
-
pass
|
210
|
+
if isinstance(self, GPTQConfig):
|
211
|
+
return get_linear_quant_method(
|
212
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
213
|
+
)
|
214
|
+
elif isinstance(self, GPTQMarlinConfig):
|
215
|
+
return get_linear_quant_method(
|
216
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
217
|
+
)
|
214
218
|
return None
|
215
219
|
|
216
220
|
|
@@ -229,33 +233,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
229
233
|
builtins.isinstance = original_isinstance
|
230
234
|
return
|
231
235
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
)
|
236
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
237
|
+
from vllm.model_executor.layers.linear import LinearBase
|
238
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
239
|
+
VocabParallelEmbedding,
|
240
|
+
)
|
238
241
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
245
|
-
)
|
242
|
+
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
243
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
244
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
245
|
+
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
246
|
+
)
|
246
247
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
except ImportError:
|
258
|
-
return
|
248
|
+
def patched_isinstance(obj, classinfo):
|
249
|
+
if classinfo is LinearBase:
|
250
|
+
return original_isinstance(obj, PatchedLinearBase)
|
251
|
+
if classinfo is FusedMoE:
|
252
|
+
return original_isinstance(obj, PatchedFusedMoE)
|
253
|
+
if classinfo is VocabParallelEmbedding:
|
254
|
+
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
255
|
+
return original_isinstance(obj, classinfo)
|
256
|
+
|
257
|
+
builtins.isinstance = patched_isinstance
|
259
258
|
|
260
259
|
|
261
260
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
@@ -263,91 +262,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
263
262
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
264
263
|
Convert sglang arguments to vllm arguments.
|
265
264
|
"""
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
kwargs["e_score_correction_bias"] = correction_bias
|
312
|
-
return original_apply(**kwargs)
|
313
|
-
|
314
|
-
setattr(class_obj, "apply", new_apply)
|
315
|
-
except (ImportError, AttributeError):
|
316
|
-
return
|
265
|
+
original_apply = class_obj.apply
|
266
|
+
sig = inspect.signature(original_apply)
|
267
|
+
param_names = list(sig.parameters.keys())
|
268
|
+
has_correction_bias = "e_score_correction_bias" in param_names
|
269
|
+
|
270
|
+
def new_apply(
|
271
|
+
self,
|
272
|
+
layer: torch.nn.Module,
|
273
|
+
x: torch.Tensor,
|
274
|
+
router_logits: torch.Tensor,
|
275
|
+
top_k: int,
|
276
|
+
renormalize: bool,
|
277
|
+
use_grouped_topk: bool,
|
278
|
+
topk_group: Optional[int] = None,
|
279
|
+
num_expert_group: Optional[int] = None,
|
280
|
+
custom_routing_function: Optional[Callable] = None,
|
281
|
+
correction_bias: Optional[torch.Tensor] = None,
|
282
|
+
activation: str = "silu",
|
283
|
+
inplace: bool = True,
|
284
|
+
no_combine: bool = False,
|
285
|
+
):
|
286
|
+
assert activation == "silu"
|
287
|
+
assert inplace and not no_combine
|
288
|
+
|
289
|
+
kwargs = {
|
290
|
+
"self": self,
|
291
|
+
"layer": layer,
|
292
|
+
"x": x,
|
293
|
+
"router_logits": router_logits,
|
294
|
+
"top_k": top_k,
|
295
|
+
"renormalize": renormalize,
|
296
|
+
"use_grouped_topk": use_grouped_topk,
|
297
|
+
"topk_group": topk_group,
|
298
|
+
"num_expert_group": num_expert_group,
|
299
|
+
"custom_routing_function": custom_routing_function,
|
300
|
+
}
|
301
|
+
if correction_bias is not None:
|
302
|
+
if not has_correction_bias:
|
303
|
+
raise ValueError(
|
304
|
+
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
305
|
+
)
|
306
|
+
kwargs["e_score_correction_bias"] = correction_bias
|
307
|
+
return original_apply(**kwargs)
|
308
|
+
|
309
|
+
setattr(class_obj, "apply", new_apply)
|
317
310
|
|
318
311
|
|
319
312
|
def monkey_patch_quant_configs():
|
320
313
|
"""Apply all monkey patches in one place."""
|
321
|
-
|
322
|
-
|
314
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
315
|
+
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
323
316
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
CompressedTensorsWNA16MoEMethod,
|
329
|
-
)
|
330
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
331
|
-
GPTQMarlinMoEMethod,
|
332
|
-
)
|
333
|
-
|
334
|
-
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
335
|
-
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
336
|
-
|
337
|
-
monkey_patch_moe_apply(AWQMoEMethod)
|
338
|
-
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
339
|
-
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
340
|
-
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
341
|
-
except ImportError:
|
342
|
-
return
|
317
|
+
monkey_patch_moe_apply(AWQMoEMethod)
|
318
|
+
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
319
|
+
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
320
|
+
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
343
321
|
|
344
322
|
|
345
323
|
# Only apply monkey patches if vllm is available
|
346
324
|
if VLLM_AVAILABLE:
|
347
325
|
monkey_patch_quant_configs()
|
348
|
-
|
349
|
-
|
350
|
-
__all__ = [
|
351
|
-
"get_quantization_config",
|
352
|
-
"QUANTIZATION_METHODS",
|
353
|
-
]
|