sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
|
|
1
|
+
# Adapted from qwen2_moe.py
|
2
|
+
|
3
|
+
# Copyright 2023-2024 SGLang Team
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# ==============================================================================
|
16
|
+
|
17
|
+
|
18
|
+
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from functools import partial
|
21
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.nn.functional as F
|
25
|
+
from torch import nn
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
28
|
+
get_tensor_model_parallel_rank,
|
29
|
+
get_tensor_model_parallel_world_size,
|
30
|
+
split_tensor_along_last_dim,
|
31
|
+
tensor_model_parallel_all_gather,
|
32
|
+
tensor_model_parallel_all_reduce,
|
33
|
+
)
|
34
|
+
from sglang.srt.layers.activation import SiluAndMul
|
35
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
36
|
+
from sglang.srt.layers.linear import (
|
37
|
+
MergedColumnParallelLinear,
|
38
|
+
QKVParallelLinear,
|
39
|
+
ReplicatedLinear,
|
40
|
+
RowParallelLinear,
|
41
|
+
)
|
42
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
46
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
47
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
48
|
+
ParallelLMHead,
|
49
|
+
VocabParallelEmbedding,
|
50
|
+
)
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
|
+
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
54
|
+
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
55
|
+
from sglang.srt.utils import add_prefix
|
56
|
+
|
57
|
+
Qwen3MoeConfig = None
|
58
|
+
|
59
|
+
|
60
|
+
class Qwen3MoeSparseMoeBlock(nn.Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
config: Qwen3MoeConfig,
|
64
|
+
quant_config: Optional[QuantizationConfig] = None,
|
65
|
+
prefix: str = "",
|
66
|
+
):
|
67
|
+
super().__init__()
|
68
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
69
|
+
|
70
|
+
if self.tp_size > config.num_experts:
|
71
|
+
raise ValueError(
|
72
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
73
|
+
f"the number of experts {config.num_experts}."
|
74
|
+
)
|
75
|
+
|
76
|
+
self.experts = FusedMoE(
|
77
|
+
num_experts=config.num_experts,
|
78
|
+
top_k=config.num_experts_per_tok,
|
79
|
+
hidden_size=config.hidden_size,
|
80
|
+
intermediate_size=config.moe_intermediate_size,
|
81
|
+
reduce_results=False,
|
82
|
+
renormalize=config.norm_topk_prob,
|
83
|
+
quant_config=quant_config,
|
84
|
+
prefix=add_prefix("experts", prefix),
|
85
|
+
)
|
86
|
+
|
87
|
+
self.gate = ReplicatedLinear(
|
88
|
+
config.hidden_size,
|
89
|
+
config.num_experts,
|
90
|
+
bias=False,
|
91
|
+
quant_config=None,
|
92
|
+
prefix=add_prefix("gate", prefix),
|
93
|
+
)
|
94
|
+
|
95
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
96
|
+
num_tokens, hidden_dim = hidden_states.shape
|
97
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
98
|
+
|
99
|
+
# router_logits: (num_tokens, n_experts)
|
100
|
+
router_logits, _ = self.gate(hidden_states)
|
101
|
+
final_hidden_states = self.experts(
|
102
|
+
hidden_states=hidden_states, router_logits=router_logits
|
103
|
+
)
|
104
|
+
if self.tp_size > 1:
|
105
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
106
|
+
|
107
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
108
|
+
|
109
|
+
|
110
|
+
class Qwen3MoeAttention(nn.Module):
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
hidden_size: int,
|
114
|
+
num_heads: int,
|
115
|
+
num_kv_heads: int,
|
116
|
+
layer_id: int = 0,
|
117
|
+
rope_theta: float = 10000,
|
118
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
119
|
+
max_position_embeddings: int = 8192,
|
120
|
+
head_dim: Optional[int] = None,
|
121
|
+
rms_norm_eps: float = 1e-06,
|
122
|
+
attention_bias: bool = False,
|
123
|
+
quant_config: Optional[QuantizationConfig] = None,
|
124
|
+
prefix: str = "",
|
125
|
+
) -> None:
|
126
|
+
super().__init__()
|
127
|
+
self.hidden_size = hidden_size
|
128
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
129
|
+
self.total_num_heads = num_heads
|
130
|
+
assert self.total_num_heads % self.tp_size == 0
|
131
|
+
self.num_heads = self.total_num_heads // self.tp_size
|
132
|
+
self.total_num_kv_heads = num_kv_heads
|
133
|
+
if self.total_num_kv_heads >= self.tp_size:
|
134
|
+
# Number of KV heads is greater than TP size, so we partition
|
135
|
+
# the KV heads across multiple tensor parallel GPUs.
|
136
|
+
assert self.total_num_kv_heads % self.tp_size == 0
|
137
|
+
else:
|
138
|
+
# Number of KV heads is less than TP size, so we replicate
|
139
|
+
# the KV heads across multiple tensor parallel GPUs.
|
140
|
+
assert self.tp_size % self.total_num_kv_heads == 0
|
141
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
142
|
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
143
|
+
self.q_size = self.num_heads * self.head_dim
|
144
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
145
|
+
self.scaling = self.head_dim**-0.5
|
146
|
+
self.rope_theta = rope_theta
|
147
|
+
self.max_position_embeddings = max_position_embeddings
|
148
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
149
|
+
|
150
|
+
self.qkv_proj = QKVParallelLinear(
|
151
|
+
hidden_size,
|
152
|
+
self.head_dim,
|
153
|
+
self.total_num_heads,
|
154
|
+
self.total_num_kv_heads,
|
155
|
+
bias=attention_bias,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("qkv_proj", prefix),
|
158
|
+
)
|
159
|
+
|
160
|
+
self.o_proj = RowParallelLinear(
|
161
|
+
self.total_num_heads * self.head_dim,
|
162
|
+
hidden_size,
|
163
|
+
bias=attention_bias,
|
164
|
+
quant_config=quant_config,
|
165
|
+
prefix=add_prefix("o_proj", prefix),
|
166
|
+
)
|
167
|
+
|
168
|
+
self.rotary_emb = get_rope(
|
169
|
+
self.head_dim,
|
170
|
+
rotary_dim=self.head_dim,
|
171
|
+
max_position=max_position_embeddings,
|
172
|
+
base=rope_theta,
|
173
|
+
rope_scaling=rope_scaling,
|
174
|
+
)
|
175
|
+
self.attn = RadixAttention(
|
176
|
+
self.num_heads,
|
177
|
+
self.head_dim,
|
178
|
+
self.scaling,
|
179
|
+
num_kv_heads=self.num_kv_heads,
|
180
|
+
layer_id=layer_id,
|
181
|
+
prefix=add_prefix("attn", prefix),
|
182
|
+
)
|
183
|
+
|
184
|
+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
185
|
+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
186
|
+
|
187
|
+
def _apply_qk_norm(
|
188
|
+
self, q: torch.Tensor, k: torch.Tensor
|
189
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
190
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
191
|
+
q_by_head = self.q_norm(q_by_head)
|
192
|
+
q = q_by_head.view(q.shape)
|
193
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
194
|
+
k_by_head = self.k_norm(k_by_head)
|
195
|
+
k = k_by_head.view(k.shape)
|
196
|
+
return q, k
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
self,
|
200
|
+
positions: torch.Tensor,
|
201
|
+
hidden_states: torch.Tensor,
|
202
|
+
forward_batch: ForwardBatch,
|
203
|
+
) -> torch.Tensor:
|
204
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
205
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
206
|
+
q, k = self._apply_qk_norm(q, k)
|
207
|
+
q, k = self.rotary_emb(positions, q, k)
|
208
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
209
|
+
output, _ = self.o_proj(attn_output)
|
210
|
+
return output
|
211
|
+
|
212
|
+
|
213
|
+
class Qwen3MoeDecoderLayer(nn.Module):
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
config: Qwen3MoeConfig,
|
217
|
+
layer_id: int,
|
218
|
+
quant_config: Optional[QuantizationConfig] = None,
|
219
|
+
prefix: str = "",
|
220
|
+
) -> None:
|
221
|
+
super().__init__()
|
222
|
+
self.hidden_size = config.hidden_size
|
223
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
224
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
225
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
226
|
+
head_dim = getattr(
|
227
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
228
|
+
)
|
229
|
+
rms_norm_eps = config.rms_norm_eps
|
230
|
+
attention_bias = config.attention_bias
|
231
|
+
self.self_attn = Qwen3MoeAttention(
|
232
|
+
hidden_size=self.hidden_size,
|
233
|
+
num_heads=config.num_attention_heads,
|
234
|
+
num_kv_heads=config.num_key_value_heads,
|
235
|
+
layer_id=layer_id,
|
236
|
+
rope_theta=rope_theta,
|
237
|
+
rope_scaling=rope_scaling,
|
238
|
+
max_position_embeddings=max_position_embeddings,
|
239
|
+
head_dim=head_dim,
|
240
|
+
rms_norm_eps=rms_norm_eps,
|
241
|
+
attention_bias=attention_bias,
|
242
|
+
quant_config=quant_config,
|
243
|
+
prefix=add_prefix("self_attn", prefix),
|
244
|
+
)
|
245
|
+
|
246
|
+
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
247
|
+
# `mlp_only_layers` in the config.
|
248
|
+
mlp_only_layers = (
|
249
|
+
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
250
|
+
)
|
251
|
+
if (layer_id not in mlp_only_layers) and (
|
252
|
+
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
253
|
+
):
|
254
|
+
self.mlp = Qwen3MoeSparseMoeBlock(
|
255
|
+
config=config,
|
256
|
+
quant_config=quant_config,
|
257
|
+
prefix=add_prefix("mlp", prefix),
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
self.mlp = Qwen3MoeMLP(
|
261
|
+
hidden_size=config.hidden_size,
|
262
|
+
intermediate_size=config.intermediate_size,
|
263
|
+
hidden_act=config.hidden_act,
|
264
|
+
quant_config=quant_config,
|
265
|
+
prefix=add_prefix("mlp", prefix),
|
266
|
+
)
|
267
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
268
|
+
self.post_attention_layernorm = RMSNorm(
|
269
|
+
config.hidden_size, eps=config.rms_norm_eps
|
270
|
+
)
|
271
|
+
|
272
|
+
def forward(
|
273
|
+
self,
|
274
|
+
positions: torch.Tensor,
|
275
|
+
hidden_states: torch.Tensor,
|
276
|
+
forward_batch: ForwardBatch,
|
277
|
+
residual: Optional[torch.Tensor],
|
278
|
+
) -> torch.Tensor:
|
279
|
+
# Self Attention
|
280
|
+
if residual is None:
|
281
|
+
residual = hidden_states
|
282
|
+
hidden_states = self.input_layernorm(hidden_states)
|
283
|
+
else:
|
284
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
285
|
+
hidden_states = self.self_attn(
|
286
|
+
positions=positions,
|
287
|
+
hidden_states=hidden_states,
|
288
|
+
forward_batch=forward_batch,
|
289
|
+
)
|
290
|
+
|
291
|
+
# Fully Connected
|
292
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
293
|
+
hidden_states = self.mlp(hidden_states)
|
294
|
+
return hidden_states, residual
|
295
|
+
|
296
|
+
|
297
|
+
class Qwen3MoeModel(Qwen2MoeModel):
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
config: Qwen3MoeConfig,
|
301
|
+
quant_config: Optional[QuantizationConfig] = None,
|
302
|
+
prefix: str = "",
|
303
|
+
) -> None:
|
304
|
+
super().__init__(
|
305
|
+
config=config,
|
306
|
+
quant_config=quant_config,
|
307
|
+
prefix=prefix,
|
308
|
+
decoder_layer_type=Qwen3MoeDecoderLayer,
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
class Qwen3MoeForCausalLM(nn.Module):
|
313
|
+
|
314
|
+
fall_back_to_pt_during_load = False
|
315
|
+
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
config: Qwen3MoeConfig,
|
319
|
+
quant_config: Optional[QuantizationConfig] = None,
|
320
|
+
prefix: str = "",
|
321
|
+
) -> None:
|
322
|
+
super().__init__()
|
323
|
+
self.config = config
|
324
|
+
self.quant_config = quant_config
|
325
|
+
self.model = Qwen3MoeModel(
|
326
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
327
|
+
)
|
328
|
+
self.lm_head = ParallelLMHead(
|
329
|
+
config.vocab_size,
|
330
|
+
config.hidden_size,
|
331
|
+
quant_config=quant_config,
|
332
|
+
prefix=add_prefix("lm_head", prefix),
|
333
|
+
)
|
334
|
+
self.logits_processor = LogitsProcessor(config)
|
335
|
+
|
336
|
+
@torch.no_grad()
|
337
|
+
def forward(
|
338
|
+
self,
|
339
|
+
input_ids: torch.Tensor,
|
340
|
+
positions: torch.Tensor,
|
341
|
+
forward_batch: ForwardBatch,
|
342
|
+
input_embeds: torch.Tensor = None,
|
343
|
+
) -> torch.Tensor:
|
344
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
345
|
+
return self.logits_processor(
|
346
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
347
|
+
)
|
348
|
+
|
349
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
350
|
+
stacked_params_mapping = [
|
351
|
+
# (param_name, shard_name, shard_id)
|
352
|
+
("qkv_proj", "q_proj", "q"),
|
353
|
+
("qkv_proj", "k_proj", "k"),
|
354
|
+
("qkv_proj", "v_proj", "v"),
|
355
|
+
("gate_up_proj", "gate_proj", 0),
|
356
|
+
("gate_up_proj", "up_proj", 1),
|
357
|
+
]
|
358
|
+
|
359
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
360
|
+
ckpt_gate_proj_name="gate_proj",
|
361
|
+
ckpt_down_proj_name="down_proj",
|
362
|
+
ckpt_up_proj_name="up_proj",
|
363
|
+
num_experts=self.config.num_experts,
|
364
|
+
)
|
365
|
+
|
366
|
+
params_dict = dict(self.named_parameters())
|
367
|
+
for name, loaded_weight in weights:
|
368
|
+
if "rotary_emb.inv_freq" in name:
|
369
|
+
continue
|
370
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
371
|
+
# Skip non-stacked layers and experts (experts handled below).
|
372
|
+
if weight_name not in name:
|
373
|
+
continue
|
374
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
375
|
+
# Since we handle the experts below in expert_params_mapping,
|
376
|
+
# we need to skip here BEFORE we update the name, otherwise
|
377
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
378
|
+
# will then be updated below in expert_params_mapping
|
379
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
380
|
+
if "mlp.experts" in name:
|
381
|
+
continue
|
382
|
+
name = name.replace(weight_name, param_name)
|
383
|
+
# Skip loading extra bias for GPTQ models.
|
384
|
+
if name.endswith(".bias") and name not in params_dict:
|
385
|
+
continue
|
386
|
+
if name not in params_dict:
|
387
|
+
continue
|
388
|
+
|
389
|
+
param = params_dict[name]
|
390
|
+
weight_loader = param.weight_loader
|
391
|
+
weight_loader(param, loaded_weight, shard_id)
|
392
|
+
break
|
393
|
+
else:
|
394
|
+
for mapping in expert_params_mapping:
|
395
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
396
|
+
if weight_name not in name:
|
397
|
+
continue
|
398
|
+
name = name.replace(weight_name, param_name)
|
399
|
+
param = params_dict[name]
|
400
|
+
weight_loader = param.weight_loader
|
401
|
+
weight_loader(
|
402
|
+
param,
|
403
|
+
loaded_weight,
|
404
|
+
name,
|
405
|
+
shard_id=shard_id,
|
406
|
+
expert_id=expert_id,
|
407
|
+
)
|
408
|
+
break
|
409
|
+
else:
|
410
|
+
# Skip loading extra bias for GPTQ models.
|
411
|
+
if name.endswith(".bias") and name not in params_dict:
|
412
|
+
continue
|
413
|
+
if name not in params_dict:
|
414
|
+
continue
|
415
|
+
|
416
|
+
param = params_dict[name]
|
417
|
+
weight_loader = getattr(
|
418
|
+
param, "weight_loader", default_weight_loader
|
419
|
+
)
|
420
|
+
weight_loader(param, loaded_weight)
|
421
|
+
|
422
|
+
|
423
|
+
EntryClass = Qwen3MoeForCausalLM
|
sglang/srt/models/stablelm.py
CHANGED
sglang/srt/models/xverse.py
CHANGED
sglang/srt/models/xverse_moe.py
CHANGED
sglang/srt/openai_api/adapter.py
CHANGED
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
|
|
983
983
|
):
|
984
984
|
encoded = encoded[1:]
|
985
985
|
prompt_ids += encoded
|
986
|
+
if tokenizer_manager.model_config.is_multimodal:
|
987
|
+
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
|
986
988
|
stop = request.stop
|
987
989
|
image_data = None
|
988
990
|
audio_data = None
|
@@ -993,7 +995,8 @@ def v1_chat_generate_request(
|
|
993
995
|
image_data = conv.image_data
|
994
996
|
audio_data = conv.audio_data
|
995
997
|
modalities = conv.modalities
|
996
|
-
stop = conv.stop_str or []
|
998
|
+
stop = conv.stop_str or [] if not request.ignore_eos else []
|
999
|
+
|
997
1000
|
if request.stop:
|
998
1001
|
if isinstance(request.stop, str):
|
999
1002
|
stop.append(request.stop)
|
sglang/srt/patch_torch.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
from typing import Callable, Union
|
15
15
|
|
16
16
|
import torch
|
17
|
+
from packaging import version
|
17
18
|
from torch.multiprocessing import reductions
|
18
19
|
|
19
20
|
|
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
|
|
69
70
|
|
70
71
|
def _modify_tuple(t, index: int, modifier: Callable):
|
71
72
|
return *t[:index], modifier(t[index]), *t[index + 1 :]
|
73
|
+
|
74
|
+
|
75
|
+
def monkey_patch_torch_compile():
|
76
|
+
if version.parse(torch.__version__) < version.parse("2.8.0"):
|
77
|
+
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
|
78
|
+
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
|
79
|
+
import torch._higher_order_ops.auto_functionalize as af
|
80
|
+
|
81
|
+
af.auto_functionalized_v2._cacheable = True
|
82
|
+
af.auto_functionalized._cacheable = True
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -10,12 +10,11 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
|
13
|
-
logger = logging.getLogger(__name__)
|
14
|
-
|
15
|
-
|
16
13
|
if TYPE_CHECKING:
|
17
14
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
18
15
|
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
19
18
|
|
20
19
|
@dataclasses.dataclass
|
21
20
|
class SamplingBatchInfo:
|