sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
|
|
1
|
+
# Adapted from qwen2_moe.py
|
2
|
+
|
3
|
+
# Copyright 2023-2024 SGLang Team
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# ==============================================================================
|
16
|
+
|
17
|
+
|
18
|
+
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from functools import partial
|
21
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.nn.functional as F
|
25
|
+
from torch import nn
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
28
|
+
get_tensor_model_parallel_rank,
|
29
|
+
get_tensor_model_parallel_world_size,
|
30
|
+
split_tensor_along_last_dim,
|
31
|
+
tensor_model_parallel_all_gather,
|
32
|
+
tensor_model_parallel_all_reduce,
|
33
|
+
)
|
34
|
+
from sglang.srt.layers.activation import SiluAndMul
|
35
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
36
|
+
from sglang.srt.layers.linear import (
|
37
|
+
MergedColumnParallelLinear,
|
38
|
+
QKVParallelLinear,
|
39
|
+
ReplicatedLinear,
|
40
|
+
RowParallelLinear,
|
41
|
+
)
|
42
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
46
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
47
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
48
|
+
ParallelLMHead,
|
49
|
+
VocabParallelEmbedding,
|
50
|
+
)
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
|
+
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
54
|
+
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
55
|
+
from sglang.srt.utils import add_prefix
|
56
|
+
|
57
|
+
Qwen3MoeConfig = None
|
58
|
+
|
59
|
+
|
60
|
+
class Qwen3MoeSparseMoeBlock(nn.Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
config: Qwen3MoeConfig,
|
64
|
+
quant_config: Optional[QuantizationConfig] = None,
|
65
|
+
prefix: str = "",
|
66
|
+
):
|
67
|
+
super().__init__()
|
68
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
69
|
+
|
70
|
+
if self.tp_size > config.num_experts:
|
71
|
+
raise ValueError(
|
72
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
73
|
+
f"the number of experts {config.num_experts}."
|
74
|
+
)
|
75
|
+
|
76
|
+
self.experts = FusedMoE(
|
77
|
+
num_experts=config.num_experts,
|
78
|
+
top_k=config.num_experts_per_tok,
|
79
|
+
hidden_size=config.hidden_size,
|
80
|
+
intermediate_size=config.moe_intermediate_size,
|
81
|
+
reduce_results=False,
|
82
|
+
renormalize=config.norm_topk_prob,
|
83
|
+
quant_config=quant_config,
|
84
|
+
prefix=add_prefix("experts", prefix),
|
85
|
+
)
|
86
|
+
|
87
|
+
self.gate = ReplicatedLinear(
|
88
|
+
config.hidden_size,
|
89
|
+
config.num_experts,
|
90
|
+
bias=False,
|
91
|
+
quant_config=None,
|
92
|
+
prefix=add_prefix("gate", prefix),
|
93
|
+
)
|
94
|
+
|
95
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
96
|
+
num_tokens, hidden_dim = hidden_states.shape
|
97
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
98
|
+
|
99
|
+
# router_logits: (num_tokens, n_experts)
|
100
|
+
router_logits, _ = self.gate(hidden_states)
|
101
|
+
final_hidden_states = self.experts(
|
102
|
+
hidden_states=hidden_states, router_logits=router_logits
|
103
|
+
)
|
104
|
+
if self.tp_size > 1:
|
105
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
106
|
+
|
107
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
108
|
+
|
109
|
+
|
110
|
+
class Qwen3MoeAttention(nn.Module):
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
hidden_size: int,
|
114
|
+
num_heads: int,
|
115
|
+
num_kv_heads: int,
|
116
|
+
layer_id: int = 0,
|
117
|
+
rope_theta: float = 10000,
|
118
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
119
|
+
max_position_embeddings: int = 8192,
|
120
|
+
head_dim: Optional[int] = None,
|
121
|
+
rms_norm_eps: float = 1e-06,
|
122
|
+
attention_bias: bool = False,
|
123
|
+
quant_config: Optional[QuantizationConfig] = None,
|
124
|
+
prefix: str = "",
|
125
|
+
) -> None:
|
126
|
+
super().__init__()
|
127
|
+
self.hidden_size = hidden_size
|
128
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
129
|
+
self.total_num_heads = num_heads
|
130
|
+
assert self.total_num_heads % self.tp_size == 0
|
131
|
+
self.num_heads = self.total_num_heads // self.tp_size
|
132
|
+
self.total_num_kv_heads = num_kv_heads
|
133
|
+
if self.total_num_kv_heads >= self.tp_size:
|
134
|
+
# Number of KV heads is greater than TP size, so we partition
|
135
|
+
# the KV heads across multiple tensor parallel GPUs.
|
136
|
+
assert self.total_num_kv_heads % self.tp_size == 0
|
137
|
+
else:
|
138
|
+
# Number of KV heads is less than TP size, so we replicate
|
139
|
+
# the KV heads across multiple tensor parallel GPUs.
|
140
|
+
assert self.tp_size % self.total_num_kv_heads == 0
|
141
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
142
|
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
143
|
+
self.q_size = self.num_heads * self.head_dim
|
144
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
145
|
+
self.scaling = self.head_dim**-0.5
|
146
|
+
self.rope_theta = rope_theta
|
147
|
+
self.max_position_embeddings = max_position_embeddings
|
148
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
149
|
+
|
150
|
+
self.qkv_proj = QKVParallelLinear(
|
151
|
+
hidden_size,
|
152
|
+
self.head_dim,
|
153
|
+
self.total_num_heads,
|
154
|
+
self.total_num_kv_heads,
|
155
|
+
bias=attention_bias,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("qkv_proj", prefix),
|
158
|
+
)
|
159
|
+
|
160
|
+
self.o_proj = RowParallelLinear(
|
161
|
+
self.total_num_heads * self.head_dim,
|
162
|
+
hidden_size,
|
163
|
+
bias=attention_bias,
|
164
|
+
quant_config=quant_config,
|
165
|
+
prefix=add_prefix("o_proj", prefix),
|
166
|
+
)
|
167
|
+
|
168
|
+
self.rotary_emb = get_rope(
|
169
|
+
self.head_dim,
|
170
|
+
rotary_dim=self.head_dim,
|
171
|
+
max_position=max_position_embeddings,
|
172
|
+
base=rope_theta,
|
173
|
+
rope_scaling=rope_scaling,
|
174
|
+
)
|
175
|
+
self.attn = RadixAttention(
|
176
|
+
self.num_heads,
|
177
|
+
self.head_dim,
|
178
|
+
self.scaling,
|
179
|
+
num_kv_heads=self.num_kv_heads,
|
180
|
+
layer_id=layer_id,
|
181
|
+
prefix=add_prefix("attn", prefix),
|
182
|
+
)
|
183
|
+
|
184
|
+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
185
|
+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
186
|
+
|
187
|
+
def _apply_qk_norm(
|
188
|
+
self, q: torch.Tensor, k: torch.Tensor
|
189
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
190
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
191
|
+
q_by_head = self.q_norm(q_by_head)
|
192
|
+
q = q_by_head.view(q.shape)
|
193
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
194
|
+
k_by_head = self.k_norm(k_by_head)
|
195
|
+
k = k_by_head.view(k.shape)
|
196
|
+
return q, k
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
self,
|
200
|
+
positions: torch.Tensor,
|
201
|
+
hidden_states: torch.Tensor,
|
202
|
+
forward_batch: ForwardBatch,
|
203
|
+
) -> torch.Tensor:
|
204
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
205
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
206
|
+
q, k = self._apply_qk_norm(q, k)
|
207
|
+
q, k = self.rotary_emb(positions, q, k)
|
208
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
209
|
+
output, _ = self.o_proj(attn_output)
|
210
|
+
return output
|
211
|
+
|
212
|
+
|
213
|
+
class Qwen3MoeDecoderLayer(nn.Module):
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
config: Qwen3MoeConfig,
|
217
|
+
layer_id: int,
|
218
|
+
quant_config: Optional[QuantizationConfig] = None,
|
219
|
+
prefix: str = "",
|
220
|
+
) -> None:
|
221
|
+
super().__init__()
|
222
|
+
self.hidden_size = config.hidden_size
|
223
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
224
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
225
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
226
|
+
head_dim = getattr(
|
227
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
228
|
+
)
|
229
|
+
rms_norm_eps = config.rms_norm_eps
|
230
|
+
attention_bias = config.attention_bias
|
231
|
+
self.self_attn = Qwen3MoeAttention(
|
232
|
+
hidden_size=self.hidden_size,
|
233
|
+
num_heads=config.num_attention_heads,
|
234
|
+
num_kv_heads=config.num_key_value_heads,
|
235
|
+
layer_id=layer_id,
|
236
|
+
rope_theta=rope_theta,
|
237
|
+
rope_scaling=rope_scaling,
|
238
|
+
max_position_embeddings=max_position_embeddings,
|
239
|
+
head_dim=head_dim,
|
240
|
+
rms_norm_eps=rms_norm_eps,
|
241
|
+
attention_bias=attention_bias,
|
242
|
+
quant_config=quant_config,
|
243
|
+
prefix=add_prefix("self_attn", prefix),
|
244
|
+
)
|
245
|
+
|
246
|
+
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
247
|
+
# `mlp_only_layers` in the config.
|
248
|
+
mlp_only_layers = (
|
249
|
+
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
250
|
+
)
|
251
|
+
if (layer_id not in mlp_only_layers) and (
|
252
|
+
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
253
|
+
):
|
254
|
+
self.mlp = Qwen3MoeSparseMoeBlock(
|
255
|
+
config=config,
|
256
|
+
quant_config=quant_config,
|
257
|
+
prefix=add_prefix("mlp", prefix),
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
self.mlp = Qwen3MoeMLP(
|
261
|
+
hidden_size=config.hidden_size,
|
262
|
+
intermediate_size=config.intermediate_size,
|
263
|
+
hidden_act=config.hidden_act,
|
264
|
+
quant_config=quant_config,
|
265
|
+
prefix=add_prefix("mlp", prefix),
|
266
|
+
)
|
267
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
268
|
+
self.post_attention_layernorm = RMSNorm(
|
269
|
+
config.hidden_size, eps=config.rms_norm_eps
|
270
|
+
)
|
271
|
+
|
272
|
+
def forward(
|
273
|
+
self,
|
274
|
+
positions: torch.Tensor,
|
275
|
+
hidden_states: torch.Tensor,
|
276
|
+
forward_batch: ForwardBatch,
|
277
|
+
residual: Optional[torch.Tensor],
|
278
|
+
) -> torch.Tensor:
|
279
|
+
# Self Attention
|
280
|
+
if residual is None:
|
281
|
+
residual = hidden_states
|
282
|
+
hidden_states = self.input_layernorm(hidden_states)
|
283
|
+
else:
|
284
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
285
|
+
hidden_states = self.self_attn(
|
286
|
+
positions=positions,
|
287
|
+
hidden_states=hidden_states,
|
288
|
+
forward_batch=forward_batch,
|
289
|
+
)
|
290
|
+
|
291
|
+
# Fully Connected
|
292
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
293
|
+
hidden_states = self.mlp(hidden_states)
|
294
|
+
return hidden_states, residual
|
295
|
+
|
296
|
+
|
297
|
+
class Qwen3MoeModel(Qwen2MoeModel):
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
config: Qwen3MoeConfig,
|
301
|
+
quant_config: Optional[QuantizationConfig] = None,
|
302
|
+
prefix: str = "",
|
303
|
+
) -> None:
|
304
|
+
super().__init__(
|
305
|
+
config=config,
|
306
|
+
quant_config=quant_config,
|
307
|
+
prefix=prefix,
|
308
|
+
decoder_layer_type=Qwen3MoeDecoderLayer,
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
class Qwen3MoeForCausalLM(nn.Module):
|
313
|
+
|
314
|
+
fall_back_to_pt_during_load = False
|
315
|
+
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
config: Qwen3MoeConfig,
|
319
|
+
quant_config: Optional[QuantizationConfig] = None,
|
320
|
+
prefix: str = "",
|
321
|
+
) -> None:
|
322
|
+
super().__init__()
|
323
|
+
self.config = config
|
324
|
+
self.quant_config = quant_config
|
325
|
+
self.model = Qwen3MoeModel(
|
326
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
327
|
+
)
|
328
|
+
self.lm_head = ParallelLMHead(
|
329
|
+
config.vocab_size,
|
330
|
+
config.hidden_size,
|
331
|
+
quant_config=quant_config,
|
332
|
+
prefix=add_prefix("lm_head", prefix),
|
333
|
+
)
|
334
|
+
self.logits_processor = LogitsProcessor(config)
|
335
|
+
|
336
|
+
@torch.no_grad()
|
337
|
+
def forward(
|
338
|
+
self,
|
339
|
+
input_ids: torch.Tensor,
|
340
|
+
positions: torch.Tensor,
|
341
|
+
forward_batch: ForwardBatch,
|
342
|
+
input_embeds: torch.Tensor = None,
|
343
|
+
) -> torch.Tensor:
|
344
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
345
|
+
return self.logits_processor(
|
346
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
347
|
+
)
|
348
|
+
|
349
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
350
|
+
stacked_params_mapping = [
|
351
|
+
# (param_name, shard_name, shard_id)
|
352
|
+
("qkv_proj", "q_proj", "q"),
|
353
|
+
("qkv_proj", "k_proj", "k"),
|
354
|
+
("qkv_proj", "v_proj", "v"),
|
355
|
+
("gate_up_proj", "gate_proj", 0),
|
356
|
+
("gate_up_proj", "up_proj", 1),
|
357
|
+
]
|
358
|
+
|
359
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
360
|
+
ckpt_gate_proj_name="gate_proj",
|
361
|
+
ckpt_down_proj_name="down_proj",
|
362
|
+
ckpt_up_proj_name="up_proj",
|
363
|
+
num_experts=self.config.num_experts,
|
364
|
+
)
|
365
|
+
|
366
|
+
params_dict = dict(self.named_parameters())
|
367
|
+
for name, loaded_weight in weights:
|
368
|
+
if "rotary_emb.inv_freq" in name:
|
369
|
+
continue
|
370
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
371
|
+
# Skip non-stacked layers and experts (experts handled below).
|
372
|
+
if weight_name not in name:
|
373
|
+
continue
|
374
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
375
|
+
# Since we handle the experts below in expert_params_mapping,
|
376
|
+
# we need to skip here BEFORE we update the name, otherwise
|
377
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
378
|
+
# will then be updated below in expert_params_mapping
|
379
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
380
|
+
if "mlp.experts" in name:
|
381
|
+
continue
|
382
|
+
name = name.replace(weight_name, param_name)
|
383
|
+
# Skip loading extra bias for GPTQ models.
|
384
|
+
if name.endswith(".bias") and name not in params_dict:
|
385
|
+
continue
|
386
|
+
if name not in params_dict:
|
387
|
+
continue
|
388
|
+
|
389
|
+
param = params_dict[name]
|
390
|
+
weight_loader = param.weight_loader
|
391
|
+
weight_loader(param, loaded_weight, shard_id)
|
392
|
+
break
|
393
|
+
else:
|
394
|
+
for mapping in expert_params_mapping:
|
395
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
396
|
+
if weight_name not in name:
|
397
|
+
continue
|
398
|
+
name = name.replace(weight_name, param_name)
|
399
|
+
param = params_dict[name]
|
400
|
+
weight_loader = param.weight_loader
|
401
|
+
weight_loader(
|
402
|
+
param,
|
403
|
+
loaded_weight,
|
404
|
+
name,
|
405
|
+
shard_id=shard_id,
|
406
|
+
expert_id=expert_id,
|
407
|
+
)
|
408
|
+
break
|
409
|
+
else:
|
410
|
+
# Skip loading extra bias for GPTQ models.
|
411
|
+
if name.endswith(".bias") and name not in params_dict:
|
412
|
+
continue
|
413
|
+
if name not in params_dict:
|
414
|
+
continue
|
415
|
+
|
416
|
+
param = params_dict[name]
|
417
|
+
weight_loader = getattr(
|
418
|
+
param, "weight_loader", default_weight_loader
|
419
|
+
)
|
420
|
+
weight_loader(param, loaded_weight)
|
421
|
+
|
422
|
+
|
423
|
+
EntryClass = Qwen3MoeForCausalLM
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -938,6 +938,35 @@ def v1_chat_generate_request(
|
|
938
938
|
|
939
939
|
if chat_template_name is None:
|
940
940
|
openai_compatible_messages = []
|
941
|
+
if (
|
942
|
+
tools
|
943
|
+
and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
|
944
|
+
):
|
945
|
+
# add function call prompt to deepseekv3
|
946
|
+
openai_compatible_messages.append(
|
947
|
+
{
|
948
|
+
"role": "system",
|
949
|
+
"content": """You are a helpful Assistant.
|
950
|
+
## Tools
|
951
|
+
### Function
|
952
|
+
You have the following functions available:
|
953
|
+
"""
|
954
|
+
+ "".join(
|
955
|
+
[
|
956
|
+
f"""
|
957
|
+
- `{tool['name']}`:
|
958
|
+
```json
|
959
|
+
{json.dumps(tool)}
|
960
|
+
```
|
961
|
+
"""
|
962
|
+
for tool in tools
|
963
|
+
]
|
964
|
+
),
|
965
|
+
}
|
966
|
+
)
|
967
|
+
# TODO fix the compatible issues with xgrammar
|
968
|
+
strict_tag = None
|
969
|
+
|
941
970
|
for message in request.messages:
|
942
971
|
if isinstance(message.content, str):
|
943
972
|
openai_compatible_messages.append(
|
@@ -950,9 +979,16 @@ def v1_chat_generate_request(
|
|
950
979
|
openai_compatible_messages.append(
|
951
980
|
{"role": message.role, "content": content["text"]}
|
952
981
|
)
|
953
|
-
if
|
954
|
-
|
955
|
-
|
982
|
+
if (
|
983
|
+
openai_compatible_messages
|
984
|
+
and openai_compatible_messages[-1]["role"] == "assistant"
|
985
|
+
):
|
986
|
+
if request.continue_final_message:
|
987
|
+
# Remove the final assistant message so its content can be continued.
|
988
|
+
assistant_prefix = openai_compatible_messages[-1]["content"]
|
989
|
+
openai_compatible_messages = openai_compatible_messages[:-1]
|
990
|
+
else:
|
991
|
+
assistant_prefix = None
|
956
992
|
else:
|
957
993
|
assistant_prefix = None
|
958
994
|
|
@@ -991,7 +1027,33 @@ def v1_chat_generate_request(
|
|
991
1027
|
modalities = []
|
992
1028
|
else:
|
993
1029
|
conv = generate_chat_conv(request, chat_template_name)
|
994
|
-
|
1030
|
+
# If we should continue the final assistant message, adjust the conversation.
|
1031
|
+
if (
|
1032
|
+
request.continue_final_message
|
1033
|
+
and request.messages
|
1034
|
+
and request.messages[-1].role == "assistant"
|
1035
|
+
):
|
1036
|
+
# Remove the auto-added blank assistant turn, if present.
|
1037
|
+
if conv.messages and conv.messages[-1][1] is None:
|
1038
|
+
conv.messages.pop()
|
1039
|
+
# Rebuild the prompt from the conversation.
|
1040
|
+
prompt = conv.get_prompt()
|
1041
|
+
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
|
1042
|
+
if isinstance(conv.stop_str, list):
|
1043
|
+
for stop_token in conv.stop_str:
|
1044
|
+
if prompt.endswith(stop_token):
|
1045
|
+
prompt = prompt[: -len(stop_token)]
|
1046
|
+
elif isinstance(conv.stop_str, str) and prompt.endswith(
|
1047
|
+
conv.stop_str
|
1048
|
+
):
|
1049
|
+
prompt = prompt[: -len(conv.stop_str)]
|
1050
|
+
if conv.sep and prompt.endswith(conv.sep):
|
1051
|
+
prompt = prompt[: -len(conv.sep)]
|
1052
|
+
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
|
1053
|
+
prompt = prompt[: -len(conv.sep2)]
|
1054
|
+
else:
|
1055
|
+
prompt = conv.get_prompt()
|
1056
|
+
|
995
1057
|
image_data = conv.image_data
|
996
1058
|
audio_data = conv.audio_data
|
997
1059
|
modalities = conv.modalities
|
@@ -1003,6 +1065,7 @@ def v1_chat_generate_request(
|
|
1003
1065
|
else:
|
1004
1066
|
stop.extend(request.stop)
|
1005
1067
|
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
1068
|
+
|
1006
1069
|
else:
|
1007
1070
|
# Use the raw prompt and stop strings if the messages is already a string.
|
1008
1071
|
prompt_ids = request.messages
|
@@ -1042,6 +1105,8 @@ def v1_chat_generate_request(
|
|
1042
1105
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
1043
1106
|
request.response_format.json_schema.schema_
|
1044
1107
|
)
|
1108
|
+
elif request.response_format and request.response_format.type == "json_object":
|
1109
|
+
sampling_params["json_schema"] = '{"type": "object"}'
|
1045
1110
|
elif (
|
1046
1111
|
request.response_format and request.response_format.type == "structural_tag"
|
1047
1112
|
):
|
@@ -1109,6 +1174,8 @@ def v1_chat_generate_request(
|
|
1109
1174
|
rid=request_ids,
|
1110
1175
|
modalities=modalities_list,
|
1111
1176
|
lora_path=lora_paths,
|
1177
|
+
bootstrap_host=all_requests[0].bootstrap_host,
|
1178
|
+
bootstrap_room=all_requests[0].bootstrap_room,
|
1112
1179
|
)
|
1113
1180
|
|
1114
1181
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
@@ -252,7 +252,7 @@ ChatCompletionMessageContentPart = Union[
|
|
252
252
|
|
253
253
|
class ChatCompletionMessageGenericParam(BaseModel):
|
254
254
|
role: Literal["system", "assistant", "tool"]
|
255
|
-
content: Union[str, List[ChatCompletionMessageContentTextPart]]
|
255
|
+
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
256
256
|
|
257
257
|
|
258
258
|
class ChatCompletionMessageUserParam(BaseModel):
|
@@ -355,12 +355,17 @@ class ChatCompletionRequest(BaseModel):
|
|
355
355
|
stop_token_ids: Optional[List[int]] = None
|
356
356
|
no_stop_trim: bool = False
|
357
357
|
ignore_eos: bool = False
|
358
|
+
continue_final_message: bool = False
|
358
359
|
skip_special_tokens: bool = True
|
359
360
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
360
361
|
session_params: Optional[Dict] = None
|
361
362
|
separate_reasoning: bool = True
|
362
363
|
stream_reasoning: bool = True
|
363
364
|
|
365
|
+
# For PD disaggregation
|
366
|
+
bootstrap_host: Optional[str] = None
|
367
|
+
bootstrap_room: Optional[int] = None
|
368
|
+
|
364
369
|
|
365
370
|
class FunctionResponse(BaseModel):
|
366
371
|
"""Function response."""
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -10,12 +10,11 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
|
13
|
-
logger = logging.getLogger(__name__)
|
14
|
-
|
15
|
-
|
16
13
|
if TYPE_CHECKING:
|
17
14
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
18
15
|
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
19
18
|
|
20
19
|
@dataclasses.dataclass
|
21
20
|
class SamplingBatchInfo:
|