sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
82
82
  def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
83
83
  """Trivial location - logical expert i corresponds to physical expert i"""
84
84
  common = ExpertLocationMetadata._init_common(server_args, model_config)
85
+
86
+ if common is None:
87
+ return None
88
+
85
89
  num_physical_experts = common["num_physical_experts"]
86
90
  model_config_for_expert_location = common["model_config_for_expert_location"]
87
91
  num_layers = model_config_for_expert_location.num_layers
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
109
113
  physical_to_logical_map = physical_to_logical_map.to(server_args.device)
110
114
 
111
115
  common = ExpertLocationMetadata._init_common(server_args, model_config)
116
+
117
+ if common is None:
118
+ return None
119
+
112
120
  model_config_for_expert_location = common["model_config_for_expert_location"]
113
121
  logical_to_all_physical_map = _compute_logical_to_all_physical_map(
114
122
  physical_to_logical_map,
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
133
141
  logical_count = logical_count.to(server_args.device)
134
142
 
135
143
  common = ExpertLocationMetadata._init_common(server_args, model_config)
144
+
145
+ if common is None:
146
+ return None
147
+
136
148
  model_config_for_expert_location = common["model_config_for_expert_location"]
137
149
  num_physical_experts = common["num_physical_experts"]
138
150
  num_groups = model_config_for_expert_location.num_groups
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
168
180
  ModelConfigForExpertLocation.from_model_config(model_config)
169
181
  )
170
182
 
183
+ if model_config_for_expert_location is None:
184
+ return None
185
+
171
186
  num_physical_experts = (
172
187
  model_config_for_expert_location.num_logical_experts
173
188
  + server_args.ep_num_redundant_experts
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
398
413
  num_logical_experts: int
399
414
  num_groups: Optional[int] = None
400
415
 
401
- @staticmethod
402
- def init_dummy():
403
- return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
404
-
405
416
  @staticmethod
406
417
  def from_model_config(model_config: ModelConfig):
407
418
  model_class, _ = get_model_architecture(model_config)
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
410
421
  model_config.hf_config
411
422
  )
412
423
  else:
413
- return ModelConfigForExpertLocation.init_dummy()
424
+ return None
414
425
 
415
426
 
416
427
  def compute_initial_expert_location_metadata(
417
428
  server_args: ServerArgs, model_config: ModelConfig
418
- ) -> ExpertLocationMetadata:
429
+ ) -> Optional[ExpertLocationMetadata]:
419
430
  data = server_args.init_expert_location
420
431
  if data == "trivial":
421
432
  return ExpertLocationMetadata.init_trivial(server_args, model_config)
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
36
36
  def init_new(cls, layer_id: int):
37
37
  ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
38
38
  expert_location_metadata = get_global_expert_location_metadata()
39
+ assert expert_location_metadata is not None
39
40
 
40
41
  if ep_dispatch_algorithm is None:
41
42
  return None
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
50
50
  torch.cuda.empty_cache()
51
51
 
52
52
  old_expert_location_metadata = get_global_expert_location_metadata()
53
+ assert old_expert_location_metadata is not None
54
+
53
55
  _update_expert_weights(
54
56
  routed_experts_weights_of_layer=routed_experts_weights_of_layer,
55
57
  old_expert_location_metadata=old_expert_location_metadata,
@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
17
17
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
18
18
  from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
19
19
  from sglang.srt.function_call.qwen25_detector import Qwen25Detector
20
+ from sglang.srt.function_call.step3_detector import Step3Detector
20
21
 
21
22
  logger = logging.getLogger(__name__)
22
23
 
@@ -39,6 +40,7 @@ class FunctionCallParser:
39
40
  "kimi_k2": KimiK2Detector,
40
41
  "qwen3_coder": Qwen3CoderDetector,
41
42
  "glm45": Glm4MoeDetector,
43
+ "step3": Step3Detector,
42
44
  }
43
45
 
44
46
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -0,0 +1,436 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import re
5
+ from typing import Any, Dict, List
6
+
7
+ from sglang.srt.entrypoints.openai.protocol import Tool
8
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
9
+ from sglang.srt.function_call.core_types import (
10
+ StreamingParseResult,
11
+ ToolCallItem,
12
+ _GetInfoFunc,
13
+ )
14
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
20
+ """Get the expected type for a function argument from tool schema."""
21
+ name2tool = {tool.function.name: tool for tool in defined_tools}
22
+ if func_name not in name2tool:
23
+ return None
24
+ tool = name2tool[func_name]
25
+ parameters = tool.function.parameters or {}
26
+ properties = parameters.get("properties", {})
27
+ if arg_key not in properties:
28
+ return None
29
+ return properties[arg_key].get("type", None)
30
+
31
+
32
+ def parse_arguments(value: str) -> tuple[Any, bool]:
33
+ """Parse a string value to appropriate type. Returns (parsed_value, success)."""
34
+ try:
35
+ try:
36
+ parsed_value = json.loads(value)
37
+ except:
38
+ parsed_value = ast.literal_eval(value)
39
+ return parsed_value, True
40
+ except:
41
+ return value, False
42
+
43
+
44
+ class Step3Detector(BaseFormatDetector):
45
+ """
46
+ Detector for Step3 model function call format.
47
+
48
+ The Step3 format uses special Unicode tokens to delimit function calls
49
+ with steptml XML format for invocations.
50
+
51
+ Format Structure:
52
+ ```
53
+ <|tool_calls_begin|>
54
+ <|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
55
+ <steptml:parameter name="param1">value1</steptml:parameter>
56
+ <steptml:parameter name="param2">value2</steptml:parameter>
57
+ </steptml:invoke><|tool_call_end|>
58
+ <|tool_calls_end|>
59
+ ```
60
+ """
61
+
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.bot_token = "<|tool_calls_begin|>"
65
+ self.eot_token = "<|tool_calls_end|>"
66
+ self.tool_call_begin = "<|tool_call_begin|>"
67
+ self.tool_call_end = "<|tool_call_end|>"
68
+ self.tool_sep = "<|tool_sep|>"
69
+
70
+ # Regex for parsing steptml invocations
71
+ self.invoke_regex = re.compile(
72
+ r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
73
+ )
74
+ self.param_regex = re.compile(
75
+ r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
76
+ )
77
+
78
+ # Streaming state variables
79
+ self._in_tool_block: bool = False
80
+ self._tool_block_finished: bool = False
81
+ self._current_function_name: str = ""
82
+ self._current_parameters: Dict[str, Any] = {}
83
+ self._in_tool_call: bool = False
84
+ self._function_name_sent: bool = False
85
+
86
+ def has_tool_call(self, text: str) -> bool:
87
+ """Check if the text contains a Step3 format tool call."""
88
+ return self.bot_token in text
89
+
90
+ def _parse_steptml_invoke(
91
+ self, text: str, tools: List[Tool] = None
92
+ ) -> tuple[str, dict]:
93
+ """Parse steptml invoke format to extract function name and parameters."""
94
+ invoke_match = self.invoke_regex.search(text)
95
+ if not invoke_match:
96
+ return None, {}
97
+
98
+ func_name = invoke_match.group(1)
99
+ params_text = invoke_match.group(2)
100
+
101
+ params = {}
102
+ for param_match in self.param_regex.finditer(params_text):
103
+ param_name = param_match.group(1)
104
+ param_value = param_match.group(2).strip()
105
+
106
+ # If tools provided, use schema-aware parsing
107
+ if tools:
108
+ arg_type = get_argument_type(func_name, param_name, tools)
109
+ if arg_type and arg_type != "string":
110
+ parsed_value, _ = parse_arguments(param_value)
111
+ params[param_name] = parsed_value
112
+ else:
113
+ params[param_name] = param_value
114
+ else:
115
+ # Fallback to generic parsing if no tools provided
116
+ parsed_value, _ = parse_arguments(param_value)
117
+ params[param_name] = parsed_value
118
+
119
+ return func_name, params
120
+
121
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
122
+ """
123
+ One-time parsing: Detects and parses tool calls in the provided text.
124
+ """
125
+ if self.bot_token not in text:
126
+ return StreamingParseResult(normal_text=text, calls=[])
127
+
128
+ try:
129
+ pre_text, rest = text.split(self.bot_token, 1)
130
+
131
+ # If no end token, return everything as normal text
132
+ if self.eot_token not in rest:
133
+ return StreamingParseResult(normal_text=text, calls=[])
134
+
135
+ tool_section, post_text = rest.split(self.eot_token, 1)
136
+
137
+ # Find all individual tool calls using regex
138
+ calls = []
139
+ tool_call_pattern = (
140
+ f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
141
+ )
142
+
143
+ for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
144
+ call_content = match.group(1)
145
+
146
+ # Check if it's a function call
147
+ if self.tool_sep not in call_content:
148
+ continue
149
+
150
+ type_part, invoke_part = call_content.split(self.tool_sep, 1)
151
+ if type_part.strip() != "function":
152
+ continue
153
+
154
+ func_name, params = self._parse_steptml_invoke(invoke_part, tools)
155
+ if func_name:
156
+ # Use parse_base_json to create the ToolCallItem
157
+ action = {"name": func_name, "arguments": params}
158
+ calls.extend(self.parse_base_json(action, tools))
159
+
160
+ # Combine pre and post text
161
+ normal_text = pre_text + post_text
162
+
163
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error in detect_and_parse: {e}")
167
+ # Return the original text if parsing fails
168
+ return StreamingParseResult(normal_text=text)
169
+
170
+ def parse_streaming_increment(
171
+ self, new_text: str, tools: List[Tool]
172
+ ) -> StreamingParseResult:
173
+ """
174
+ Streaming incremental parsing for Step3 format.
175
+ """
176
+ self._buffer += new_text
177
+
178
+ # Build tool indices for validation
179
+ if not hasattr(self, "_tool_indices"):
180
+ self._tool_indices = self._get_tool_indices(tools)
181
+
182
+ # If we've finished the tool block, everything is normal text
183
+ if self._tool_block_finished:
184
+ normal_text = self._buffer
185
+ self._buffer = ""
186
+ return StreamingParseResult(normal_text=normal_text)
187
+
188
+ # Check if tool block hasn't started yet
189
+ if not self._in_tool_block:
190
+ if self.bot_token in self._buffer:
191
+ idx = self._buffer.find(self.bot_token)
192
+ normal_text = self._buffer[:idx]
193
+ self._buffer = self._buffer[idx + len(self.bot_token) :]
194
+ self._in_tool_block = True
195
+ return StreamingParseResult(normal_text=normal_text)
196
+ else:
197
+ # Check if we might have a partial bot_token
198
+ partial_len = self._ends_with_partial_token(
199
+ self._buffer, self.bot_token
200
+ )
201
+ if partial_len:
202
+ return StreamingParseResult() # Wait for more text
203
+ else:
204
+ normal_text = self._buffer
205
+ self._buffer = ""
206
+ return StreamingParseResult(normal_text=normal_text)
207
+
208
+ # We're inside the tool block
209
+ calls: List[ToolCallItem] = []
210
+
211
+ # Check if tool block is ending
212
+ if self.eot_token in self._buffer:
213
+ idx = self._buffer.find(self.eot_token)
214
+
215
+ # If we're in the middle of a tool call, we need to handle it
216
+ if self._in_tool_call:
217
+ # The buffer before eot_token might contain the end of the current tool call
218
+ before_eot = self._buffer[:idx]
219
+ if self.tool_call_end in before_eot:
220
+ # Parse this final tool call
221
+ result = self._parse_partial_tool_call(tools)
222
+ calls.extend(result.calls)
223
+ else:
224
+ # Incomplete tool call - log warning
225
+ logger.warning("Tool block ended with incomplete tool call")
226
+
227
+ remaining = self._buffer[idx + len(self.eot_token) :]
228
+ self._buffer = ""
229
+ self._tool_block_finished = True
230
+
231
+ # Reset any partial tool call state
232
+ self._reset_streaming_state()
233
+
234
+ return StreamingParseResult(normal_text=remaining, calls=calls)
235
+
236
+ # Check if we're in a tool call or need to start one
237
+ if not self._in_tool_call:
238
+ if self.tool_call_begin in self._buffer:
239
+ idx = self._buffer.find(self.tool_call_begin)
240
+ # Remove any content before tool call begin (shouldn't happen but be safe)
241
+ self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
242
+ self._in_tool_call = True
243
+ self._function_name_sent = False
244
+ self._current_function_name = ""
245
+ self._current_parameters = {}
246
+ # Fall through to parse the partial tool call
247
+ else:
248
+ # Wait for tool call to begin
249
+ return StreamingParseResult()
250
+
251
+ # Parse partial tool call
252
+ if self._in_tool_call:
253
+ return self._parse_partial_tool_call(tools)
254
+
255
+ return StreamingParseResult()
256
+
257
+ def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
258
+ """Parse partial tool call for streaming scenarios."""
259
+ calls = []
260
+
261
+ # Check if we have tool_sep (means we're past the type declaration)
262
+ if self.tool_sep not in self._buffer:
263
+ return StreamingParseResult(calls=calls) # Wait for more text
264
+
265
+ type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
266
+ if type_part.strip() != "function":
267
+ # Invalid tool type, skip this tool call
268
+ self._reset_streaming_state()
269
+ return StreamingParseResult(calls=calls)
270
+
271
+ # Try to extract function name if not sent yet
272
+ if not self._function_name_sent:
273
+ name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
274
+ if name_match:
275
+ func_name = name_match.group(1)
276
+
277
+ # Validate function name
278
+ if func_name in self._tool_indices:
279
+ self._current_function_name = func_name
280
+ self._function_name_sent = True
281
+
282
+ # Initialize tool tracking
283
+ if self.current_tool_id == -1:
284
+ self.current_tool_id = 0
285
+
286
+ # Ensure tracking arrays are large enough
287
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
288
+ self.prev_tool_call_arr.append({})
289
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
290
+ self.streamed_args_for_tool.append("")
291
+
292
+ # Store tool call info
293
+ self.prev_tool_call_arr[self.current_tool_id] = {
294
+ "name": func_name,
295
+ "arguments": {},
296
+ }
297
+
298
+ # Send tool name with empty parameters
299
+ calls.append(
300
+ ToolCallItem(
301
+ tool_index=self.current_tool_id,
302
+ name=func_name,
303
+ parameters="",
304
+ )
305
+ )
306
+ else:
307
+ # Invalid function name
308
+ logger.warning(f"Invalid function name: {func_name}")
309
+ self._reset_streaming_state()
310
+ return StreamingParseResult(calls=calls)
311
+ else:
312
+ # Function name not complete yet
313
+ return StreamingParseResult(calls=calls)
314
+
315
+ # Parse parameters incrementally
316
+ if self._function_name_sent:
317
+ # Extract all complete parameters
318
+ new_params = {}
319
+ for param_match in self.param_regex.finditer(invoke_part):
320
+ param_name = param_match.group(1)
321
+ param_value = param_match.group(2).strip()
322
+
323
+ # Use schema-aware parsing
324
+ arg_type = get_argument_type(
325
+ self._current_function_name, param_name, tools
326
+ )
327
+ if arg_type and arg_type != "string":
328
+ parsed_value, _ = parse_arguments(param_value)
329
+ new_params[param_name] = parsed_value
330
+ else:
331
+ new_params[param_name] = param_value
332
+
333
+ # Check if we have new parameters to stream
334
+ if new_params != self._current_parameters:
335
+ # Build the JSON content without the closing brace for streaming
336
+ if not self._current_parameters:
337
+ # First parameters - send opening brace and content
338
+ params_content = json.dumps(new_params, ensure_ascii=False)
339
+ if len(params_content) > 2: # More than just "{}"
340
+ # Send everything except the closing brace
341
+ diff = params_content[:-1]
342
+ else:
343
+ diff = "{"
344
+ else:
345
+ # Subsequent parameters - calculate the incremental diff
346
+ old_json = json.dumps(self._current_parameters, ensure_ascii=False)
347
+ new_json = json.dumps(new_params, ensure_ascii=False)
348
+
349
+ # Remove closing braces for comparison
350
+ old_without_brace = old_json[:-1]
351
+ new_without_brace = new_json[:-1]
352
+
353
+ # The new content should extend the old content
354
+ if new_without_brace.startswith(old_without_brace):
355
+ diff = new_without_brace[len(old_without_brace) :]
356
+ else:
357
+ # Parameters changed in unexpected way - shouldn't happen in normal streaming
358
+ diff = ""
359
+
360
+ if diff:
361
+ calls.append(
362
+ ToolCallItem(
363
+ tool_index=self.current_tool_id,
364
+ parameters=diff,
365
+ )
366
+ )
367
+ self.streamed_args_for_tool[self.current_tool_id] += diff
368
+
369
+ # Update current state
370
+ self._current_parameters = new_params
371
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
372
+
373
+ # Check if tool call is complete
374
+ if self.tool_call_end in self._buffer:
375
+ # Send closing brace if we've sent any parameters
376
+ if self.streamed_args_for_tool[self.current_tool_id]:
377
+ calls.append(
378
+ ToolCallItem(
379
+ tool_index=self.current_tool_id,
380
+ parameters="}",
381
+ )
382
+ )
383
+ self.streamed_args_for_tool[self.current_tool_id] += "}"
384
+
385
+ # Find the end position
386
+ end_idx = self._buffer.find(self.tool_call_end)
387
+ # Remove the processed tool call from buffer
388
+ self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
389
+
390
+ # Reset state for next tool call
391
+ self._reset_streaming_state()
392
+ self.current_tool_id += 1
393
+
394
+ return StreamingParseResult(calls=calls)
395
+
396
+ def _reset_streaming_state(self):
397
+ """Reset streaming state for the next tool call"""
398
+ self._in_tool_call = False
399
+ self._function_name_sent = False
400
+ self._current_function_name = ""
401
+ self._current_parameters = {}
402
+
403
+ def supports_structural_tag(self) -> bool:
404
+ """Return True if this detector supports structural tag format."""
405
+ return False
406
+
407
+ def structure_info(self) -> _GetInfoFunc:
408
+ raise NotImplementedError()
409
+
410
+ def build_ebnf(self, tools: List[Tool]) -> str:
411
+ """
412
+ Build EBNF grammar for Step3 tool call format.
413
+ """
414
+ # Custom call rule for steptml format
415
+ call_rule_fmt = (
416
+ '"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
417
+ '{arguments_rule} "</steptml:invoke>"'
418
+ )
419
+
420
+ # Custom key-value rule for steptml parameters
421
+ key_value_rule_fmt = (
422
+ '"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
423
+ )
424
+
425
+ return EBNFComposer.build_ebnf(
426
+ tools,
427
+ sequence_start_token=self.bot_token,
428
+ sequence_end_token=self.eot_token,
429
+ individual_call_start_token=self.tool_call_begin,
430
+ individual_call_end_token=self.tool_call_end,
431
+ tool_call_separator="",
432
+ function_format="xml",
433
+ call_rule_fmt=call_rule_fmt,
434
+ key_value_rule_fmt=key_value_rule_fmt,
435
+ key_value_separator="",
436
+ )
@@ -41,6 +41,7 @@ from sglang.srt.configs import (
41
41
  ExaoneConfig,
42
42
  KimiVLConfig,
43
43
  MultiModalityConfig,
44
+ Step3VLConfig,
44
45
  )
45
46
  from sglang.srt.configs.internvl import InternVLChatConfig
46
47
  from sglang.srt.connector import create_remote_connector
@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
54
55
  MultiModalityConfig.model_type: MultiModalityConfig,
55
56
  KimiVLConfig.model_type: KimiVLConfig,
56
57
  InternVLChatConfig.model_type: InternVLChatConfig,
58
+ Step3VLConfig.model_type: Step3VLConfig,
57
59
  }
58
60
 
59
61
  for name, cls in _CONFIG_REGISTRY.items():
@@ -165,7 +165,7 @@ def process_content_for_template_format(
165
165
  new_msg["content"] = processed_content_parts
166
166
  return new_msg
167
167
 
168
- else: # content_format == "string"
168
+ elif content_format == "string":
169
169
  # String format: flatten to text only (for templates like DeepSeek)
170
170
  text_parts = []
171
171
  for chunk in msg_dict["content"]:
@@ -179,3 +179,6 @@ def process_content_for_template_format(
179
179
  new_msg["content"] = " ".join(text_parts) if text_parts else ""
180
180
  new_msg = {k: v for k, v in new_msg.items() if v is not None}
181
181
  return new_msg
182
+
183
+ else:
184
+ raise ValueError(f"Invalid content format: {content_format}")
@@ -209,7 +209,8 @@ def cutlass_fused_experts_fp8(
209
209
  )
210
210
 
211
211
  result = torch.empty((m, k), device=device, dtype=out_dtype)
212
- return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
212
+ apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
213
+ return result
213
214
 
214
215
 
215
216
  FLOAT4_E2M1_MAX = 6.0