sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,22 +12,21 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
+
from __future__ import annotations
|
16
|
+
|
15
17
|
import math
|
16
|
-
from typing import Callable, Optional
|
18
|
+
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
|
17
19
|
|
18
20
|
import torch
|
19
21
|
import torch.nn.functional as F
|
20
22
|
|
23
|
+
from sglang.srt.custom_op import CustomOp
|
21
24
|
from sglang.srt.eplb import expert_location_dispatch
|
22
|
-
from sglang.srt.eplb.expert_distribution import
|
23
|
-
ExpertDistributionRecorder,
|
24
|
-
get_global_expert_distribution_recorder,
|
25
|
-
)
|
25
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
26
26
|
from sglang.srt.eplb.expert_location_dispatch import (
|
27
27
|
ExpertLocationDispatchInfo,
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
30
|
from sglang.srt.utils import (
|
32
31
|
cpu_has_amx_support,
|
33
32
|
get_bool_env_var,
|
@@ -56,6 +55,168 @@ if _use_aiter:
|
|
56
55
|
except ImportError:
|
57
56
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
58
57
|
|
58
|
+
if _is_npu:
|
59
|
+
import torch_npu
|
60
|
+
|
61
|
+
|
62
|
+
class TopKOutput(NamedTuple):
|
63
|
+
topk_weights: torch.Tensor
|
64
|
+
topk_ids: torch.Tensor
|
65
|
+
router_logits: torch.Tensor
|
66
|
+
|
67
|
+
|
68
|
+
class TopK(CustomOp):
|
69
|
+
|
70
|
+
# TODO(ch-wan): support triton_kernels
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
top_k: int,
|
75
|
+
*,
|
76
|
+
use_grouped_topk: bool = False,
|
77
|
+
topk_group: Optional[int] = None,
|
78
|
+
num_expert_group: Optional[int] = None,
|
79
|
+
renormalize: bool = True,
|
80
|
+
num_fused_shared_experts: int = 0,
|
81
|
+
custom_routing_function: Optional[Callable] = None,
|
82
|
+
scoring_func: str = "softmax",
|
83
|
+
correction_bias: Optional[torch.Tensor] = None,
|
84
|
+
routed_scaling_factor: Optional[float] = None,
|
85
|
+
):
|
86
|
+
# NOTE: scoring_func is not used for now, but we keep it for future use
|
87
|
+
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
88
|
+
super().__init__()
|
89
|
+
if use_grouped_topk:
|
90
|
+
assert num_expert_group is not None and topk_group is not None
|
91
|
+
self.top_k = top_k
|
92
|
+
self.use_grouped_topk = use_grouped_topk
|
93
|
+
self.renormalize = renormalize
|
94
|
+
self.topk_group = topk_group
|
95
|
+
self.num_expert_group = num_expert_group
|
96
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
97
|
+
self.custom_routing_function = custom_routing_function
|
98
|
+
self.correction_bias = correction_bias
|
99
|
+
self.routed_scaling_factor = routed_scaling_factor
|
100
|
+
|
101
|
+
def forward_native(
|
102
|
+
self,
|
103
|
+
hidden_states: torch.Tensor,
|
104
|
+
router_logits: torch.Tensor,
|
105
|
+
*,
|
106
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
107
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
108
|
+
) -> TopKOutput:
|
109
|
+
torch_native = True
|
110
|
+
return select_experts(
|
111
|
+
hidden_states=hidden_states,
|
112
|
+
router_logits=router_logits,
|
113
|
+
top_k=self.top_k,
|
114
|
+
use_grouped_topk=self.use_grouped_topk,
|
115
|
+
renormalize=self.renormalize,
|
116
|
+
topk_group=self.topk_group,
|
117
|
+
num_expert_group=self.num_expert_group,
|
118
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
119
|
+
custom_routing_function=self.custom_routing_function,
|
120
|
+
correction_bias=self.correction_bias,
|
121
|
+
torch_native=torch_native,
|
122
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
123
|
+
num_token_non_padded=num_token_non_padded,
|
124
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
125
|
+
)
|
126
|
+
|
127
|
+
def forward_cuda(
|
128
|
+
self,
|
129
|
+
hidden_states: torch.Tensor,
|
130
|
+
router_logits: torch.Tensor,
|
131
|
+
*,
|
132
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
133
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
134
|
+
) -> TopKOutput:
|
135
|
+
torch_native = False
|
136
|
+
return select_experts(
|
137
|
+
hidden_states=hidden_states,
|
138
|
+
router_logits=router_logits,
|
139
|
+
top_k=self.top_k,
|
140
|
+
use_grouped_topk=self.use_grouped_topk,
|
141
|
+
renormalize=self.renormalize,
|
142
|
+
topk_group=self.topk_group,
|
143
|
+
num_expert_group=self.num_expert_group,
|
144
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
145
|
+
custom_routing_function=self.custom_routing_function,
|
146
|
+
correction_bias=self.correction_bias,
|
147
|
+
torch_native=torch_native,
|
148
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
149
|
+
num_token_non_padded=num_token_non_padded,
|
150
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
151
|
+
)
|
152
|
+
|
153
|
+
def forward_cpu(
|
154
|
+
self,
|
155
|
+
hidden_states: torch.Tensor,
|
156
|
+
router_logits: torch.Tensor,
|
157
|
+
*,
|
158
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
159
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
160
|
+
) -> TopKOutput:
|
161
|
+
return select_experts(
|
162
|
+
hidden_states=hidden_states,
|
163
|
+
router_logits=router_logits,
|
164
|
+
top_k=self.top_k,
|
165
|
+
use_grouped_topk=self.use_grouped_topk,
|
166
|
+
renormalize=self.renormalize,
|
167
|
+
topk_group=self.topk_group,
|
168
|
+
num_expert_group=self.num_expert_group,
|
169
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
170
|
+
custom_routing_function=self.custom_routing_function,
|
171
|
+
correction_bias=self.correction_bias,
|
172
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
173
|
+
num_token_non_padded=num_token_non_padded,
|
174
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
175
|
+
)
|
176
|
+
|
177
|
+
def forward_npu(
|
178
|
+
self,
|
179
|
+
hidden_states: torch.Tensor,
|
180
|
+
router_logits: torch.Tensor,
|
181
|
+
*,
|
182
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
183
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
184
|
+
) -> TopKOutput:
|
185
|
+
global_num_experts = router_logits.shape[-1]
|
186
|
+
|
187
|
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
188
|
+
if global_num_experts == 256:
|
189
|
+
return torch_npu.npu_moe_gating_top_k(
|
190
|
+
router_logits,
|
191
|
+
k=self.top_k,
|
192
|
+
bias=self.correction_bias,
|
193
|
+
k_group=self.topk_group,
|
194
|
+
group_count=self.num_expert_group,
|
195
|
+
group_select_mode=1,
|
196
|
+
renorm=0,
|
197
|
+
norm_type=1,
|
198
|
+
routed_scaling_factor=1,
|
199
|
+
eps=float(1e-20),
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
torch_native = True
|
203
|
+
return select_experts(
|
204
|
+
hidden_states=hidden_states,
|
205
|
+
router_logits=router_logits,
|
206
|
+
top_k=self.top_k,
|
207
|
+
use_grouped_topk=self.use_grouped_topk,
|
208
|
+
renormalize=self.renormalize,
|
209
|
+
topk_group=self.topk_group,
|
210
|
+
num_expert_group=self.num_expert_group,
|
211
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
212
|
+
custom_routing_function=self.custom_routing_function,
|
213
|
+
correction_bias=self.correction_bias,
|
214
|
+
torch_native=torch_native,
|
215
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
216
|
+
num_token_non_padded=num_token_non_padded,
|
217
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
218
|
+
)
|
219
|
+
|
59
220
|
|
60
221
|
def fused_topk_torch_native(
|
61
222
|
hidden_states: torch.Tensor,
|
@@ -97,6 +258,19 @@ def fused_topk_cpu(
|
|
97
258
|
return topk_weights, topk_ids
|
98
259
|
|
99
260
|
|
261
|
+
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
262
|
+
if not need_apply:
|
263
|
+
return inputs, topk_weights
|
264
|
+
|
265
|
+
# TODO: fuse below processing in fused_experts_cpu kernel
|
266
|
+
inputs = inputs * topk_weights.to(inputs.dtype)
|
267
|
+
topk_weights = torch.ones_like(
|
268
|
+
topk_weights, dtype=torch.float32
|
269
|
+
) # clear topk_weights as already applied
|
270
|
+
|
271
|
+
return inputs, topk_weights
|
272
|
+
|
273
|
+
|
100
274
|
def fused_topk(
|
101
275
|
hidden_states: torch.Tensor,
|
102
276
|
gating_output: torch.Tensor,
|
@@ -325,7 +499,7 @@ def biased_grouped_topk_gpu(
|
|
325
499
|
and is_power_of_two(correction_bias.shape[0])
|
326
500
|
):
|
327
501
|
topk_weights, topk_ids = moe_fused_gate(
|
328
|
-
gating_output,
|
502
|
+
gating_output.to(dtype=torch.float32),
|
329
503
|
correction_bias,
|
330
504
|
num_expert_group,
|
331
505
|
topk_group,
|
@@ -350,7 +524,7 @@ def biased_grouped_topk_gpu(
|
|
350
524
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
351
525
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
352
526
|
aiter_biased_grouped_topk(
|
353
|
-
gating_output,
|
527
|
+
gating_output.to(dtype=torch.float32),
|
354
528
|
correction_bias,
|
355
529
|
topk_weights,
|
356
530
|
topk_ids,
|
@@ -427,8 +601,9 @@ def select_experts(
|
|
427
601
|
hidden_states: torch.Tensor,
|
428
602
|
router_logits: torch.Tensor,
|
429
603
|
top_k: int,
|
430
|
-
|
431
|
-
|
604
|
+
*,
|
605
|
+
use_grouped_topk: bool = False,
|
606
|
+
renormalize: bool = False,
|
432
607
|
topk_group: Optional[int] = None,
|
433
608
|
num_expert_group: Optional[int] = None,
|
434
609
|
num_fused_shared_experts: int = 0,
|
@@ -438,7 +613,7 @@ def select_experts(
|
|
438
613
|
routed_scaling_factor: Optional[float] = None,
|
439
614
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
440
615
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
441
|
-
):
|
616
|
+
) -> TopKOutput:
|
442
617
|
router_logits, correction_bias = (
|
443
618
|
expert_location_dispatch.transform_select_experts_inputs(
|
444
619
|
router_logits=router_logits,
|
@@ -513,4 +688,4 @@ def select_experts(
|
|
513
688
|
|
514
689
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
515
690
|
|
516
|
-
return topk_weights, topk_ids
|
691
|
+
return TopKOutput(topk_weights, topk_ids, router_logits)
|
@@ -1,18 +1,14 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
+
from __future__ import annotations
|
3
|
+
|
2
4
|
import builtins
|
3
5
|
import inspect
|
4
|
-
import
|
5
|
-
from copy import deepcopy
|
6
|
-
from typing import Callable, Dict, Optional, Type, Union
|
6
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
-
from vllm.model_executor.layers.quantization.awq_marlin import (
|
13
|
-
AWQMarlinConfig,
|
14
|
-
AWQMoEMethod,
|
15
|
-
)
|
16
12
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
17
13
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
18
14
|
CompressedTensorsW8A8Fp8MoEMethod,
|
@@ -22,10 +18,6 @@ try:
|
|
22
18
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
23
19
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
24
20
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
25
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
|
-
GPTQMarlinLinearMethod,
|
28
|
-
)
|
29
21
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
30
22
|
GPTQMarlin24Config,
|
31
23
|
)
|
@@ -42,15 +34,14 @@ except ImportError:
|
|
42
34
|
def override_quantization_method(self, *args, **kwargs):
|
43
35
|
return None
|
44
36
|
|
45
|
-
AQLMConfig =
|
46
|
-
|
47
|
-
) =
|
48
|
-
|
49
|
-
) =
|
37
|
+
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
38
|
+
ExpertsInt8Config
|
39
|
+
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
|
40
|
+
Int8TpuConfig
|
41
|
+
) = DummyConfig
|
50
42
|
|
51
43
|
|
52
|
-
from sglang.srt.layers.
|
53
|
-
from sglang.srt.layers.quantization.awq import AWQConfig
|
44
|
+
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
54
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
55
46
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
56
47
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
@@ -59,7 +50,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
59
50
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
60
51
|
from sglang.srt.layers.quantization.gptq import (
|
61
52
|
GPTQConfig,
|
53
|
+
GPTQLinearMethod,
|
62
54
|
GPTQMarlinConfig,
|
55
|
+
GPTQMarlinLinearMethod,
|
63
56
|
GPTQMarlinMoEMethod,
|
64
57
|
)
|
65
58
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
@@ -67,11 +60,16 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
67
60
|
ModelOptFp8Config,
|
68
61
|
)
|
69
62
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
63
|
+
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
70
64
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
65
|
+
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
71
66
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
72
67
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
73
68
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
74
69
|
|
70
|
+
if TYPE_CHECKING:
|
71
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
72
|
+
|
75
73
|
# Base quantization methods that don't depend on vllm
|
76
74
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
77
75
|
"fp8": Fp8Config,
|
@@ -84,6 +82,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
84
82
|
"compressed-tensors": CompressedTensorsConfig,
|
85
83
|
"qoq": QoQConfig,
|
86
84
|
"w4afp8": W4AFp8Config,
|
85
|
+
"petit_nvfp4": PetitNvFp4Config,
|
87
86
|
}
|
88
87
|
|
89
88
|
# VLLM-dependent quantization methods
|
@@ -122,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
122
121
|
return QUANTIZATION_METHODS[quantization]
|
123
122
|
|
124
123
|
|
125
|
-
# Match dynamic rules with module name (prefix) and override quantize
|
126
|
-
# config if module (prefix) matches a rule
|
127
|
-
def override_config(config: QuantizationConfig, prefix: str):
|
128
|
-
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
129
|
-
if isinstance(weight_bits, int):
|
130
|
-
config.weight_bits = weight_bits
|
131
|
-
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
132
|
-
if isinstance(group_size, int):
|
133
|
-
config.group_size = group_size
|
134
|
-
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
135
|
-
if isinstance(desc_act, bool):
|
136
|
-
config.desc_act = desc_act
|
137
|
-
|
138
|
-
config.pack_factor = 32 // config.weight_bits # packed into int32
|
139
|
-
if config.get_name() == "gptq_marlin":
|
140
|
-
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
141
|
-
if isinstance(is_sym, bool):
|
142
|
-
config.is_sym = is_sym
|
143
|
-
|
144
|
-
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
145
|
-
raise ValueError(
|
146
|
-
"Unsupported quantization config: "
|
147
|
-
f"bits={config.weight_bits}, sym={config.is_sym}"
|
148
|
-
)
|
149
|
-
|
150
|
-
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
151
|
-
elif config.get_name() == "gptq":
|
152
|
-
if config.weight_bits not in [2, 3, 4, 8]:
|
153
|
-
raise ValueError(
|
154
|
-
"Currently, only 2/3/4/8-bit weight quantization is "
|
155
|
-
f"supported for GPTQ, but got {config.weight_bits} bits."
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
def get_dynamic_override(
|
160
|
-
config: QuantizationConfig,
|
161
|
-
layer_name: str,
|
162
|
-
key: Optional[str] = None,
|
163
|
-
default_value: Union[int, bool, None] = None,
|
164
|
-
) -> Union[Dict, int, bool, None]:
|
165
|
-
for pattern, pattern_dict in config.dynamic.items():
|
166
|
-
# Negative match: matched modules are excluded from quantized init
|
167
|
-
if pattern.startswith("-:"):
|
168
|
-
if re.match(pattern.removeprefix("-:"), layer_name):
|
169
|
-
return False
|
170
|
-
# Positive match: matched modules have quant properties overrides
|
171
|
-
# base quant config
|
172
|
-
elif re.match(pattern.removeprefix("+:"), layer_name):
|
173
|
-
if key is None:
|
174
|
-
return pattern_dict
|
175
|
-
else:
|
176
|
-
return pattern_dict.get(key, default_value)
|
177
|
-
return default_value
|
178
|
-
|
179
|
-
|
180
|
-
def get_linear_quant_method(
|
181
|
-
config: QuantizationConfig,
|
182
|
-
layer: torch.nn.Module,
|
183
|
-
prefix: str,
|
184
|
-
linear_method_cls: type,
|
185
|
-
):
|
186
|
-
# Move import here to avoid circular import. This is only used in monkey patching
|
187
|
-
# of vllm's QuantizationConfig.
|
188
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
189
|
-
ParallelLMHead,
|
190
|
-
UnquantizedEmbeddingMethod,
|
191
|
-
)
|
192
|
-
|
193
|
-
cloned_config = deepcopy(config)
|
194
|
-
parallel_lm_head_quantized = (
|
195
|
-
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
196
|
-
)
|
197
|
-
|
198
|
-
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
199
|
-
# False = skip module, None = no override, else = Positive match
|
200
|
-
if (
|
201
|
-
get_dynamic_override( # noqa: E712
|
202
|
-
cloned_config, layer_name=prefix # noqa: E712
|
203
|
-
)
|
204
|
-
== False
|
205
|
-
): # noqa: E712
|
206
|
-
if parallel_lm_head_quantized:
|
207
|
-
return UnquantizedEmbeddingMethod()
|
208
|
-
return UnquantizedLinearMethod()
|
209
|
-
|
210
|
-
if prefix:
|
211
|
-
# Dynamic per module/layer rules may override base config
|
212
|
-
override_config(cloned_config, prefix=prefix)
|
213
|
-
|
214
|
-
return linear_method_cls(cloned_config)
|
215
|
-
return None
|
216
|
-
|
217
|
-
|
218
124
|
def gptq_get_quant_method(self, layer, prefix):
|
219
125
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
220
126
|
|
@@ -285,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
285
191
|
self,
|
286
192
|
layer: torch.nn.Module,
|
287
193
|
x: torch.Tensor,
|
288
|
-
|
289
|
-
|
290
|
-
renormalize: bool,
|
291
|
-
use_grouped_topk: bool,
|
292
|
-
topk_group: Optional[int] = None,
|
293
|
-
num_expert_group: Optional[int] = None,
|
294
|
-
num_fused_shared_experts: int = 0,
|
295
|
-
custom_routing_function: Optional[Callable] = None,
|
296
|
-
correction_bias: Optional[torch.Tensor] = None,
|
194
|
+
topk_output: TopKOutput,
|
195
|
+
*,
|
297
196
|
activation: str = "silu",
|
298
197
|
apply_router_weight_on_input: bool = False,
|
299
198
|
inplace: bool = True,
|
@@ -307,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
307
206
|
"self": self,
|
308
207
|
"layer": layer,
|
309
208
|
"x": x,
|
310
|
-
"
|
311
|
-
"top_k": top_k,
|
312
|
-
"renormalize": renormalize,
|
313
|
-
"use_grouped_topk": use_grouped_topk,
|
314
|
-
"topk_group": topk_group,
|
315
|
-
"num_expert_group": num_expert_group,
|
316
|
-
"custom_routing_function": custom_routing_function,
|
209
|
+
"topk_output": topk_output,
|
317
210
|
}
|
318
|
-
if correction_bias is not None:
|
319
|
-
if not has_correction_bias:
|
320
|
-
raise ValueError(
|
321
|
-
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
|
322
|
-
)
|
323
|
-
kwargs["e_score_correction_bias"] = correction_bias
|
324
211
|
return original_apply(**kwargs)
|
325
212
|
|
326
213
|
setattr(class_obj, "apply", new_apply)
|
@@ -331,7 +218,6 @@ def monkey_patch_quant_configs():
|
|
331
218
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
332
219
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
333
220
|
|
334
|
-
monkey_patch_moe_apply(AWQMoEMethod)
|
335
221
|
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
336
222
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
337
223
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|