sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Modeling from:
16
+ # ./llama.py and
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
18
+ """Inference-only GLM4 model compatible with THUDM weights."""
19
+
20
+ from typing import Iterable, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import Glm4Config
25
+
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
29
+ from sglang.srt.layers.logits_processor import LogitsProcessor
30
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
+ from sglang.srt.layers.radix_attention import RadixAttention
32
+ from sglang.srt.layers.rotary_embedding import get_rope
33
+ from sglang.srt.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.models.llama import LlamaMLP as Glm4MLP
40
+ from sglang.srt.utils import add_prefix, make_layers
41
+
42
+
43
+ class Glm4Attention(nn.Module):
44
+ def __init__(
45
+ self,
46
+ config,
47
+ layer_id: int = 0,
48
+ quant_config: Optional[QuantizationConfig] = None,
49
+ prefix: str = "",
50
+ ):
51
+ super().__init__()
52
+ self.hidden_size = config.hidden_size
53
+ tp_size = get_tensor_model_parallel_world_size()
54
+ self.total_num_heads = config.num_attention_heads
55
+ assert self.total_num_heads % tp_size == 0
56
+ self.num_heads = self.total_num_heads // tp_size
57
+ self.total_num_kv_heads = config.num_key_value_heads
58
+ if self.total_num_kv_heads >= tp_size:
59
+ # Number of KV heads is greater than TP size, so we partition
60
+ # the KV heads across multiple tensor parallel GPUs.
61
+ assert self.total_num_kv_heads % tp_size == 0
62
+ else:
63
+ # Number of KV heads is less than TP size, so we replicate
64
+ # the KV heads across multiple tensor parallel GPUs.
65
+ assert tp_size % self.total_num_kv_heads == 0
66
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
67
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
68
+ self.head_dim = config.hidden_size // self.total_num_heads
69
+ self.q_size = self.num_heads * self.head_dim
70
+ self.kv_size = self.num_kv_heads * self.head_dim
71
+ self.scaling = self.head_dim**-0.5
72
+ self.rope_theta = getattr(config, "rope_theta", 1000000)
73
+ self.rope_scaling = getattr(config, "rope_scaling", None)
74
+
75
+ self.qkv_proj = QKVParallelLinear(
76
+ self.hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ self.total_num_kv_heads,
80
+ bias=config.attention_bias,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("qkv_proj", prefix),
83
+ )
84
+ self.o_proj = RowParallelLinear(
85
+ self.total_num_heads * self.head_dim,
86
+ self.hidden_size,
87
+ bias=False,
88
+ quant_config=quant_config,
89
+ prefix=add_prefix("o_proj", prefix),
90
+ )
91
+
92
+ self.rotary_emb = get_rope(
93
+ self.head_dim,
94
+ rotary_dim=self.head_dim,
95
+ max_position=config.max_position_embeddings,
96
+ base=self.rope_theta,
97
+ rope_scaling=self.rope_scaling,
98
+ partial_rotary_factor=partial_rotary_factor,
99
+ is_neox_style=False,
100
+ )
101
+ self.attn = RadixAttention(
102
+ self.num_heads,
103
+ self.head_dim,
104
+ self.scaling,
105
+ num_kv_heads=self.num_kv_heads,
106
+ layer_id=layer_id,
107
+ quant_config=quant_config,
108
+ prefix=add_prefix("attn", prefix),
109
+ )
110
+
111
+ def forward(
112
+ self,
113
+ positions: torch.Tensor,
114
+ hidden_states: torch.Tensor,
115
+ forward_batch: ForwardBatch,
116
+ ) -> torch.Tensor:
117
+ qkv, _ = self.qkv_proj(hidden_states)
118
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
119
+ q, k = self.rotary_emb(positions, q, k)
120
+ context_layer = self.attn(
121
+ q,
122
+ k,
123
+ v,
124
+ forward_batch,
125
+ )
126
+ attn_output, _ = self.o_proj(context_layer)
127
+ return attn_output
128
+
129
+
130
+ class Glm4DecoderLayer(nn.Module):
131
+ """A single transformer layer.
132
+
133
+ Transformer layer takes input with size [s, b, h] and returns an
134
+ output of the same size.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ config,
140
+ layer_id: int,
141
+ quant_config: Optional[QuantizationConfig] = None,
142
+ prefix: str = "",
143
+ ):
144
+ super().__init__()
145
+ # Self attention.
146
+ self.self_attn = Glm4Attention(
147
+ config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
148
+ )
149
+
150
+ # MLP
151
+ self.mlp = Glm4MLP(
152
+ config.hidden_size,
153
+ intermediate_size=config.intermediate_size,
154
+ hidden_act=config.hidden_act,
155
+ quant_config=quant_config,
156
+ prefix=add_prefix("mlp", prefix),
157
+ )
158
+
159
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
160
+ self.post_attention_layernorm = RMSNorm(
161
+ config.hidden_size, eps=config.rms_norm_eps
162
+ )
163
+ self.post_self_attn_layernorm = RMSNorm(
164
+ config.hidden_size, eps=config.rms_norm_eps
165
+ )
166
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
167
+
168
+ def forward(
169
+ self,
170
+ positions: torch.Tensor,
171
+ hidden_states: torch.Tensor,
172
+ forward_batch: ForwardBatch,
173
+ residual: Optional[torch.Tensor],
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ # Self Attention
176
+ if residual is None:
177
+ residual = hidden_states
178
+ hidden_states = self.input_layernorm(hidden_states)
179
+ else:
180
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
181
+ hidden_states = self.self_attn(
182
+ positions=positions,
183
+ hidden_states=hidden_states,
184
+ forward_batch=forward_batch,
185
+ )
186
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
187
+
188
+ # Fully Connected
189
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
190
+ hidden_states = self.mlp(hidden_states)
191
+ hidden_states = self.post_mlp_layernorm(hidden_states)
192
+
193
+ return hidden_states, residual
194
+
195
+
196
+ class Glm4Model(nn.Module):
197
+ def __init__(
198
+ self,
199
+ config: Glm4Config,
200
+ quant_config: Optional[QuantizationConfig] = None,
201
+ prefix: str = "",
202
+ ) -> None:
203
+ super().__init__()
204
+ self.config = config
205
+ self.embed_tokens = VocabParallelEmbedding(
206
+ config.vocab_size,
207
+ config.hidden_size,
208
+ quant_config=quant_config,
209
+ prefix=add_prefix("embed_tokens", prefix),
210
+ )
211
+ self.layers = make_layers(
212
+ config.num_hidden_layers,
213
+ lambda idx, prefix: Glm4DecoderLayer(
214
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
215
+ ),
216
+ prefix="model.layers",
217
+ )
218
+
219
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
+
221
+ @torch.no_grad()
222
+ def forward(
223
+ self,
224
+ input_ids: torch.Tensor,
225
+ positions: torch.Tensor,
226
+ forward_batch: ForwardBatch,
227
+ input_embeds: torch.Tensor = None,
228
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
229
+ if input_embeds is None:
230
+ hidden_states = self.embed_tokens(input_ids)
231
+ else:
232
+ hidden_states = input_embeds
233
+ residual = None
234
+ for layer in self.layers:
235
+ hidden_states, residual = layer(
236
+ positions,
237
+ hidden_states,
238
+ forward_batch,
239
+ residual,
240
+ )
241
+ hidden_states, _ = self.norm(hidden_states, residual)
242
+
243
+ return hidden_states
244
+
245
+
246
+ class Glm4ForCausalLM(nn.Module):
247
+ def __init__(
248
+ self,
249
+ config: Glm4Config,
250
+ quant_config: Optional[QuantizationConfig] = None,
251
+ prefix: str = "",
252
+ ):
253
+ super().__init__()
254
+ self.config: Glm4Config = config
255
+ self.quant_config = quant_config
256
+ self.model = Glm4Model(config, quant_config, add_prefix("model", prefix))
257
+ if config.tie_word_embeddings:
258
+ self.lm_head = self.model.embed_tokens
259
+ else:
260
+ self.lm_head = ParallelLMHead(
261
+ config.vocab_size,
262
+ config.hidden_size,
263
+ quant_config=quant_config,
264
+ prefix="lm_head",
265
+ )
266
+ self.logits_processor = LogitsProcessor(config)
267
+
268
+ @torch.no_grad()
269
+ def forward(
270
+ self,
271
+ input_ids: torch.Tensor,
272
+ positions: torch.Tensor,
273
+ forward_batch: ForwardBatch,
274
+ ) -> torch.Tensor:
275
+ hidden_states = self.model(input_ids, positions, forward_batch)
276
+ return self.logits_processor(
277
+ input_ids, hidden_states, self.lm_head, forward_batch
278
+ )
279
+
280
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281
+ stacked_params_mapping = [
282
+ # (param_name, weight_name, shard_id)
283
+ (".qkv_proj", ".q_proj", "q"),
284
+ (".qkv_proj", ".k_proj", "k"),
285
+ (".qkv_proj", ".v_proj", "v"),
286
+ (".gate_up_proj", ".gate_proj", 0),
287
+ (".gate_up_proj", ".up_proj", 1),
288
+ ]
289
+ params_dict = dict(self.named_parameters())
290
+ for name, loaded_weight in weights:
291
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
292
+ continue
293
+ for param_name, weight_name, shard_id in stacked_params_mapping:
294
+ if weight_name not in name:
295
+ continue
296
+ name = name.replace(weight_name, param_name)
297
+ param = params_dict[name]
298
+ weight_loader = param.weight_loader
299
+ weight_loader(param, loaded_weight, shard_id)
300
+ break
301
+ else:
302
+ if name in params_dict.keys():
303
+ param = params_dict[name]
304
+ weight_loader = getattr(
305
+ param, "weight_loader", default_weight_loader
306
+ )
307
+ weight_loader(param, loaded_weight)
308
+ else:
309
+ raise KeyError(f"Parameter '{name}' not found in model.")
310
+
311
+
312
+ EntryClass = [Glm4ForCausalLM]
@@ -7,33 +7,17 @@ import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
9
 
10
- from sglang.srt.distributed import (
11
- get_tensor_model_parallel_rank,
12
- get_tensor_model_parallel_world_size,
13
- split_tensor_along_last_dim,
14
- tensor_model_parallel_all_gather,
15
- )
10
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
16
11
  from sglang.srt.layers.layernorm import RMSNorm
17
- from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18
12
  from sglang.srt.layers.logits_processor import LogitsProcessor
19
- from sglang.srt.layers.pooler import Pooler, PoolingType
20
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
- from sglang.srt.layers.radix_attention import RadixAttention
22
- from sglang.srt.layers.rotary_embedding import get_rope
23
14
  from sglang.srt.layers.vocab_parallel_embedding import (
24
15
  ParallelLMHead,
25
16
  VocabParallelEmbedding,
26
17
  )
27
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
29
- from sglang.srt.models.mimo import MiMoForCausalLM
30
- from sglang.srt.models.qwen2 import (
31
- Qwen2Attention,
32
- Qwen2DecoderLayer,
33
- Qwen2MLP,
34
- Qwen2Model,
35
- )
36
- from sglang.srt.utils import add_prefix
20
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer
37
21
 
38
22
 
39
23
  class MiMoMultiTokenPredictorLayer(nn.Module):
@@ -1,4 +1,4 @@
1
- from typing import Dict, Tuple
1
+ from typing import Dict, Optional, Tuple, Type
2
2
 
3
3
 
4
4
  class StreamingParseResult:
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "").strip()
36
- if self.think_end_token not in text:
35
+ in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
36
+
37
+ if not in_reasoning:
38
+ return StreamingParseResult(normal_text=text)
39
+
40
+ # The text is considered to be in a reasoning block.
41
+ processed_text = text.replace(self.think_start_token, "").strip()
42
+
43
+ if self.think_end_token not in processed_text:
37
44
  # Assume reasoning was truncated before `</think>` token
38
- return StreamingParseResult(reasoning_text=text)
45
+ return StreamingParseResult(reasoning_text=processed_text)
39
46
 
40
47
  # Extract reasoning content
41
- splits = text.split(self.think_end_token, maxsplit=1)
48
+ splits = processed_text.split(self.think_end_token, maxsplit=1)
42
49
  reasoning_text = splits[0]
43
- text = splits[1].strip()
50
+ normal_text = splits[1].strip()
44
51
 
45
- return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
52
+ return StreamingParseResult(
53
+ normal_text=normal_text, reasoning_text=reasoning_text
54
+ )
46
55
 
47
56
  def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
48
57
  """
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
61
70
  if not self.stripped_think_start and self.think_start_token in current_text:
62
71
  current_text = current_text.replace(self.think_start_token, "")
63
72
  self.stripped_think_start = True
73
+ self._in_reasoning = True
64
74
 
65
75
  # Handle end of reasoning block
66
76
  if self._in_reasoning and self.think_end_token in current_text:
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
131
141
  """
132
142
 
133
143
  def __init__(self, stream_reasoning: bool = True):
134
- # Qwen3 is assumed to be reasoning until `</think>` token
144
+ # Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
135
145
  super().__init__(
136
146
  "<think>",
137
147
  "</think>",
138
- force_reasoning=True,
148
+ force_reasoning=False,
139
149
  stream_reasoning=stream_reasoning,
140
150
  )
141
151
 
@@ -151,12 +161,12 @@ class ReasoningParser:
151
161
  If True, streams reasoning content as it arrives.
152
162
  """
153
163
 
154
- DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
164
+ DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
155
165
  "deepseek-r1": DeepSeekR1Detector,
156
166
  "qwen3": Qwen3Detector,
157
167
  }
158
168
 
159
- def __init__(self, model_type: str = None, stream_reasoning: bool = True):
169
+ def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
160
170
  if not model_type:
161
171
  raise ValueError("Model type must be specified")
162
172
 
sglang/srt/server_args.py CHANGED
@@ -152,6 +152,7 @@ class ServerArgs:
152
152
  ep_size: int = 1
153
153
  enable_ep_moe: bool = False
154
154
  enable_deepep_moe: bool = False
155
+ enable_flashinfer_moe: bool = False
155
156
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
156
157
  ep_num_redundant_experts: int = 0
157
158
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
@@ -234,6 +235,10 @@ class ServerArgs:
234
235
  num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
235
236
  pdlb_url: Optional[str] = None
236
237
 
238
+ # For model weight update
239
+ custom_weight_loader: Optional[List[str]] = None
240
+ weight_loader_disable_mmap: bool = False
241
+
237
242
  def __post_init__(self):
238
243
  # Expert parallelism
239
244
  if self.enable_ep_moe:
@@ -241,7 +246,15 @@ class ServerArgs:
241
246
  logger.warning(
242
247
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
243
248
  )
244
-
249
+ if self.enable_flashinfer_moe:
250
+ assert (
251
+ self.quantization == "modelopt_fp4"
252
+ ), "modelopt_fp4 quantization is required for Flashinfer MOE"
253
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
254
+ self.disable_shared_experts_fusion = True
255
+ logger.warning(
256
+ f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
257
+ )
245
258
  # Set missing default values
246
259
  if self.tokenizer_path is None:
247
260
  self.tokenizer_path = self.model_path
@@ -384,7 +397,6 @@ class ServerArgs:
384
397
  ), "Please enable dp attention when setting enable_dp_attention. "
385
398
 
386
399
  # DeepEP MoE
387
- self.enable_sp_layernorm = False
388
400
  if self.enable_deepep_moe:
389
401
  if self.deepep_mode == "auto":
390
402
  assert (
@@ -394,9 +406,6 @@ class ServerArgs:
394
406
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
395
407
  self.disable_cuda_graph = True
396
408
  self.ep_size = self.tp_size
397
- self.enable_sp_layernorm = (
398
- self.dp_size < self.tp_size if self.enable_dp_attention else True
399
- )
400
409
  logger.warning(
401
410
  f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
402
411
  )
@@ -538,6 +547,9 @@ class ServerArgs:
538
547
  "1" if self.disable_outlines_disk_cache else "0"
539
548
  )
540
549
 
550
+ if self.custom_weight_loader is None:
551
+ self.custom_weight_loader = []
552
+
541
553
  def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
542
554
  larger_tp = max(decode_tp, prefill_tp)
543
555
  smaller_tp = min(decode_tp, prefill_tp)
@@ -1160,6 +1172,11 @@ class ServerArgs:
1160
1172
  action="store_true",
1161
1173
  help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1162
1174
  )
1175
+ parser.add_argument(
1176
+ "--enable-flashinfer-moe",
1177
+ action="store_true",
1178
+ help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1179
+ )
1163
1180
  parser.add_argument(
1164
1181
  "--enable-deepep-moe",
1165
1182
  action="store_true",
@@ -1576,6 +1593,18 @@ class ServerArgs:
1576
1593
  default=None,
1577
1594
  help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
1578
1595
  )
1596
+ parser.add_argument(
1597
+ "--custom-weight-loader",
1598
+ type=str,
1599
+ nargs="*",
1600
+ default=None,
1601
+ help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
1602
+ )
1603
+ parser.add_argument(
1604
+ "--weight-loader-disable-mmap",
1605
+ action="store_true",
1606
+ help="Disable mmap while loading weight using safetensors.",
1607
+ )
1579
1608
 
1580
1609
  @classmethod
1581
1610
  def from_cli_args(cls, args: argparse.Namespace):
@@ -1700,9 +1729,8 @@ class PortArgs:
1700
1729
  dist_init_host, dist_init_port = dist_init_addr
1701
1730
  port_base = int(dist_init_port) + 1
1702
1731
  if dp_rank is None:
1703
- scheduler_input_port = (
1704
- port_base + 3
1705
- ) # TokenizerManager to DataParallelController
1732
+ # TokenizerManager to DataParallelController
1733
+ scheduler_input_port = port_base + 3
1706
1734
  else:
1707
1735
  scheduler_input_port = port_base + 3 + 1 + dp_rank
1708
1736