sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Modeling from:
|
16
|
+
# ./llama.py and
|
17
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
|
18
|
+
"""Inference-only GLM4 model compatible with THUDM weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import Glm4Config
|
25
|
+
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
28
|
+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
29
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
30
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
32
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
33
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
34
|
+
ParallelLMHead,
|
35
|
+
VocabParallelEmbedding,
|
36
|
+
)
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
|
+
from sglang.srt.models.llama import LlamaMLP as Glm4MLP
|
40
|
+
from sglang.srt.utils import add_prefix, make_layers
|
41
|
+
|
42
|
+
|
43
|
+
class Glm4Attention(nn.Module):
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
config,
|
47
|
+
layer_id: int = 0,
|
48
|
+
quant_config: Optional[QuantizationConfig] = None,
|
49
|
+
prefix: str = "",
|
50
|
+
):
|
51
|
+
super().__init__()
|
52
|
+
self.hidden_size = config.hidden_size
|
53
|
+
tp_size = get_tensor_model_parallel_world_size()
|
54
|
+
self.total_num_heads = config.num_attention_heads
|
55
|
+
assert self.total_num_heads % tp_size == 0
|
56
|
+
self.num_heads = self.total_num_heads // tp_size
|
57
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
58
|
+
if self.total_num_kv_heads >= tp_size:
|
59
|
+
# Number of KV heads is greater than TP size, so we partition
|
60
|
+
# the KV heads across multiple tensor parallel GPUs.
|
61
|
+
assert self.total_num_kv_heads % tp_size == 0
|
62
|
+
else:
|
63
|
+
# Number of KV heads is less than TP size, so we replicate
|
64
|
+
# the KV heads across multiple tensor parallel GPUs.
|
65
|
+
assert tp_size % self.total_num_kv_heads == 0
|
66
|
+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
67
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
68
|
+
self.head_dim = config.hidden_size // self.total_num_heads
|
69
|
+
self.q_size = self.num_heads * self.head_dim
|
70
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
71
|
+
self.scaling = self.head_dim**-0.5
|
72
|
+
self.rope_theta = getattr(config, "rope_theta", 1000000)
|
73
|
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
74
|
+
|
75
|
+
self.qkv_proj = QKVParallelLinear(
|
76
|
+
self.hidden_size,
|
77
|
+
self.head_dim,
|
78
|
+
self.total_num_heads,
|
79
|
+
self.total_num_kv_heads,
|
80
|
+
bias=config.attention_bias,
|
81
|
+
quant_config=quant_config,
|
82
|
+
prefix=add_prefix("qkv_proj", prefix),
|
83
|
+
)
|
84
|
+
self.o_proj = RowParallelLinear(
|
85
|
+
self.total_num_heads * self.head_dim,
|
86
|
+
self.hidden_size,
|
87
|
+
bias=False,
|
88
|
+
quant_config=quant_config,
|
89
|
+
prefix=add_prefix("o_proj", prefix),
|
90
|
+
)
|
91
|
+
|
92
|
+
self.rotary_emb = get_rope(
|
93
|
+
self.head_dim,
|
94
|
+
rotary_dim=self.head_dim,
|
95
|
+
max_position=config.max_position_embeddings,
|
96
|
+
base=self.rope_theta,
|
97
|
+
rope_scaling=self.rope_scaling,
|
98
|
+
partial_rotary_factor=partial_rotary_factor,
|
99
|
+
is_neox_style=False,
|
100
|
+
)
|
101
|
+
self.attn = RadixAttention(
|
102
|
+
self.num_heads,
|
103
|
+
self.head_dim,
|
104
|
+
self.scaling,
|
105
|
+
num_kv_heads=self.num_kv_heads,
|
106
|
+
layer_id=layer_id,
|
107
|
+
quant_config=quant_config,
|
108
|
+
prefix=add_prefix("attn", prefix),
|
109
|
+
)
|
110
|
+
|
111
|
+
def forward(
|
112
|
+
self,
|
113
|
+
positions: torch.Tensor,
|
114
|
+
hidden_states: torch.Tensor,
|
115
|
+
forward_batch: ForwardBatch,
|
116
|
+
) -> torch.Tensor:
|
117
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
118
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
119
|
+
q, k = self.rotary_emb(positions, q, k)
|
120
|
+
context_layer = self.attn(
|
121
|
+
q,
|
122
|
+
k,
|
123
|
+
v,
|
124
|
+
forward_batch,
|
125
|
+
)
|
126
|
+
attn_output, _ = self.o_proj(context_layer)
|
127
|
+
return attn_output
|
128
|
+
|
129
|
+
|
130
|
+
class Glm4DecoderLayer(nn.Module):
|
131
|
+
"""A single transformer layer.
|
132
|
+
|
133
|
+
Transformer layer takes input with size [s, b, h] and returns an
|
134
|
+
output of the same size.
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
config,
|
140
|
+
layer_id: int,
|
141
|
+
quant_config: Optional[QuantizationConfig] = None,
|
142
|
+
prefix: str = "",
|
143
|
+
):
|
144
|
+
super().__init__()
|
145
|
+
# Self attention.
|
146
|
+
self.self_attn = Glm4Attention(
|
147
|
+
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
|
148
|
+
)
|
149
|
+
|
150
|
+
# MLP
|
151
|
+
self.mlp = Glm4MLP(
|
152
|
+
config.hidden_size,
|
153
|
+
intermediate_size=config.intermediate_size,
|
154
|
+
hidden_act=config.hidden_act,
|
155
|
+
quant_config=quant_config,
|
156
|
+
prefix=add_prefix("mlp", prefix),
|
157
|
+
)
|
158
|
+
|
159
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
160
|
+
self.post_attention_layernorm = RMSNorm(
|
161
|
+
config.hidden_size, eps=config.rms_norm_eps
|
162
|
+
)
|
163
|
+
self.post_self_attn_layernorm = RMSNorm(
|
164
|
+
config.hidden_size, eps=config.rms_norm_eps
|
165
|
+
)
|
166
|
+
self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
167
|
+
|
168
|
+
def forward(
|
169
|
+
self,
|
170
|
+
positions: torch.Tensor,
|
171
|
+
hidden_states: torch.Tensor,
|
172
|
+
forward_batch: ForwardBatch,
|
173
|
+
residual: Optional[torch.Tensor],
|
174
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175
|
+
# Self Attention
|
176
|
+
if residual is None:
|
177
|
+
residual = hidden_states
|
178
|
+
hidden_states = self.input_layernorm(hidden_states)
|
179
|
+
else:
|
180
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
181
|
+
hidden_states = self.self_attn(
|
182
|
+
positions=positions,
|
183
|
+
hidden_states=hidden_states,
|
184
|
+
forward_batch=forward_batch,
|
185
|
+
)
|
186
|
+
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
187
|
+
|
188
|
+
# Fully Connected
|
189
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
190
|
+
hidden_states = self.mlp(hidden_states)
|
191
|
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
192
|
+
|
193
|
+
return hidden_states, residual
|
194
|
+
|
195
|
+
|
196
|
+
class Glm4Model(nn.Module):
|
197
|
+
def __init__(
|
198
|
+
self,
|
199
|
+
config: Glm4Config,
|
200
|
+
quant_config: Optional[QuantizationConfig] = None,
|
201
|
+
prefix: str = "",
|
202
|
+
) -> None:
|
203
|
+
super().__init__()
|
204
|
+
self.config = config
|
205
|
+
self.embed_tokens = VocabParallelEmbedding(
|
206
|
+
config.vocab_size,
|
207
|
+
config.hidden_size,
|
208
|
+
quant_config=quant_config,
|
209
|
+
prefix=add_prefix("embed_tokens", prefix),
|
210
|
+
)
|
211
|
+
self.layers = make_layers(
|
212
|
+
config.num_hidden_layers,
|
213
|
+
lambda idx, prefix: Glm4DecoderLayer(
|
214
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
215
|
+
),
|
216
|
+
prefix="model.layers",
|
217
|
+
)
|
218
|
+
|
219
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
220
|
+
|
221
|
+
@torch.no_grad()
|
222
|
+
def forward(
|
223
|
+
self,
|
224
|
+
input_ids: torch.Tensor,
|
225
|
+
positions: torch.Tensor,
|
226
|
+
forward_batch: ForwardBatch,
|
227
|
+
input_embeds: torch.Tensor = None,
|
228
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
229
|
+
if input_embeds is None:
|
230
|
+
hidden_states = self.embed_tokens(input_ids)
|
231
|
+
else:
|
232
|
+
hidden_states = input_embeds
|
233
|
+
residual = None
|
234
|
+
for layer in self.layers:
|
235
|
+
hidden_states, residual = layer(
|
236
|
+
positions,
|
237
|
+
hidden_states,
|
238
|
+
forward_batch,
|
239
|
+
residual,
|
240
|
+
)
|
241
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
242
|
+
|
243
|
+
return hidden_states
|
244
|
+
|
245
|
+
|
246
|
+
class Glm4ForCausalLM(nn.Module):
|
247
|
+
def __init__(
|
248
|
+
self,
|
249
|
+
config: Glm4Config,
|
250
|
+
quant_config: Optional[QuantizationConfig] = None,
|
251
|
+
prefix: str = "",
|
252
|
+
):
|
253
|
+
super().__init__()
|
254
|
+
self.config: Glm4Config = config
|
255
|
+
self.quant_config = quant_config
|
256
|
+
self.model = Glm4Model(config, quant_config, add_prefix("model", prefix))
|
257
|
+
if config.tie_word_embeddings:
|
258
|
+
self.lm_head = self.model.embed_tokens
|
259
|
+
else:
|
260
|
+
self.lm_head = ParallelLMHead(
|
261
|
+
config.vocab_size,
|
262
|
+
config.hidden_size,
|
263
|
+
quant_config=quant_config,
|
264
|
+
prefix="lm_head",
|
265
|
+
)
|
266
|
+
self.logits_processor = LogitsProcessor(config)
|
267
|
+
|
268
|
+
@torch.no_grad()
|
269
|
+
def forward(
|
270
|
+
self,
|
271
|
+
input_ids: torch.Tensor,
|
272
|
+
positions: torch.Tensor,
|
273
|
+
forward_batch: ForwardBatch,
|
274
|
+
) -> torch.Tensor:
|
275
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
276
|
+
return self.logits_processor(
|
277
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
278
|
+
)
|
279
|
+
|
280
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
281
|
+
stacked_params_mapping = [
|
282
|
+
# (param_name, weight_name, shard_id)
|
283
|
+
(".qkv_proj", ".q_proj", "q"),
|
284
|
+
(".qkv_proj", ".k_proj", "k"),
|
285
|
+
(".qkv_proj", ".v_proj", "v"),
|
286
|
+
(".gate_up_proj", ".gate_proj", 0),
|
287
|
+
(".gate_up_proj", ".up_proj", 1),
|
288
|
+
]
|
289
|
+
params_dict = dict(self.named_parameters())
|
290
|
+
for name, loaded_weight in weights:
|
291
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
292
|
+
continue
|
293
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
294
|
+
if weight_name not in name:
|
295
|
+
continue
|
296
|
+
name = name.replace(weight_name, param_name)
|
297
|
+
param = params_dict[name]
|
298
|
+
weight_loader = param.weight_loader
|
299
|
+
weight_loader(param, loaded_weight, shard_id)
|
300
|
+
break
|
301
|
+
else:
|
302
|
+
if name in params_dict.keys():
|
303
|
+
param = params_dict[name]
|
304
|
+
weight_loader = getattr(
|
305
|
+
param, "weight_loader", default_weight_loader
|
306
|
+
)
|
307
|
+
weight_loader(param, loaded_weight)
|
308
|
+
else:
|
309
|
+
raise KeyError(f"Parameter '{name}' not found in model.")
|
310
|
+
|
311
|
+
|
312
|
+
EntryClass = [Glm4ForCausalLM]
|
sglang/srt/models/mimo_mtp.py
CHANGED
@@ -7,33 +7,17 @@ import torch
|
|
7
7
|
from torch import nn
|
8
8
|
from transformers import PretrainedConfig
|
9
9
|
|
10
|
-
from sglang.srt.distributed import
|
11
|
-
get_tensor_model_parallel_rank,
|
12
|
-
get_tensor_model_parallel_world_size,
|
13
|
-
split_tensor_along_last_dim,
|
14
|
-
tensor_model_parallel_all_gather,
|
15
|
-
)
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
16
11
|
from sglang.srt.layers.layernorm import RMSNorm
|
17
|
-
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
18
12
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
23
14
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
15
|
ParallelLMHead,
|
25
16
|
VocabParallelEmbedding,
|
26
17
|
)
|
27
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
19
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
-
from sglang.srt.models.
|
30
|
-
from sglang.srt.models.qwen2 import (
|
31
|
-
Qwen2Attention,
|
32
|
-
Qwen2DecoderLayer,
|
33
|
-
Qwen2MLP,
|
34
|
-
Qwen2Model,
|
35
|
-
)
|
36
|
-
from sglang.srt.utils import add_prefix
|
20
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer
|
37
21
|
|
38
22
|
|
39
23
|
class MiMoMultiTokenPredictorLayer(nn.Module):
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Tuple
|
1
|
+
from typing import Dict, Optional, Tuple, Type
|
2
2
|
|
3
3
|
|
4
4
|
class StreamingParseResult:
|
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
|
36
|
-
|
35
|
+
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
|
36
|
+
|
37
|
+
if not in_reasoning:
|
38
|
+
return StreamingParseResult(normal_text=text)
|
39
|
+
|
40
|
+
# The text is considered to be in a reasoning block.
|
41
|
+
processed_text = text.replace(self.think_start_token, "").strip()
|
42
|
+
|
43
|
+
if self.think_end_token not in processed_text:
|
37
44
|
# Assume reasoning was truncated before `</think>` token
|
38
|
-
return StreamingParseResult(reasoning_text=
|
45
|
+
return StreamingParseResult(reasoning_text=processed_text)
|
39
46
|
|
40
47
|
# Extract reasoning content
|
41
|
-
splits =
|
48
|
+
splits = processed_text.split(self.think_end_token, maxsplit=1)
|
42
49
|
reasoning_text = splits[0]
|
43
|
-
|
50
|
+
normal_text = splits[1].strip()
|
44
51
|
|
45
|
-
return StreamingParseResult(
|
52
|
+
return StreamingParseResult(
|
53
|
+
normal_text=normal_text, reasoning_text=reasoning_text
|
54
|
+
)
|
46
55
|
|
47
56
|
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
48
57
|
"""
|
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
|
|
61
70
|
if not self.stripped_think_start and self.think_start_token in current_text:
|
62
71
|
current_text = current_text.replace(self.think_start_token, "")
|
63
72
|
self.stripped_think_start = True
|
73
|
+
self._in_reasoning = True
|
64
74
|
|
65
75
|
# Handle end of reasoning block
|
66
76
|
if self._in_reasoning and self.think_end_token in current_text:
|
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
|
|
131
141
|
"""
|
132
142
|
|
133
143
|
def __init__(self, stream_reasoning: bool = True):
|
134
|
-
# Qwen3
|
144
|
+
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
|
135
145
|
super().__init__(
|
136
146
|
"<think>",
|
137
147
|
"</think>",
|
138
|
-
force_reasoning=
|
148
|
+
force_reasoning=False,
|
139
149
|
stream_reasoning=stream_reasoning,
|
140
150
|
)
|
141
151
|
|
@@ -151,12 +161,12 @@ class ReasoningParser:
|
|
151
161
|
If True, streams reasoning content as it arrives.
|
152
162
|
"""
|
153
163
|
|
154
|
-
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
164
|
+
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
155
165
|
"deepseek-r1": DeepSeekR1Detector,
|
156
166
|
"qwen3": Qwen3Detector,
|
157
167
|
}
|
158
168
|
|
159
|
-
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
169
|
+
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
160
170
|
if not model_type:
|
161
171
|
raise ValueError("Model type must be specified")
|
162
172
|
|
sglang/srt/server_args.py
CHANGED
@@ -152,6 +152,7 @@ class ServerArgs:
|
|
152
152
|
ep_size: int = 1
|
153
153
|
enable_ep_moe: bool = False
|
154
154
|
enable_deepep_moe: bool = False
|
155
|
+
enable_flashinfer_moe: bool = False
|
155
156
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
156
157
|
ep_num_redundant_experts: int = 0
|
157
158
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
@@ -234,6 +235,10 @@ class ServerArgs:
|
|
234
235
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
235
236
|
pdlb_url: Optional[str] = None
|
236
237
|
|
238
|
+
# For model weight update
|
239
|
+
custom_weight_loader: Optional[List[str]] = None
|
240
|
+
weight_loader_disable_mmap: bool = False
|
241
|
+
|
237
242
|
def __post_init__(self):
|
238
243
|
# Expert parallelism
|
239
244
|
if self.enable_ep_moe:
|
@@ -241,7 +246,15 @@ class ServerArgs:
|
|
241
246
|
logger.warning(
|
242
247
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
243
248
|
)
|
244
|
-
|
249
|
+
if self.enable_flashinfer_moe:
|
250
|
+
assert (
|
251
|
+
self.quantization == "modelopt_fp4"
|
252
|
+
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
253
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
254
|
+
self.disable_shared_experts_fusion = True
|
255
|
+
logger.warning(
|
256
|
+
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
257
|
+
)
|
245
258
|
# Set missing default values
|
246
259
|
if self.tokenizer_path is None:
|
247
260
|
self.tokenizer_path = self.model_path
|
@@ -384,7 +397,6 @@ class ServerArgs:
|
|
384
397
|
), "Please enable dp attention when setting enable_dp_attention. "
|
385
398
|
|
386
399
|
# DeepEP MoE
|
387
|
-
self.enable_sp_layernorm = False
|
388
400
|
if self.enable_deepep_moe:
|
389
401
|
if self.deepep_mode == "auto":
|
390
402
|
assert (
|
@@ -394,9 +406,6 @@ class ServerArgs:
|
|
394
406
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
395
407
|
self.disable_cuda_graph = True
|
396
408
|
self.ep_size = self.tp_size
|
397
|
-
self.enable_sp_layernorm = (
|
398
|
-
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
399
|
-
)
|
400
409
|
logger.warning(
|
401
410
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
402
411
|
)
|
@@ -538,6 +547,9 @@ class ServerArgs:
|
|
538
547
|
"1" if self.disable_outlines_disk_cache else "0"
|
539
548
|
)
|
540
549
|
|
550
|
+
if self.custom_weight_loader is None:
|
551
|
+
self.custom_weight_loader = []
|
552
|
+
|
541
553
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
542
554
|
larger_tp = max(decode_tp, prefill_tp)
|
543
555
|
smaller_tp = min(decode_tp, prefill_tp)
|
@@ -1160,6 +1172,11 @@ class ServerArgs:
|
|
1160
1172
|
action="store_true",
|
1161
1173
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
1162
1174
|
)
|
1175
|
+
parser.add_argument(
|
1176
|
+
"--enable-flashinfer-moe",
|
1177
|
+
action="store_true",
|
1178
|
+
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1179
|
+
)
|
1163
1180
|
parser.add_argument(
|
1164
1181
|
"--enable-deepep-moe",
|
1165
1182
|
action="store_true",
|
@@ -1576,6 +1593,18 @@ class ServerArgs:
|
|
1576
1593
|
default=None,
|
1577
1594
|
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
1578
1595
|
)
|
1596
|
+
parser.add_argument(
|
1597
|
+
"--custom-weight-loader",
|
1598
|
+
type=str,
|
1599
|
+
nargs="*",
|
1600
|
+
default=None,
|
1601
|
+
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
1602
|
+
)
|
1603
|
+
parser.add_argument(
|
1604
|
+
"--weight-loader-disable-mmap",
|
1605
|
+
action="store_true",
|
1606
|
+
help="Disable mmap while loading weight using safetensors.",
|
1607
|
+
)
|
1579
1608
|
|
1580
1609
|
@classmethod
|
1581
1610
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -1700,9 +1729,8 @@ class PortArgs:
|
|
1700
1729
|
dist_init_host, dist_init_port = dist_init_addr
|
1701
1730
|
port_base = int(dist_init_port) + 1
|
1702
1731
|
if dp_rank is None:
|
1703
|
-
|
1704
|
-
|
1705
|
-
) # TokenizerManager to DataParallelController
|
1732
|
+
# TokenizerManager to DataParallelController
|
1733
|
+
scheduler_input_port = port_base + 3
|
1706
1734
|
else:
|
1707
1735
|
scheduler_input_port = port_base + 3 + 1 + dp_rank
|
1708
1736
|
|