sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
11
11
  class JanusProImageProcessor(BaseMultimodalProcessor):
12
12
  models = [MultiModalityCausalLM]
13
13
 
14
- def __init__(self, hf_config, server_args, _processor):
15
- super().__init__(hf_config, server_args, _processor)
14
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
15
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
16
16
 
17
17
  self.mm_tokens = MultimodalSpecialTokens(
18
18
  image_token=_processor.image_token,
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
12
12
  class KimiVLImageProcessor(SGLangBaseProcessor):
13
13
  models = [KimiVLForConditionalGeneration]
14
14
 
15
- def __init__(self, hf_config, server_args, _processor):
16
- super().__init__(hf_config, server_args, _processor)
15
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
16
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
17
17
  self.mm_tokens = MultimodalSpecialTokens(
18
18
  image_token="<|media_pad|>",
19
19
  # TODO: could we convert in MultimodalSpecialTokens?
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
30
30
  LlavaMistralForCausalLM,
31
31
  ]
32
32
 
33
- def __init__(self, hf_config, server_args, _processor):
34
- super().__init__(hf_config, server_args, _processor)
33
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
34
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
35
35
 
36
36
  @staticmethod
37
37
  def _process_single_image_task(
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
187
187
  f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
188
188
  )
189
189
 
190
- def __init__(self, hf_config, server_args, _processor):
190
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
191
191
  assert hasattr(hf_config, "vision_config")
192
192
  assert hasattr(hf_config, "text_config")
193
193
  self.vision_config = hf_config.vision_config
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
196
196
 
197
197
  if vision_type := getattr(self.vision_config, "model_type"):
198
198
  self.inner = self._get_sgl_processor_cls(vision_type)(
199
- hf_config, server_args, _processor
199
+ hf_config, server_args, _processor, *args, **kwargs
200
200
  )
201
201
  else:
202
202
  raise ValueError(
@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
15
15
  class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
16
16
  models = [MiniCPMV, MiniCPMO]
17
17
 
18
- def __init__(self, hf_config, server_args, _processor):
19
- super().__init__(hf_config, server_args, _processor)
18
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
19
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
20
20
  # Collect special token ids
21
21
  tokenizer = self._processor.tokenizer
22
22
  self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
26
26
  self.im_start_id = getattr(tokenizer, "im_start_id", None)
27
27
  self.im_end_id = getattr(tokenizer, "im_end_id", None)
28
28
  self.im_token_id = getattr(tokenizer, "unk_id", None)
29
-
30
29
  self.mm_tokens = MultimodalSpecialTokens(
31
30
  image_token="(<image>./</image>)",
32
31
  audio_token="(<audio>./</audio>)",
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
10
10
  class MllamaImageProcessor(BaseMultimodalProcessor):
11
11
  models = [MllamaForConditionalGeneration]
12
12
 
13
- def __init__(self, hf_config, server_args, _processor):
14
- super().__init__(hf_config, server_args, _processor)
13
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
14
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
15
15
  self.mm_tokens = MultimodalSpecialTokens(
16
16
  image_token=self._processor.image_token,
17
17
  image_token_id=self._processor.image_token_id,
@@ -18,16 +18,16 @@ from sglang.srt.multimodal.processors.base_processor import (
18
18
  class Mllama4ImageProcessor(BaseMultimodalProcessor):
19
19
  models = [Llama4ForConditionalGeneration]
20
20
 
21
- def __init__(self, hf_config, server_args, _processor):
22
- super().__init__(hf_config, server_args, _processor)
21
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
22
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
23
23
  self.vision_config = hf_config.vision_config
24
24
  self.text_config = hf_config.text_config
25
- self.boi_token_index = hf_config.boi_token_index
26
- self.eoi_token_index = hf_config.eoi_token_index
27
- self.image_token_index = hf_config.image_token_index
28
- self.multimodal_tokens = MultimodalSpecialTokens(
25
+ self.IM_START_TOKEN_ID = hf_config.boi_token_index
26
+ self.IM_END_TOKEN_ID = hf_config.eoi_token_index
27
+ self.IM_TOKEN_ID = hf_config.image_token_index
28
+ self.mm_tokens = MultimodalSpecialTokens(
29
29
  image_token=_processor.image_token,
30
- image_token_id=self.image_token_index,
30
+ image_token_id=self.IM_TOKEN_ID,
31
31
  ).build(_processor)
32
32
 
33
33
  async def process_mm_data_async(
@@ -37,114 +37,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
37
37
  *args,
38
38
  **kwargs,
39
39
  ):
40
- if isinstance(input_text, list):
41
- assert len(input_text) and isinstance(input_text[0], int)
42
- input_text = self._processor.tokenizer.decode(input_text)
43
-
44
- # Process images and text using the base processor's load_mm_data method
45
- processed_data = self.load_mm_data(
40
+ base_output = self.load_mm_data(
46
41
  prompt=input_text,
47
- multimodal_tokens=self.multimodal_tokens,
48
42
  image_data=image_data,
49
- return_text=True,
43
+ multimodal_tokens=self.mm_tokens,
50
44
  )
51
45
 
52
- # Process the images using the processor
53
- processor = self._processor
54
-
55
46
  # Process the prompt and images
56
- processor_output = self.process_mm_data(
57
- input_text=processed_data.input_text,
58
- images=processed_data.images,
59
- )
60
-
61
- # Handle image resolutions and aspect ratios
62
- if "pixel_values" not in processor_output: # no image processed
63
- return None
64
-
65
- image_processor = processor.image_processor
66
- tokenizer = self._processor.tokenizer
67
-
68
- # Calculate tile size and find supported resolutions
69
- tile_size = self.vision_config.image_size
70
- max_num_tiles = getattr(self.vision_config, "max_patches", 1)
71
-
72
- possible_resolutions = find_supported_resolutions(
73
- max_num_chunks=max_num_tiles,
74
- patch_size=SizeDict(height=tile_size, width=tile_size),
47
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
48
+ base_output, self.mm_tokens
75
49
  )
76
50
 
77
- # Find best fit for each image
78
- best_fit_sizes = [
79
- get_best_fit(
80
- (image.size[1], image.size[0]), # (height, width)
81
- torch.tensor(possible_resolutions),
82
- resize_to_max_canvas=image_processor.resize_to_max_canvas,
83
- )
84
- for image in processed_data.images
85
- ]
86
-
87
- # Calculate aspect ratios and patches per image
88
- aspect_ratios = [
89
- (image_size[0] // tile_size, image_size[1] // tile_size)
90
- for image_size in best_fit_sizes
91
- ]
92
-
93
- patches_per_image = [
94
- 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
95
- ]
96
-
97
- # Add to image_inputs
98
- processor_output["aspect_ratios"] = aspect_ratios
99
- processor_output["patches_per_image"] = torch.tensor(patches_per_image)
100
-
101
- # Process embed_is_patch
102
- vocab = tokenizer.get_vocab()
103
- patch_id = vocab.get(processor.img_patch_token, -1)
104
- image_end_id = vocab.get(processor.end_of_img_token, -1)
105
-
106
- if patch_id != -1 and image_end_id != -1:
107
- input_ids = processor_output["input_ids"].view(-1)
108
-
109
- # Remove BOS token if present
110
- if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
111
- input_ids = input_ids[1:]
112
-
113
- # Find image end indices and split input_ids
114
- image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
115
-
116
- if image_end_indices.size(0) > 0:
117
- # Split at image boundaries
118
- split_indices = (image_end_indices + 1)[:-1]
119
- split_input_ids = torch.tensor_split(input_ids, split_indices)
120
- split_input_ids = [x for x in split_input_ids if x.numel() > 0]
121
-
122
- # Create embed_is_patch for each image
123
- embed_is_patch = []
124
- for per_image_input_ids in split_input_ids:
125
- embed_is_patch.append(per_image_input_ids == patch_id)
126
-
127
- processor_output["embed_is_patch"] = embed_is_patch
128
-
129
- # Convert to the format expected by SGLang
130
- processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
131
-
132
- processor_output["im_start_id"] = self.boi_token_index
133
- processor_output["im_end_id"] = self.eoi_token_index
134
- processor_output["im_token_id"] = self.image_token_index
135
-
136
- image_offsets = self.get_mm_items_offset(
137
- input_ids=torch.tensor(processor_output["input_ids"]),
138
- mm_token_id=self.image_token_index,
139
- )
140
-
141
- # Add metadata for image processing
142
- processor_output["mm_items"] = [
143
- MultimodalDataItem(
144
- feature=processor_output["pixel_values"],
145
- modality=Modality.IMAGE,
146
- offsets=image_offsets,
147
- )
148
- ]
149
-
150
- return processor_output
51
+ return {
52
+ "input_ids": input_ids.tolist(),
53
+ "mm_items": mm_items,
54
+ "im_start_id": self.IM_START_TOKEN_ID,
55
+ "im_end_id": self.IM_END_TOKEN_ID,
56
+ "im_token_id": self.IM_TOKEN_ID,
57
+ }
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
47
47
  class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
48
48
  models = [Phi4MMForCausalLM]
49
49
 
50
- def __init__(self, hf_config, server_args, _processor):
50
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
51
51
  self.processor = Phi4MMProcessorAdapter(_processor)
52
- super().__init__(hf_config, server_args, self.processor)
52
+ super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
53
53
 
54
54
  # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
55
55
  # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
42
42
 
43
43
  return ncols, nrows
44
44
 
45
- def __init__(self, hf_config, server_args, _processor):
46
- super().__init__(hf_config, server_args, _processor)
45
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
46
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
47
47
  self.IM_TOKEN_ID = getattr(
48
48
  hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
49
49
  )
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
11
11
  class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
12
12
  models = [Qwen2AudioForConditionalGeneration]
13
13
 
14
- def __init__(self, hf_config, server_args, _processor):
15
- super().__init__(hf_config, server_args, _processor)
14
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
15
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
16
16
  self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
17
17
  self.AUDIO_TOKEN_REGEX = re.compile(
18
18
  r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
@@ -201,8 +201,8 @@ async def preprocess_video(
201
201
  class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
202
202
  models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
203
203
 
204
- def __init__(self, hf_config, server_args, _processor):
205
- super().__init__(hf_config, server_args, _processor)
204
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
205
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
206
206
  # The regex that matches expanded image tokens.
207
207
  self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
208
208
  self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
34
34
  hf_config: PretrainedConfig,
35
35
  server_args: ServerArgs,
36
36
  _processor: VILAProcessor,
37
+ *args,
38
+ **kwargs,
37
39
  ) -> None:
38
- super().__init__(hf_config, server_args, _processor)
40
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
39
41
  self.mm_tokens = MultimodalSpecialTokens(
40
42
  image_token=self._processor.tokenizer.image_token,
41
43
  image_token_id=hf_config.image_token_id,
@@ -0,0 +1,31 @@
1
+ import torch
2
+
3
+ from sglang.srt.distributed import get_world_group
4
+
5
+
6
+ class PollBasedBarrier:
7
+ def __init__(self, noop: bool = False):
8
+ self._noop = noop
9
+ self._local_arrived = False
10
+
11
+ def local_arrive(self):
12
+ assert not self._local_arrived
13
+ self._local_arrived = True
14
+
15
+ def poll_global_arrived(self) -> bool:
16
+ global_arrived = self._compute_global_arrived()
17
+ output = self._local_arrived and global_arrived
18
+ if output:
19
+ self._local_arrived = False
20
+ return output
21
+
22
+ def _compute_global_arrived(self) -> bool:
23
+ local_arrived = self._noop or self._local_arrived
24
+ global_arrived = torch.tensor(local_arrived)
25
+ # Can optimize if bottleneck
26
+ torch.distributed.all_reduce(
27
+ global_arrived,
28
+ torch.distributed.ReduceOp.MIN,
29
+ group=get_world_group().cpu_group,
30
+ )
31
+ return global_arrived.item()
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
35
+ in_reasoning = self._in_reasoning or self.think_start_token in text
36
36
 
37
37
  if not in_reasoning:
38
38
  return StreamingParseResult(normal_text=text)
@@ -231,6 +231,7 @@ class ReasoningParser:
231
231
  "deepseek-r1": DeepSeekR1Detector,
232
232
  "qwen3": Qwen3Detector,
233
233
  "qwen3-thinking": Qwen3ThinkingDetector,
234
+ "glm45": Qwen3Detector,
234
235
  "kimi": KimiDetector,
235
236
  }
236
237
 
sglang/srt/server_args.py CHANGED
@@ -19,6 +19,7 @@ import json
19
19
  import logging
20
20
  import os
21
21
  import random
22
+ import sys
22
23
  import tempfile
23
24
  from typing import List, Literal, Optional, Union
24
25
 
@@ -74,6 +75,7 @@ class ServerArgs:
74
75
  # Memory and scheduling
75
76
  mem_fraction_static: Optional[float] = None
76
77
  max_running_requests: Optional[int] = None
78
+ max_queued_requests: Optional[int] = sys.maxsize
77
79
  max_total_tokens: Optional[int] = None
78
80
  chunked_prefill_size: Optional[int] = None
79
81
  max_prefill_tokens: int = 16384
@@ -151,6 +153,8 @@ class ServerArgs:
151
153
 
152
154
  # Kernel backend
153
155
  attention_backend: Optional[str] = None
156
+ decode_attention_backend: Optional[str] = None
157
+ prefill_attention_backend: Optional[str] = None
154
158
  sampling_backend: Optional[str] = None
155
159
  grammar_backend: Optional[str] = None
156
160
  mm_attention_backend: Optional[str] = None
@@ -169,7 +173,8 @@ class ServerArgs:
169
173
  ep_size: int = 1
170
174
  enable_ep_moe: bool = False
171
175
  enable_deepep_moe: bool = False
172
- enable_flashinfer_moe: bool = False
176
+ enable_flashinfer_cutlass_moe: bool = False
177
+ enable_flashinfer_trtllm_moe: bool = False
173
178
  enable_flashinfer_allreduce_fusion: bool = False
174
179
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
175
180
  ep_num_redundant_experts: int = 0
@@ -386,13 +391,19 @@ class ServerArgs:
386
391
  )
387
392
  self.page_size = 128
388
393
 
389
- if self.attention_backend == "flashmla":
394
+ if (
395
+ self.attention_backend == "flashmla"
396
+ or self.decode_attention_backend == "flashmla"
397
+ ):
390
398
  logger.warning(
391
399
  "FlashMLA only supports a page_size of 64, change page_size to 64."
392
400
  )
393
401
  self.page_size = 64
394
402
 
395
- if self.attention_backend == "cutlass_mla":
403
+ if (
404
+ self.attention_backend == "cutlass_mla"
405
+ or self.decode_attention_backend == "cutlass_mla"
406
+ ):
396
407
  logger.warning(
397
408
  "Cutlass MLA only supports a page_size of 128, change page_size to 128."
398
409
  )
@@ -428,12 +439,16 @@ class ServerArgs:
428
439
  ), "Please enable dp attention when setting enable_dp_lm_head. "
429
440
 
430
441
  # MoE kernel
431
- if self.enable_flashinfer_moe:
442
+ if self.enable_flashinfer_cutlass_moe:
432
443
  assert (
433
444
  self.quantization == "modelopt_fp4"
434
445
  ), "modelopt_fp4 quantization is required for Flashinfer MOE"
435
446
  os.environ["TRTLLM_ENABLE_PDL"] = "1"
436
447
 
448
+ if self.enable_flashinfer_trtllm_moe:
449
+ assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
450
+ logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
451
+
437
452
  # DeepEP MoE
438
453
  if self.enable_deepep_moe:
439
454
  if self.deepep_mode == "normal":
@@ -458,6 +473,9 @@ class ServerArgs:
458
473
  "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
459
474
  )
460
475
 
476
+ if self.enable_eplb:
477
+ assert self.enable_ep_moe or self.enable_deepep_moe
478
+
461
479
  if self.enable_expert_distribution_metrics and (
462
480
  self.expert_distribution_recorder_mode is None
463
481
  ):
@@ -497,7 +515,7 @@ class ServerArgs:
497
515
  )
498
516
 
499
517
  model_arch = self.get_hf_config().architectures[0]
500
- if model_arch == "DeepseekV3ForCausalLM":
518
+ if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
501
519
  # Auto set draft_model_path DeepSeek-V3/R1
502
520
  if self.speculative_draft_model_path is None:
503
521
  self.speculative_draft_model_path = self.model_path
@@ -789,6 +807,12 @@ class ServerArgs:
789
807
  default=ServerArgs.max_running_requests,
790
808
  help="The maximum number of running requests.",
791
809
  )
810
+ parser.add_argument(
811
+ "--max-queued-requests",
812
+ type=int,
813
+ default=ServerArgs.max_queued_requests,
814
+ help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
815
+ )
792
816
  parser.add_argument(
793
817
  "--max-total-tokens",
794
818
  type=int,
@@ -1092,6 +1116,7 @@ class ServerArgs:
1092
1116
  "pythonic",
1093
1117
  "kimi_k2",
1094
1118
  "qwen3_coder",
1119
+ "glm45",
1095
1120
  ],
1096
1121
  default=ServerArgs.tool_call_parser,
1097
1122
  help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
@@ -1205,6 +1230,35 @@ class ServerArgs:
1205
1230
  default=ServerArgs.attention_backend,
1206
1231
  help="Choose the kernels for attention layers.",
1207
1232
  )
1233
+ parser.add_argument(
1234
+ "--decode-attention-backend",
1235
+ type=str,
1236
+ choices=[
1237
+ "flashinfer",
1238
+ "triton",
1239
+ "torch_native",
1240
+ "fa3",
1241
+ "flashmla",
1242
+ "cutlass_mla",
1243
+ ],
1244
+ default=ServerArgs.decode_attention_backend,
1245
+ help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1246
+ )
1247
+
1248
+ parser.add_argument(
1249
+ "--prefill-attention-backend",
1250
+ type=str,
1251
+ choices=[
1252
+ "flashinfer",
1253
+ "triton",
1254
+ "torch_native",
1255
+ "fa3",
1256
+ "flashmla",
1257
+ "cutlass_mla",
1258
+ ],
1259
+ default=ServerArgs.prefill_attention_backend,
1260
+ help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1261
+ )
1208
1262
  parser.add_argument(
1209
1263
  "--sampling-backend",
1210
1264
  type=str,
@@ -1290,10 +1344,15 @@ class ServerArgs:
1290
1344
  help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1291
1345
  )
1292
1346
  parser.add_argument(
1293
- "--enable-flashinfer-moe",
1347
+ "--enable-flashinfer-cutlass-moe",
1294
1348
  action="store_true",
1295
1349
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1296
1350
  )
1351
+ parser.add_argument(
1352
+ "--enable-flashinfer-trtllm-moe",
1353
+ action="store_true",
1354
+ help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
1355
+ )
1297
1356
  parser.add_argument(
1298
1357
  "--enable-flashinfer-allreduce-fusion",
1299
1358
  action="store_true",
@@ -1,7 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
4
  import logging
3
5
  from dataclasses import replace
4
- from typing import Dict, List, Optional, Sequence, Union
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
5
7
 
6
8
  import torch
7
9
 
@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
20
22
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
23
  from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
22
24
 
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
27
+
23
28
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
24
29
 
25
30
  logger = logging.getLogger(__name__)
@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
802
807
  def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
803
808
  return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
804
809
 
805
- def dispatch(self, **kwargs):
810
+ def dispatch(self, **kwargs) -> DispatchOutput:
806
811
  return self._execute("dispatch", **kwargs)
807
812
 
808
813
  def dispatch_a(self, **kwargs):
@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
811
816
  def dispatch_b(self, **kwargs):
812
817
  return self._execute("dispatch_b", **kwargs)
813
818
 
814
- def combine(self, **kwargs):
819
+ def combine(self, **kwargs) -> torch.Tensor:
815
820
  return self._execute("combine", **kwargs)
816
821
 
817
822
  def combine_a(self, **kwargs):