sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ # Adapted from qwen2_moe.py
2
+
3
+ # Copyright 2023-2024 SGLang Team
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
+
20
+ from functools import partial
21
+ from typing import Any, Dict, Iterable, Optional, Tuple
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from sglang.srt.distributed import (
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ split_tensor_along_last_dim,
31
+ tensor_model_parallel_all_gather,
32
+ tensor_model_parallel_all_reduce,
33
+ )
34
+ from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import (
37
+ MergedColumnParallelLinear,
38
+ QKVParallelLinear,
39
+ ReplicatedLinear,
40
+ RowParallelLinear,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
+ from sglang.srt.layers.radix_attention import RadixAttention
46
+ from sglang.srt.layers.rotary_embedding import get_rope
47
+ from sglang.srt.layers.vocab_parallel_embedding import (
48
+ ParallelLMHead,
49
+ VocabParallelEmbedding,
50
+ )
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
53
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
54
+ from sglang.srt.models.qwen2_moe import Qwen2MoeModel
55
+ from sglang.srt.utils import add_prefix
56
+
57
+ Qwen3MoeConfig = None
58
+
59
+
60
+ class Qwen3MoeSparseMoeBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ config: Qwen3MoeConfig,
64
+ quant_config: Optional[QuantizationConfig] = None,
65
+ prefix: str = "",
66
+ ):
67
+ super().__init__()
68
+ self.tp_size = get_tensor_model_parallel_world_size()
69
+
70
+ if self.tp_size > config.num_experts:
71
+ raise ValueError(
72
+ f"Tensor parallel size {self.tp_size} is greater than "
73
+ f"the number of experts {config.num_experts}."
74
+ )
75
+
76
+ self.experts = FusedMoE(
77
+ num_experts=config.num_experts,
78
+ top_k=config.num_experts_per_tok,
79
+ hidden_size=config.hidden_size,
80
+ intermediate_size=config.moe_intermediate_size,
81
+ reduce_results=False,
82
+ renormalize=config.norm_topk_prob,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("experts", prefix),
85
+ )
86
+
87
+ self.gate = ReplicatedLinear(
88
+ config.hidden_size,
89
+ config.num_experts,
90
+ bias=False,
91
+ quant_config=None,
92
+ prefix=add_prefix("gate", prefix),
93
+ )
94
+
95
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
+ num_tokens, hidden_dim = hidden_states.shape
97
+ hidden_states = hidden_states.view(-1, hidden_dim)
98
+
99
+ # router_logits: (num_tokens, n_experts)
100
+ router_logits, _ = self.gate(hidden_states)
101
+ final_hidden_states = self.experts(
102
+ hidden_states=hidden_states, router_logits=router_logits
103
+ )
104
+ if self.tp_size > 1:
105
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
106
+
107
+ return final_hidden_states.view(num_tokens, hidden_dim)
108
+
109
+
110
+ class Qwen3MoeAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ num_heads: int,
115
+ num_kv_heads: int,
116
+ layer_id: int = 0,
117
+ rope_theta: float = 10000,
118
+ rope_scaling: Optional[Dict[str, Any]] = None,
119
+ max_position_embeddings: int = 8192,
120
+ head_dim: Optional[int] = None,
121
+ rms_norm_eps: float = 1e-06,
122
+ attention_bias: bool = False,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ prefix: str = "",
125
+ ) -> None:
126
+ super().__init__()
127
+ self.hidden_size = hidden_size
128
+ self.tp_size = get_tensor_model_parallel_world_size()
129
+ self.total_num_heads = num_heads
130
+ assert self.total_num_heads % self.tp_size == 0
131
+ self.num_heads = self.total_num_heads // self.tp_size
132
+ self.total_num_kv_heads = num_kv_heads
133
+ if self.total_num_kv_heads >= self.tp_size:
134
+ # Number of KV heads is greater than TP size, so we partition
135
+ # the KV heads across multiple tensor parallel GPUs.
136
+ assert self.total_num_kv_heads % self.tp_size == 0
137
+ else:
138
+ # Number of KV heads is less than TP size, so we replicate
139
+ # the KV heads across multiple tensor parallel GPUs.
140
+ assert self.tp_size % self.total_num_kv_heads == 0
141
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
142
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
143
+ self.q_size = self.num_heads * self.head_dim
144
+ self.kv_size = self.num_kv_heads * self.head_dim
145
+ self.scaling = self.head_dim**-0.5
146
+ self.rope_theta = rope_theta
147
+ self.max_position_embeddings = max_position_embeddings
148
+ self.tp_rank = get_tensor_model_parallel_rank()
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=attention_bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("qkv_proj", prefix),
158
+ )
159
+
160
+ self.o_proj = RowParallelLinear(
161
+ self.total_num_heads * self.head_dim,
162
+ hidden_size,
163
+ bias=attention_bias,
164
+ quant_config=quant_config,
165
+ prefix=add_prefix("o_proj", prefix),
166
+ )
167
+
168
+ self.rotary_emb = get_rope(
169
+ self.head_dim,
170
+ rotary_dim=self.head_dim,
171
+ max_position=max_position_embeddings,
172
+ base=rope_theta,
173
+ rope_scaling=rope_scaling,
174
+ )
175
+ self.attn = RadixAttention(
176
+ self.num_heads,
177
+ self.head_dim,
178
+ self.scaling,
179
+ num_kv_heads=self.num_kv_heads,
180
+ layer_id=layer_id,
181
+ prefix=add_prefix("attn", prefix),
182
+ )
183
+
184
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
185
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
186
+
187
+ def _apply_qk_norm(
188
+ self, q: torch.Tensor, k: torch.Tensor
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ q_by_head = q.reshape(-1, self.head_dim)
191
+ q_by_head = self.q_norm(q_by_head)
192
+ q = q_by_head.view(q.shape)
193
+ k_by_head = k.reshape(-1, self.head_dim)
194
+ k_by_head = self.k_norm(k_by_head)
195
+ k = k_by_head.view(k.shape)
196
+ return q, k
197
+
198
+ def forward(
199
+ self,
200
+ positions: torch.Tensor,
201
+ hidden_states: torch.Tensor,
202
+ forward_batch: ForwardBatch,
203
+ ) -> torch.Tensor:
204
+ qkv, _ = self.qkv_proj(hidden_states)
205
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
206
+ q, k = self._apply_qk_norm(q, k)
207
+ q, k = self.rotary_emb(positions, q, k)
208
+ attn_output = self.attn(q, k, v, forward_batch)
209
+ output, _ = self.o_proj(attn_output)
210
+ return output
211
+
212
+
213
+ class Qwen3MoeDecoderLayer(nn.Module):
214
+ def __init__(
215
+ self,
216
+ config: Qwen3MoeConfig,
217
+ layer_id: int,
218
+ quant_config: Optional[QuantizationConfig] = None,
219
+ prefix: str = "",
220
+ ) -> None:
221
+ super().__init__()
222
+ self.hidden_size = config.hidden_size
223
+ rope_theta = getattr(config, "rope_theta", 10000)
224
+ rope_scaling = getattr(config, "rope_scaling", None)
225
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
226
+ head_dim = getattr(
227
+ config, "head_dim", config.hidden_size // config.num_attention_heads
228
+ )
229
+ rms_norm_eps = config.rms_norm_eps
230
+ attention_bias = config.attention_bias
231
+ self.self_attn = Qwen3MoeAttention(
232
+ hidden_size=self.hidden_size,
233
+ num_heads=config.num_attention_heads,
234
+ num_kv_heads=config.num_key_value_heads,
235
+ layer_id=layer_id,
236
+ rope_theta=rope_theta,
237
+ rope_scaling=rope_scaling,
238
+ max_position_embeddings=max_position_embeddings,
239
+ head_dim=head_dim,
240
+ rms_norm_eps=rms_norm_eps,
241
+ attention_bias=attention_bias,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("self_attn", prefix),
244
+ )
245
+
246
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
247
+ # `mlp_only_layers` in the config.
248
+ mlp_only_layers = (
249
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
250
+ )
251
+ if (layer_id not in mlp_only_layers) and (
252
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
253
+ ):
254
+ self.mlp = Qwen3MoeSparseMoeBlock(
255
+ config=config,
256
+ quant_config=quant_config,
257
+ prefix=add_prefix("mlp", prefix),
258
+ )
259
+ else:
260
+ self.mlp = Qwen3MoeMLP(
261
+ hidden_size=config.hidden_size,
262
+ intermediate_size=config.intermediate_size,
263
+ hidden_act=config.hidden_act,
264
+ quant_config=quant_config,
265
+ prefix=add_prefix("mlp", prefix),
266
+ )
267
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268
+ self.post_attention_layernorm = RMSNorm(
269
+ config.hidden_size, eps=config.rms_norm_eps
270
+ )
271
+
272
+ def forward(
273
+ self,
274
+ positions: torch.Tensor,
275
+ hidden_states: torch.Tensor,
276
+ forward_batch: ForwardBatch,
277
+ residual: Optional[torch.Tensor],
278
+ ) -> torch.Tensor:
279
+ # Self Attention
280
+ if residual is None:
281
+ residual = hidden_states
282
+ hidden_states = self.input_layernorm(hidden_states)
283
+ else:
284
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
285
+ hidden_states = self.self_attn(
286
+ positions=positions,
287
+ hidden_states=hidden_states,
288
+ forward_batch=forward_batch,
289
+ )
290
+
291
+ # Fully Connected
292
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
293
+ hidden_states = self.mlp(hidden_states)
294
+ return hidden_states, residual
295
+
296
+
297
+ class Qwen3MoeModel(Qwen2MoeModel):
298
+ def __init__(
299
+ self,
300
+ config: Qwen3MoeConfig,
301
+ quant_config: Optional[QuantizationConfig] = None,
302
+ prefix: str = "",
303
+ ) -> None:
304
+ super().__init__(
305
+ config=config,
306
+ quant_config=quant_config,
307
+ prefix=prefix,
308
+ decoder_layer_type=Qwen3MoeDecoderLayer,
309
+ )
310
+
311
+
312
+ class Qwen3MoeForCausalLM(nn.Module):
313
+
314
+ fall_back_to_pt_during_load = False
315
+
316
+ def __init__(
317
+ self,
318
+ config: Qwen3MoeConfig,
319
+ quant_config: Optional[QuantizationConfig] = None,
320
+ prefix: str = "",
321
+ ) -> None:
322
+ super().__init__()
323
+ self.config = config
324
+ self.quant_config = quant_config
325
+ self.model = Qwen3MoeModel(
326
+ config, quant_config, prefix=add_prefix("model", prefix)
327
+ )
328
+ self.lm_head = ParallelLMHead(
329
+ config.vocab_size,
330
+ config.hidden_size,
331
+ quant_config=quant_config,
332
+ prefix=add_prefix("lm_head", prefix),
333
+ )
334
+ self.logits_processor = LogitsProcessor(config)
335
+
336
+ @torch.no_grad()
337
+ def forward(
338
+ self,
339
+ input_ids: torch.Tensor,
340
+ positions: torch.Tensor,
341
+ forward_batch: ForwardBatch,
342
+ input_embeds: torch.Tensor = None,
343
+ ) -> torch.Tensor:
344
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
+ return self.logits_processor(
346
+ input_ids, hidden_states, self.lm_head, forward_batch
347
+ )
348
+
349
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
350
+ stacked_params_mapping = [
351
+ # (param_name, shard_name, shard_id)
352
+ ("qkv_proj", "q_proj", "q"),
353
+ ("qkv_proj", "k_proj", "k"),
354
+ ("qkv_proj", "v_proj", "v"),
355
+ ("gate_up_proj", "gate_proj", 0),
356
+ ("gate_up_proj", "up_proj", 1),
357
+ ]
358
+
359
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
360
+ ckpt_gate_proj_name="gate_proj",
361
+ ckpt_down_proj_name="down_proj",
362
+ ckpt_up_proj_name="up_proj",
363
+ num_experts=self.config.num_experts,
364
+ )
365
+
366
+ params_dict = dict(self.named_parameters())
367
+ for name, loaded_weight in weights:
368
+ if "rotary_emb.inv_freq" in name:
369
+ continue
370
+ for param_name, weight_name, shard_id in stacked_params_mapping:
371
+ # Skip non-stacked layers and experts (experts handled below).
372
+ if weight_name not in name:
373
+ continue
374
+ # We have mlp.experts[0].gate_proj in the checkpoint.
375
+ # Since we handle the experts below in expert_params_mapping,
376
+ # we need to skip here BEFORE we update the name, otherwise
377
+ # name will be updated to mlp.experts[0].gate_up_proj, which
378
+ # will then be updated below in expert_params_mapping
379
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
380
+ if "mlp.experts" in name:
381
+ continue
382
+ name = name.replace(weight_name, param_name)
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ continue
386
+ if name not in params_dict:
387
+ continue
388
+
389
+ param = params_dict[name]
390
+ weight_loader = param.weight_loader
391
+ weight_loader(param, loaded_weight, shard_id)
392
+ break
393
+ else:
394
+ for mapping in expert_params_mapping:
395
+ param_name, weight_name, expert_id, shard_id = mapping
396
+ if weight_name not in name:
397
+ continue
398
+ name = name.replace(weight_name, param_name)
399
+ param = params_dict[name]
400
+ weight_loader = param.weight_loader
401
+ weight_loader(
402
+ param,
403
+ loaded_weight,
404
+ name,
405
+ shard_id=shard_id,
406
+ expert_id=expert_id,
407
+ )
408
+ break
409
+ else:
410
+ # Skip loading extra bias for GPTQ models.
411
+ if name.endswith(".bias") and name not in params_dict:
412
+ continue
413
+ if name not in params_dict:
414
+ continue
415
+
416
+ param = params_dict[name]
417
+ weight_loader = getattr(
418
+ param, "weight_loader", default_weight_loader
419
+ )
420
+ weight_loader(param, loaded_weight)
421
+
422
+
423
+ EntryClass = Qwen3MoeForCausalLM
@@ -938,6 +938,35 @@ def v1_chat_generate_request(
938
938
 
939
939
  if chat_template_name is None:
940
940
  openai_compatible_messages = []
941
+ if (
942
+ tools
943
+ and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
944
+ ):
945
+ # add function call prompt to deepseekv3
946
+ openai_compatible_messages.append(
947
+ {
948
+ "role": "system",
949
+ "content": """You are a helpful Assistant.
950
+ ## Tools
951
+ ### Function
952
+ You have the following functions available:
953
+ """
954
+ + "".join(
955
+ [
956
+ f"""
957
+ - `{tool['name']}`:
958
+ ```json
959
+ {json.dumps(tool)}
960
+ ```
961
+ """
962
+ for tool in tools
963
+ ]
964
+ ),
965
+ }
966
+ )
967
+ # TODO fix the compatible issues with xgrammar
968
+ strict_tag = None
969
+
941
970
  for message in request.messages:
942
971
  if isinstance(message.content, str):
943
972
  openai_compatible_messages.append(
@@ -950,9 +979,16 @@ def v1_chat_generate_request(
950
979
  openai_compatible_messages.append(
951
980
  {"role": message.role, "content": content["text"]}
952
981
  )
953
- if openai_compatible_messages[-1]["role"] == "assistant":
954
- assistant_prefix = openai_compatible_messages[-1]["content"]
955
- openai_compatible_messages = openai_compatible_messages[:-1]
982
+ if (
983
+ openai_compatible_messages
984
+ and openai_compatible_messages[-1]["role"] == "assistant"
985
+ ):
986
+ if request.continue_final_message:
987
+ # Remove the final assistant message so its content can be continued.
988
+ assistant_prefix = openai_compatible_messages[-1]["content"]
989
+ openai_compatible_messages = openai_compatible_messages[:-1]
990
+ else:
991
+ assistant_prefix = None
956
992
  else:
957
993
  assistant_prefix = None
958
994
 
@@ -991,7 +1027,33 @@ def v1_chat_generate_request(
991
1027
  modalities = []
992
1028
  else:
993
1029
  conv = generate_chat_conv(request, chat_template_name)
994
- prompt = conv.get_prompt()
1030
+ # If we should continue the final assistant message, adjust the conversation.
1031
+ if (
1032
+ request.continue_final_message
1033
+ and request.messages
1034
+ and request.messages[-1].role == "assistant"
1035
+ ):
1036
+ # Remove the auto-added blank assistant turn, if present.
1037
+ if conv.messages and conv.messages[-1][1] is None:
1038
+ conv.messages.pop()
1039
+ # Rebuild the prompt from the conversation.
1040
+ prompt = conv.get_prompt()
1041
+ # Strip any trailing stop tokens or separators that indicate end-of-assistant.
1042
+ if isinstance(conv.stop_str, list):
1043
+ for stop_token in conv.stop_str:
1044
+ if prompt.endswith(stop_token):
1045
+ prompt = prompt[: -len(stop_token)]
1046
+ elif isinstance(conv.stop_str, str) and prompt.endswith(
1047
+ conv.stop_str
1048
+ ):
1049
+ prompt = prompt[: -len(conv.stop_str)]
1050
+ if conv.sep and prompt.endswith(conv.sep):
1051
+ prompt = prompt[: -len(conv.sep)]
1052
+ if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
1053
+ prompt = prompt[: -len(conv.sep2)]
1054
+ else:
1055
+ prompt = conv.get_prompt()
1056
+
995
1057
  image_data = conv.image_data
996
1058
  audio_data = conv.audio_data
997
1059
  modalities = conv.modalities
@@ -1003,6 +1065,7 @@ def v1_chat_generate_request(
1003
1065
  else:
1004
1066
  stop.extend(request.stop)
1005
1067
  prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
1068
+
1006
1069
  else:
1007
1070
  # Use the raw prompt and stop strings if the messages is already a string.
1008
1071
  prompt_ids = request.messages
@@ -1042,6 +1105,8 @@ def v1_chat_generate_request(
1042
1105
  sampling_params["json_schema"] = convert_json_schema_to_str(
1043
1106
  request.response_format.json_schema.schema_
1044
1107
  )
1108
+ elif request.response_format and request.response_format.type == "json_object":
1109
+ sampling_params["json_schema"] = '{"type": "object"}'
1045
1110
  elif (
1046
1111
  request.response_format and request.response_format.type == "structural_tag"
1047
1112
  ):
@@ -1109,6 +1174,8 @@ def v1_chat_generate_request(
1109
1174
  rid=request_ids,
1110
1175
  modalities=modalities_list,
1111
1176
  lora_path=lora_paths,
1177
+ bootstrap_host=all_requests[0].bootstrap_host,
1178
+ bootstrap_room=all_requests[0].bootstrap_room,
1112
1179
  )
1113
1180
 
1114
1181
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -252,7 +252,7 @@ ChatCompletionMessageContentPart = Union[
252
252
 
253
253
  class ChatCompletionMessageGenericParam(BaseModel):
254
254
  role: Literal["system", "assistant", "tool"]
255
- content: Union[str, List[ChatCompletionMessageContentTextPart]]
255
+ content: Union[str, List[ChatCompletionMessageContentTextPart], None]
256
256
 
257
257
 
258
258
  class ChatCompletionMessageUserParam(BaseModel):
@@ -355,12 +355,17 @@ class ChatCompletionRequest(BaseModel):
355
355
  stop_token_ids: Optional[List[int]] = None
356
356
  no_stop_trim: bool = False
357
357
  ignore_eos: bool = False
358
+ continue_final_message: bool = False
358
359
  skip_special_tokens: bool = True
359
360
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
360
361
  session_params: Optional[Dict] = None
361
362
  separate_reasoning: bool = True
362
363
  stream_reasoning: bool = True
363
364
 
365
+ # For PD disaggregation
366
+ bootstrap_host: Optional[str] = None
367
+ bootstrap_room: Optional[int] = None
368
+
364
369
 
365
370
  class FunctionResponse(BaseModel):
366
371
  """Function response."""
@@ -1,4 +1,3 @@
1
- import re
2
1
  from typing import Dict, Tuple
3
2
 
4
3
 
@@ -10,12 +10,11 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
 
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
13
  if TYPE_CHECKING:
17
14
  from sglang.srt.managers.schedule_batch import ScheduleBatch
18
15
 
16
+ logger = logging.getLogger(__name__)
17
+
19
18
 
20
19
  @dataclasses.dataclass
21
20
  class SamplingBatchInfo: