sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
107
107
  rope_scaling: Optional[Dict[str, Any]] = None,
108
108
  max_position_embeddings: int = 32768,
109
109
  quant_config: Optional[QuantizationConfig] = None,
110
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
110
111
  prefix: str = "",
111
112
  ) -> None:
112
113
  super().__init__()
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
158
159
  max_position=max_position_embeddings,
159
160
  base=rope_theta,
160
161
  rope_scaling=rope_scaling,
162
+ dual_chunk_attention_config=dual_chunk_attention_config,
161
163
  )
162
164
  self.attn = RadixAttention(
163
165
  self.num_heads,
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
198
200
  rope_scaling = getattr(config, "rope_scaling", None)
199
201
  max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
200
202
  head_dim = getattr(config, "head_dim", None)
203
+ dual_chunk_attention_config = getattr(
204
+ config, "dual_chunk_attention_config", None
205
+ )
201
206
  self.self_attn = Qwen2Attention(
202
207
  hidden_size=self.hidden_size,
203
208
  num_heads=config.num_attention_heads,
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
208
213
  rope_scaling=rope_scaling,
209
214
  max_position_embeddings=max_position_embeddings,
210
215
  quant_config=quant_config,
216
+ dual_chunk_attention_config=dual_chunk_attention_config,
211
217
  prefix=add_prefix("self_attn", prefix),
212
218
  )
213
219
  self.mlp = Qwen2MLP(
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
210
210
  max_position_embeddings: int = 8192,
211
211
  qkv_bias: int = True,
212
212
  quant_config: Optional[QuantizationConfig] = None,
213
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
213
214
  prefix: str = "",
214
215
  ) -> None:
215
216
  super().__init__()
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
267
268
  max_position=max_position_embeddings,
268
269
  base=rope_theta,
269
270
  rope_scaling=rope_scaling,
271
+ dual_chunk_attention_config=dual_chunk_attention_config,
270
272
  )
271
273
  self.attn = RadixAttention(
272
274
  self.num_heads,
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
308
310
  rope_scaling = getattr(config, "rope_scaling", None)
309
311
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
310
312
  qkv_bias = getattr(config, "qkv_bias", True)
313
+ dual_chunk_attention_config = getattr(
314
+ config, "dual_chunk_attention_config", None
315
+ )
311
316
  self.self_attn = Qwen2MoeAttention(
312
317
  hidden_size=self.hidden_size,
313
318
  num_heads=config.num_attention_heads,
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
317
322
  rope_scaling=rope_scaling,
318
323
  max_position_embeddings=max_position_embeddings,
319
324
  quant_config=quant_config,
325
+ dual_chunk_attention_config=dual_chunk_attention_config,
320
326
  qkv_bias=qkv_bias,
321
327
  prefix=add_prefix("self_attn", prefix),
322
328
  )
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
295
295
  attention_bias: bool = False,
296
296
  quant_config: Optional[QuantizationConfig] = None,
297
297
  prefix: str = "",
298
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
298
299
  alt_stream: Optional[torch.cuda.Stream] = None,
299
300
  ) -> None:
300
301
  super().__init__()
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
353
354
  max_position=max_position_embeddings,
354
355
  base=rope_theta,
355
356
  rope_scaling=rope_scaling,
357
+ dual_chunk_attention_config=dual_chunk_attention_config,
356
358
  )
357
359
  self.attn = RadixAttention(
358
360
  self.num_heads,
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
458
460
  )
459
461
  rms_norm_eps = config.rms_norm_eps
460
462
  attention_bias = config.attention_bias
463
+ dual_chunk_attention_config = getattr(
464
+ config, "dual_chunk_attention_config", None
465
+ )
461
466
  self.self_attn = Qwen3MoeAttention(
462
467
  hidden_size=self.hidden_size,
463
468
  num_heads=config.num_attention_heads,
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
471
476
  attention_bias=attention_bias,
472
477
  quant_config=quant_config,
473
478
  prefix=add_prefix("self_attn", prefix),
479
+ dual_chunk_attention_config=dual_chunk_attention_config,
474
480
  alt_stream=alt_stream,
475
481
  )
476
482
 
@@ -766,7 +772,10 @@ class Qwen3MoeForCausalLM(nn.Module):
766
772
  num_experts=self.config.num_experts,
767
773
  )
768
774
 
769
- params_dict = dict(self.named_parameters())
775
+ # Cache params_dict to avoid repeated expensive traversal of model parameters
776
+ if not hasattr(self, "_cached_params_dict"):
777
+ self._cached_params_dict = dict(self.named_parameters())
778
+ params_dict = self._cached_params_dict
770
779
  for name, loaded_weight in weights:
771
780
  layer_id = get_layer_id(name)
772
781
  if (
@@ -805,11 +814,22 @@ class Qwen3MoeForCausalLM(nn.Module):
805
814
  weight_loader(param, loaded_weight, shard_id)
806
815
  break
807
816
  else:
817
+ # Track if this is an expert weight to enable early skipping
818
+ is_expert_weight = False
819
+
808
820
  for mapping in expert_params_mapping:
809
821
  param_name, weight_name, expert_id, shard_id = mapping
810
822
  if weight_name not in name:
811
823
  continue
824
+
825
+ # Mark as expert weight regardless of whether we can process it
826
+ is_expert_weight = True
827
+
812
828
  name = name.replace(weight_name, param_name)
829
+ if name not in params_dict:
830
+ # Expert weight not on this rank, will be skipped below
831
+ continue
832
+
813
833
  param = params_dict[name]
814
834
  weight_loader = param.weight_loader
815
835
  weight_loader(
@@ -821,6 +841,10 @@ class Qwen3MoeForCausalLM(nn.Module):
821
841
  )
822
842
  break
823
843
  else:
844
+ if is_expert_weight:
845
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
846
+ continue
847
+
824
848
  # Skip loading extra bias for GPTQ models.
825
849
  if name.endswith(".bias") and name not in params_dict:
826
850
  continue
@@ -837,11 +861,13 @@ class Qwen3MoeForCausalLM(nn.Module):
837
861
  logger.warning(f"Parameter {name} not found in params_dict")
838
862
 
839
863
  # TODO mimic deepseek
840
- self.routed_experts_weights_of_layer = {
841
- layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
842
- for layer_id in range(self.start_layer, self.end_layer)
843
- if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
844
- }
864
+ # Lazy initialization of expert weights cache to avoid slowing down load_weights
865
+ if not hasattr(self, "routed_experts_weights_of_layer"):
866
+ self.routed_experts_weights_of_layer = {
867
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
868
+ for layer_id in range(self.start_layer, self.end_layer)
869
+ if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
870
+ }
845
871
 
846
872
  @classmethod
847
873
  def get_model_config_for_expert_location(cls, config):
@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
531
531
  prefix: str = "",
532
532
  ) -> None:
533
533
  super().__init__()
534
+ # Since this is a dense model,
535
+ # the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
536
+ # This choice may not represent the optimal solution and remains open to further deliberation.
537
+ attn_tp_rank = get_attention_tp_rank()
538
+ attn_tp_size = get_attention_tp_size()
534
539
  self.fc1 = ColumnParallelLinear(
535
540
  dim,
536
541
  intermediate_size,
537
542
  bias=bias,
538
543
  quant_config=quant_config,
544
+ tp_rank=attn_tp_rank,
545
+ tp_size=attn_tp_size,
539
546
  prefix=add_prefix("gate_proj", prefix),
540
547
  )
541
548
  self.act = ACT2FN[hidden_act] # quick_gelu
@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
544
551
  dim,
545
552
  bias=bias,
546
553
  quant_config=quant_config,
554
+ tp_rank=attn_tp_rank,
555
+ tp_size=attn_tp_size,
547
556
  prefix=add_prefix("down_proj", prefix),
548
557
  )
549
558
 
@@ -211,16 +211,13 @@ class TransformersForCausalLM(nn.Module):
211
211
  Apply the model's tensor parallelization plan.
212
212
  Currently only supports linear layers.
213
213
  """
214
- if not self.model.supports_tp_plan:
215
- if tp_size <= 1:
216
- return
214
+ tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
217
215
 
216
+ if not tp_plan and self.tp_size > 1:
218
217
  raise ValueError(
219
218
  f"{type(self.model)} does not support tensor parallel yet!"
220
219
  )
221
220
 
222
- tp_plan = self.model._tp_plan
223
-
224
221
  def _tensor_parallel(module: nn.Module, prefix: str = ""):
225
222
  for child_name, child_module in module.named_children():
226
223
  qual_name = maybe_prefix(prefix, child_name)
@@ -8,7 +8,7 @@ import torch
8
8
  from PIL import Image
9
9
  from torchvision import transforms
10
10
  from torchvision.transforms import InterpolationMode
11
- from transformers import BatchFeature, TensorType
11
+ from transformers import BatchFeature, ProcessorMixin, TensorType
12
12
 
13
13
  from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
14
14
  from sglang.srt.multimodal.processors.base_processor import (
@@ -276,6 +276,8 @@ class Step3VLProcessor:
276
276
  super().__init__()
277
277
 
278
278
  self.config = config
279
+ if isinstance(tokenizer, ProcessorMixin):
280
+ tokenizer = tokenizer.tokenizer
279
281
  self.tokenizer = tokenizer
280
282
 
281
283
  self.image_size = 728
@@ -131,7 +131,7 @@ class DeepSeekR1Detector(BaseReasoningFormatDetector):
131
131
  If True, streams reasoning content as it arrives.
132
132
  """
133
133
 
134
- def __init__(self, stream_reasoning: bool = True):
134
+ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True):
135
135
  # DeepSeek-R1 is assumed to be reasoning until `</think>` token
136
136
  super().__init__(
137
137
  "<think>",
@@ -144,7 +144,7 @@ class DeepSeekR1Detector(BaseReasoningFormatDetector):
144
144
 
145
145
  class Qwen3Detector(BaseReasoningFormatDetector):
146
146
  """
147
- Detector for standard Qwen3 models (e.g., Qwen/Qwen3-235B-A22B).
147
+ Detector for Qwen3 models (e.g., Qwen/Qwen3-235B-A22B).
148
148
  Assumes reasoning format:
149
149
  (<think>)*(.*)</think>
150
150
 
@@ -153,47 +153,16 @@ class Qwen3Detector(BaseReasoningFormatDetector):
153
153
  - enable_thinking=True: "<think>reasoning content</think>The answer is 42."
154
154
  - enable_thinking=False: "The answer is 42." (no thinking tokens)
155
155
 
156
- This detector handles both cases.
157
-
158
- NOTE: Do NOT use this detector for Qwen3-Thinking models (e.g., Qwen3-Thinking-2507).
159
- Those models always generate thinking content without <think> start tags.
160
- Use "qwen3-thinking" parser type for those models instead.
161
-
162
- Args:
163
- stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
164
- If True, streams reasoning content as it arrives.
165
- """
166
-
167
- def __init__(self, stream_reasoning: bool = True):
168
- super().__init__(
169
- "<think>",
170
- "</think>",
171
- force_reasoning=False,
172
- stream_reasoning=stream_reasoning,
173
- )
174
-
175
-
176
- class Qwen3ThinkingDetector(BaseReasoningFormatDetector):
177
- """
178
- Detector for Qwen3-Thinking models (e.g., Qwen3-Thinking-2507).
179
- Assumes reasoning format:
180
- *(.*)</think>
181
-
182
- These models always generate thinking content without <think> start tag.
183
- They do not support the enable_thinking parameter and always think.
184
-
185
- Format: "I need to think about this...</think>The answer is 42."
186
-
187
156
  Args:
188
157
  stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
189
158
  If True, streams reasoning content as it arrives.
190
159
  """
191
160
 
192
- def __init__(self, stream_reasoning: bool = True):
161
+ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False):
193
162
  super().__init__(
194
163
  "<think>",
195
164
  "</think>",
196
- force_reasoning=True,
165
+ force_reasoning=force_reasoning,
197
166
  stream_reasoning=stream_reasoning,
198
167
  )
199
168
 
@@ -207,7 +176,7 @@ class KimiDetector(BaseReasoningFormatDetector):
207
176
  and the rest of the text as `normal_text`.
208
177
  """
209
178
 
210
- def __init__(self, stream_reasoning: bool = True):
179
+ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False):
211
180
  super().__init__(
212
181
  "◁think▷",
213
182
  "◁/think▷",
@@ -230,13 +199,18 @@ class ReasoningParser:
230
199
  DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
231
200
  "deepseek-r1": DeepSeekR1Detector,
232
201
  "qwen3": Qwen3Detector,
233
- "qwen3-thinking": Qwen3ThinkingDetector,
202
+ "qwen3-thinking": Qwen3Detector,
234
203
  "glm45": Qwen3Detector,
235
204
  "kimi": KimiDetector,
236
205
  "step3": DeepSeekR1Detector,
237
206
  }
238
207
 
239
- def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
208
+ def __init__(
209
+ self,
210
+ model_type: Optional[str] = None,
211
+ stream_reasoning: bool = True,
212
+ force_reasoning: bool = False,
213
+ ):
240
214
  if not model_type:
241
215
  raise ValueError("Model type must be specified")
242
216
 
@@ -244,7 +218,12 @@ class ReasoningParser:
244
218
  if not detector_class:
245
219
  raise ValueError(f"Unsupported model type: {model_type}")
246
220
 
247
- self.detector = detector_class(stream_reasoning=stream_reasoning)
221
+ if model_type.lower() == "qwen3-thinking":
222
+ force_reasoning = True
223
+
224
+ self.detector = detector_class(
225
+ stream_reasoning=stream_reasoning, force_reasoning=force_reasoning
226
+ )
248
227
 
249
228
  def parse_non_stream(self, full_text: str) -> Tuple[str, str]:
250
229
  """Non-streaming call: one-time parsing"""
sglang/srt/server_args.py CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
37
37
  is_hip,
38
38
  is_port_available,
39
39
  is_remote_url,
40
+ is_triton_kernels_available,
40
41
  is_valid_ipv6_address,
41
42
  nullable_str,
42
43
  )
@@ -201,6 +202,7 @@ class ServerArgs:
201
202
  hicache_io_backend: str = "kernel"
202
203
  hicache_mem_layout: str = "layer_first"
203
204
  hicache_storage_backend: Optional[str] = None
205
+ hicache_storage_prefetch_policy: str = "best_effort"
204
206
 
205
207
  # Double Sparsity
206
208
  enable_double_sparsity: bool = False
@@ -229,6 +231,7 @@ class ServerArgs:
229
231
  enable_dp_attention: bool = False
230
232
  enable_dp_lm_head: bool = False
231
233
  enable_two_batch_overlap: bool = False
234
+ tbo_token_distribution_threshold: float = 0.48
232
235
  enable_torch_compile: bool = False
233
236
  torch_compile_max_bs: int = 32
234
237
  torchao_config: str = ""
@@ -247,6 +250,8 @@ class ServerArgs:
247
250
  disable_fast_image_processor: bool = False
248
251
  enable_return_hidden_states: bool = False
249
252
  enable_triton_kernel_moe: bool = False
253
+ enable_flashinfer_mxfp4_moe: bool = False
254
+ scheduler_recv_interval: int = 1
250
255
 
251
256
  # Debug tensor dumps
252
257
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -273,6 +278,9 @@ class ServerArgs:
273
278
  enable_pdmux: bool = False
274
279
  sm_group_num: int = 3
275
280
 
281
+ # For tool server
282
+ tool_server: Optional[str] = None
283
+
276
284
  # Deprecated arguments
277
285
  enable_ep_moe: bool = False
278
286
  enable_deepep_moe: bool = False
@@ -441,6 +449,81 @@ class ServerArgs:
441
449
  "trtllm_mla backend does not support speculative decoding yet."
442
450
  )
443
451
 
452
+ if (
453
+ self.attention_backend == "trtllm_mha"
454
+ or self.decode_attention_backend == "trtllm_mha"
455
+ or self.prefill_attention_backend == "trtllm_mha"
456
+ ):
457
+ if not is_sm100_supported():
458
+ raise ValueError(
459
+ "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
460
+ )
461
+
462
+ if self.page_size not in [16, 32, 64]:
463
+ logger.warning(
464
+ f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
465
+ )
466
+ self.page_size = 64
467
+
468
+ if self.speculative_algorithm is not None:
469
+ raise ValueError(
470
+ "trtllm_mha backend does not support speculative decoding yet."
471
+ )
472
+
473
+ model_arch = self.get_hf_config().architectures[0]
474
+ if model_arch in ["GptOssForCausalLM"]:
475
+ if self.attention_backend is None:
476
+ # default is triton, but we could have trtllm_mha as an option
477
+ self.attention_backend = "triton"
478
+ assert (
479
+ self.attention_backend == "trtllm_mha"
480
+ or self.attention_backend == "triton"
481
+ )
482
+ quantization_config = getattr(
483
+ self.get_hf_config(), "quantization_config", None
484
+ )
485
+ is_mxfp4_quant_format = (
486
+ quantization_config is not None
487
+ and quantization_config.get("quant_method") == "mxfp4"
488
+ )
489
+
490
+ if is_sm100_supported() and is_mxfp4_quant_format:
491
+ self.enable_flashinfer_mxfp4_moe = True
492
+ self.enable_triton_kernel_moe = False
493
+ logger.info(
494
+ "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
495
+ )
496
+ else:
497
+ if self.enable_triton_kernel_moe:
498
+ assert (
499
+ self.ep_size == 1
500
+ ), "Triton kernel MoE is only supported when ep_size == 1"
501
+ if not self.enable_triton_kernel_moe and self.ep_size == 1:
502
+ self.enable_triton_kernel_moe = True
503
+ logger.info(
504
+ "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
505
+ )
506
+
507
+ self.disable_hybrid_swa_memory = True
508
+
509
+ if is_mxfp4_quant_format:
510
+ # use bf16 for mxfp4 triton kernels
511
+ self.dtype = "bfloat16"
512
+
513
+ if self.attention_backend == "dual_chunk_flash_attn":
514
+ logger.warning(
515
+ "Mixed chunk is disabled because of using dual chunk flash attention backend"
516
+ )
517
+ logger.warning(
518
+ "Radix cache is disabled because of using dual chunk flash attention backend"
519
+ )
520
+ logger.warning(
521
+ "Cuda graph is disabled because of using dual chunk flash attention backend"
522
+ )
523
+ self.enable_mixed_chunk = False
524
+ self.disable_cuda_graph = True
525
+ self.disable_radix_cache = True
526
+
444
527
  # Set page size
445
528
  if self.page_size is None:
446
529
  self.page_size = 1
@@ -481,6 +564,13 @@ class ServerArgs:
481
564
  self.tp_size,
482
565
  ], "The expert parallel size must be 1 or the same as the tensor parallel size"
483
566
 
567
+ if self.enable_flashinfer_trtllm_moe:
568
+ if not self.disable_shared_experts_fusion:
569
+ self.disable_shared_experts_fusion = True
570
+ logger.warning(
571
+ "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
572
+ )
573
+
484
574
  # DeepEP MoE
485
575
  if self.moe_a2a_backend == "deepep":
486
576
  if self.deepep_mode == "normal":
@@ -806,6 +896,7 @@ class ServerArgs:
806
896
  "moe_wna16",
807
897
  "qoq",
808
898
  "w4afp8",
899
+ "mxfp4",
809
900
  ],
810
901
  help="The quantization method.",
811
902
  )
@@ -868,7 +959,7 @@ class ServerArgs:
868
959
  "--schedule-policy",
869
960
  type=str,
870
961
  default=ServerArgs.schedule_policy,
871
- choices=["lpm", "random", "fcfs", "dfs-weight"],
962
+ choices=["lpm", "random", "fcfs", "dfs-weight", "lof"],
872
963
  help="The scheduling policy of the requests.",
873
964
  )
874
965
  parser.add_argument(
@@ -1267,6 +1358,8 @@ class ServerArgs:
1267
1358
  "ascend",
1268
1359
  "triton",
1269
1360
  "trtllm_mla",
1361
+ "trtllm_mha",
1362
+ "dual_chunk_flash_attn",
1270
1363
  ],
1271
1364
  default=ServerArgs.attention_backend,
1272
1365
  help="Choose the kernels for attention layers.",
@@ -1527,6 +1620,13 @@ class ServerArgs:
1527
1620
  default=ServerArgs.hicache_storage_backend,
1528
1621
  help="The storage backend for hierarchical KV cache.",
1529
1622
  )
1623
+ parser.add_argument(
1624
+ "--hicache-storage-prefetch-policy",
1625
+ type=str,
1626
+ choices=["best_effort", "wait_complete", "timeout"],
1627
+ default=ServerArgs.hicache_storage_prefetch_policy,
1628
+ help="Control when prefetching from the storage backend should stop.",
1629
+ )
1530
1630
 
1531
1631
  # Double Sparsity
1532
1632
  parser.add_argument(
@@ -1658,6 +1758,12 @@ class ServerArgs:
1658
1758
  action="store_true",
1659
1759
  help="Enabling two micro batches to overlap.",
1660
1760
  )
1761
+ parser.add_argument(
1762
+ "--tbo-token-distribution-threshold",
1763
+ type=float,
1764
+ default=ServerArgs.tbo_token_distribution_threshold,
1765
+ help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
1766
+ )
1661
1767
  parser.add_argument(
1662
1768
  "--enable-torch-compile",
1663
1769
  action="store_true",
@@ -1755,6 +1861,17 @@ class ServerArgs:
1755
1861
  action="store_true",
1756
1862
  help="Use triton moe grouped gemm kernel.",
1757
1863
  )
1864
+ parser.add_argument(
1865
+ "--enable-flashinfer-mxfp4-moe",
1866
+ action="store_true",
1867
+ help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
1868
+ )
1869
+ parser.add_argument(
1870
+ "--scheduler-recv-interval",
1871
+ type=int,
1872
+ default=ServerArgs.scheduler_recv_interval,
1873
+ help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.",
1874
+ )
1758
1875
 
1759
1876
  # Debug tensor dumps
1760
1877
  parser.add_argument(
@@ -1868,6 +1985,14 @@ class ServerArgs:
1868
1985
  help="Disable mmap while loading weight using safetensors.",
1869
1986
  )
1870
1987
 
1988
+ # For tool server
1989
+ parser.add_argument(
1990
+ "--tool-server",
1991
+ type=str,
1992
+ default=None,
1993
+ help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
1994
+ )
1995
+
1871
1996
  # Deprecated arguments
1872
1997
  parser.add_argument(
1873
1998
  "--enable-ep-moe",
@@ -1936,10 +2061,18 @@ class ServerArgs:
1936
2061
  if "Llama4" in model_arch:
1937
2062
  assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
1938
2063
 
1939
- if "Gemma2ForCausalLM" in model_arch:
2064
+ if model_arch in [
2065
+ "Gemma2ForCausalLM",
2066
+ "Gemma3ForCausalLM",
2067
+ "Gemma3ForConditionalGeneration",
2068
+ "Gemma3nForCausalLM",
2069
+ "Gemma3nForConditionalGeneration",
2070
+ ]:
1940
2071
  # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
1941
2072
  # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
1942
- logger.warning("Disable hybrid SWA memory for Gemma2ForCausalLM.")
2073
+ logger.warning(
2074
+ f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
2075
+ )
1943
2076
  self.disable_hybrid_swa_memory = True
1944
2077
 
1945
2078
  # Check LoRA
@@ -1977,21 +2110,23 @@ class ServerArgs:
1977
2110
 
1978
2111
  if self.enable_lora:
1979
2112
  # Normalize lora_paths to a dictionary if it is a list.
2113
+ # TODO (lifuhuang): support specifying pinned adapters in server_args.
1980
2114
  if isinstance(self.lora_paths, list):
1981
2115
  lora_paths = self.lora_paths
1982
2116
  self.lora_paths = {}
1983
2117
  for lora_path in lora_paths:
1984
2118
  if "=" in lora_path:
1985
2119
  name, path = lora_path.split("=", 1)
1986
- self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
2120
+ self.lora_paths[name] = LoRARef(
2121
+ lora_name=name, lora_path=path, pinned=False
2122
+ )
1987
2123
  else:
1988
2124
  self.lora_paths[lora_path] = LoRARef(
1989
- lora_name=lora_path,
1990
- lora_path=lora_path,
2125
+ lora_name=lora_path, lora_path=lora_path, pinned=False
1991
2126
  )
1992
2127
  elif isinstance(self.lora_paths, dict):
1993
2128
  self.lora_paths = {
1994
- k: LoRARef(lora_name=k, lora_path=v)
2129
+ k: LoRARef(lora_name=k, lora_path=v, pinned=False)
1995
2130
  for k, v in self.lora_paths.items()
1996
2131
  }
1997
2132
  elif self.lora_paths is None: