sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -238,6 +238,9 @@ async def health() -> Response:
|
|
238
238
|
@app.get("/health_generate")
|
239
239
|
async def health_generate(request: Request) -> Response:
|
240
240
|
"""Check the health of the inference server by generating one token."""
|
241
|
+
if _global_state.tokenizer_manager.gracefully_exit:
|
242
|
+
logger.info("Health check request received during shutdown. Returning 503.")
|
243
|
+
return Response(status_code=503)
|
241
244
|
|
242
245
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
243
246
|
rid = f"HEALTH_CHECK_{time.time()}"
|
@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
|
|
260
263
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
261
264
|
break
|
262
265
|
|
263
|
-
|
266
|
+
# This request is a special request.
|
267
|
+
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
268
|
+
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
264
269
|
task = asyncio.create_task(gen())
|
265
|
-
|
270
|
+
|
271
|
+
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
272
|
+
tic = time.time()
|
273
|
+
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
266
274
|
await asyncio.sleep(1)
|
267
275
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
268
276
|
task.cancel()
|
@@ -127,12 +127,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
127
127
|
request.skip_special_tokens = False
|
128
128
|
if not isinstance(request.tool_choice, str):
|
129
129
|
tools = [
|
130
|
-
item.model_dump()
|
130
|
+
item.function.model_dump()
|
131
131
|
for item in request.tools
|
132
132
|
if item.function.name == request.tool_choice.function.name
|
133
133
|
]
|
134
134
|
else:
|
135
|
-
tools = [item.model_dump() for item in request.tools]
|
135
|
+
tools = [item.function.model_dump() for item in request.tools]
|
136
136
|
|
137
137
|
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
138
138
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
@@ -178,25 +178,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
178
178
|
audio_data,
|
179
179
|
modalities,
|
180
180
|
)
|
181
|
-
|
182
|
-
if "tool_calls" in processed_msg and isinstance(
|
183
|
-
processed_msg.get("tool_calls"), list
|
184
|
-
):
|
185
|
-
for call in processed_msg["tool_calls"]:
|
186
|
-
try:
|
187
|
-
if "arguments" in call["function"] and isinstance(
|
188
|
-
call["function"]["arguments"], str
|
189
|
-
):
|
190
|
-
call["function"]["arguments"] = json.loads(
|
191
|
-
call["function"]["arguments"]
|
192
|
-
)
|
193
|
-
except json.JSONDecodeError as e:
|
194
|
-
# Log a warning or error if JSON parsing fails for arguments
|
195
|
-
logger.warning(
|
196
|
-
f"Failed to parse tool call arguments as JSON: {e}"
|
197
|
-
)
|
198
|
-
# Decide whether to continue or raise the exception based on desired behavior
|
199
|
-
continue # Or raise e if strict parsing is required
|
200
181
|
openai_compatible_messages.append(processed_msg)
|
201
182
|
|
202
183
|
# Handle assistant prefix for continue_final_message
|
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
|
|
47
47
|
rank: int,
|
48
48
|
):
|
49
49
|
if server_args.expert_distribution_recorder_mode is not None:
|
50
|
+
assert (
|
51
|
+
expert_location_metadata is not None
|
52
|
+
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
|
53
|
+
"reason is that you are using a model that does not support expert distribution"
|
54
|
+
"recording. Try setting `get_model_config_for_expert_location` in your model."
|
50
55
|
return _ExpertDistributionRecorderReal(
|
51
56
|
server_args, expert_location_metadata, rank
|
52
57
|
)
|
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
|
|
82
82
|
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
83
83
|
"""Trivial location - logical expert i corresponds to physical expert i"""
|
84
84
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
85
|
+
|
86
|
+
if common is None:
|
87
|
+
return None
|
88
|
+
|
85
89
|
num_physical_experts = common["num_physical_experts"]
|
86
90
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
87
91
|
num_layers = model_config_for_expert_location.num_layers
|
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
|
|
109
113
|
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
110
114
|
|
111
115
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
116
|
+
|
117
|
+
if common is None:
|
118
|
+
return None
|
119
|
+
|
112
120
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
113
121
|
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
114
122
|
physical_to_logical_map,
|
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
|
|
133
141
|
logical_count = logical_count.to(server_args.device)
|
134
142
|
|
135
143
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
144
|
+
|
145
|
+
if common is None:
|
146
|
+
return None
|
147
|
+
|
136
148
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
137
149
|
num_physical_experts = common["num_physical_experts"]
|
138
150
|
num_groups = model_config_for_expert_location.num_groups
|
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
|
|
168
180
|
ModelConfigForExpertLocation.from_model_config(model_config)
|
169
181
|
)
|
170
182
|
|
183
|
+
if model_config_for_expert_location is None:
|
184
|
+
return None
|
185
|
+
|
171
186
|
num_physical_experts = (
|
172
187
|
model_config_for_expert_location.num_logical_experts
|
173
188
|
+ server_args.ep_num_redundant_experts
|
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
|
|
398
413
|
num_logical_experts: int
|
399
414
|
num_groups: Optional[int] = None
|
400
415
|
|
401
|
-
@staticmethod
|
402
|
-
def init_dummy():
|
403
|
-
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
404
|
-
|
405
416
|
@staticmethod
|
406
417
|
def from_model_config(model_config: ModelConfig):
|
407
418
|
model_class, _ = get_model_architecture(model_config)
|
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
|
|
410
421
|
model_config.hf_config
|
411
422
|
)
|
412
423
|
else:
|
413
|
-
return
|
424
|
+
return None
|
414
425
|
|
415
426
|
|
416
427
|
def compute_initial_expert_location_metadata(
|
417
428
|
server_args: ServerArgs, model_config: ModelConfig
|
418
|
-
) -> ExpertLocationMetadata:
|
429
|
+
) -> Optional[ExpertLocationMetadata]:
|
419
430
|
data = server_args.init_expert_location
|
420
431
|
if data == "trivial":
|
421
432
|
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
|
|
36
36
|
def init_new(cls, layer_id: int):
|
37
37
|
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
38
38
|
expert_location_metadata = get_global_expert_location_metadata()
|
39
|
+
assert expert_location_metadata is not None
|
39
40
|
|
40
41
|
if ep_dispatch_algorithm is None:
|
41
42
|
return None
|
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
|
|
50
50
|
torch.cuda.empty_cache()
|
51
51
|
|
52
52
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
53
|
+
assert old_expert_location_metadata is not None
|
54
|
+
|
53
55
|
_update_expert_weights(
|
54
56
|
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
55
57
|
old_expert_location_metadata=old_expert_location_metadata,
|
@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
|
17
17
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
18
18
|
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
19
19
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
20
|
+
from sglang.srt.function_call.step3_detector import Step3Detector
|
20
21
|
|
21
22
|
logger = logging.getLogger(__name__)
|
22
23
|
|
@@ -39,6 +40,7 @@ class FunctionCallParser:
|
|
39
40
|
"kimi_k2": KimiK2Detector,
|
40
41
|
"qwen3_coder": Qwen3CoderDetector,
|
41
42
|
"glm45": Glm4MoeDetector,
|
43
|
+
"step3": Step3Detector,
|
42
44
|
}
|
43
45
|
|
44
46
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
@@ -0,0 +1,436 @@
|
|
1
|
+
import ast
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
from typing import Any, Dict, List
|
6
|
+
|
7
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
8
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
9
|
+
from sglang.srt.function_call.core_types import (
|
10
|
+
StreamingParseResult,
|
11
|
+
ToolCallItem,
|
12
|
+
_GetInfoFunc,
|
13
|
+
)
|
14
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
|
20
|
+
"""Get the expected type for a function argument from tool schema."""
|
21
|
+
name2tool = {tool.function.name: tool for tool in defined_tools}
|
22
|
+
if func_name not in name2tool:
|
23
|
+
return None
|
24
|
+
tool = name2tool[func_name]
|
25
|
+
parameters = tool.function.parameters or {}
|
26
|
+
properties = parameters.get("properties", {})
|
27
|
+
if arg_key not in properties:
|
28
|
+
return None
|
29
|
+
return properties[arg_key].get("type", None)
|
30
|
+
|
31
|
+
|
32
|
+
def parse_arguments(value: str) -> tuple[Any, bool]:
|
33
|
+
"""Parse a string value to appropriate type. Returns (parsed_value, success)."""
|
34
|
+
try:
|
35
|
+
try:
|
36
|
+
parsed_value = json.loads(value)
|
37
|
+
except:
|
38
|
+
parsed_value = ast.literal_eval(value)
|
39
|
+
return parsed_value, True
|
40
|
+
except:
|
41
|
+
return value, False
|
42
|
+
|
43
|
+
|
44
|
+
class Step3Detector(BaseFormatDetector):
|
45
|
+
"""
|
46
|
+
Detector for Step3 model function call format.
|
47
|
+
|
48
|
+
The Step3 format uses special Unicode tokens to delimit function calls
|
49
|
+
with steptml XML format for invocations.
|
50
|
+
|
51
|
+
Format Structure:
|
52
|
+
```
|
53
|
+
<|tool_calls_begin|>
|
54
|
+
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
|
55
|
+
<steptml:parameter name="param1">value1</steptml:parameter>
|
56
|
+
<steptml:parameter name="param2">value2</steptml:parameter>
|
57
|
+
</steptml:invoke><|tool_call_end|>
|
58
|
+
<|tool_calls_end|>
|
59
|
+
```
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(self):
|
63
|
+
super().__init__()
|
64
|
+
self.bot_token = "<|tool_calls_begin|>"
|
65
|
+
self.eot_token = "<|tool_calls_end|>"
|
66
|
+
self.tool_call_begin = "<|tool_call_begin|>"
|
67
|
+
self.tool_call_end = "<|tool_call_end|>"
|
68
|
+
self.tool_sep = "<|tool_sep|>"
|
69
|
+
|
70
|
+
# Regex for parsing steptml invocations
|
71
|
+
self.invoke_regex = re.compile(
|
72
|
+
r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
|
73
|
+
)
|
74
|
+
self.param_regex = re.compile(
|
75
|
+
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
|
76
|
+
)
|
77
|
+
|
78
|
+
# Streaming state variables
|
79
|
+
self._in_tool_block: bool = False
|
80
|
+
self._tool_block_finished: bool = False
|
81
|
+
self._current_function_name: str = ""
|
82
|
+
self._current_parameters: Dict[str, Any] = {}
|
83
|
+
self._in_tool_call: bool = False
|
84
|
+
self._function_name_sent: bool = False
|
85
|
+
|
86
|
+
def has_tool_call(self, text: str) -> bool:
|
87
|
+
"""Check if the text contains a Step3 format tool call."""
|
88
|
+
return self.bot_token in text
|
89
|
+
|
90
|
+
def _parse_steptml_invoke(
|
91
|
+
self, text: str, tools: List[Tool] = None
|
92
|
+
) -> tuple[str, dict]:
|
93
|
+
"""Parse steptml invoke format to extract function name and parameters."""
|
94
|
+
invoke_match = self.invoke_regex.search(text)
|
95
|
+
if not invoke_match:
|
96
|
+
return None, {}
|
97
|
+
|
98
|
+
func_name = invoke_match.group(1)
|
99
|
+
params_text = invoke_match.group(2)
|
100
|
+
|
101
|
+
params = {}
|
102
|
+
for param_match in self.param_regex.finditer(params_text):
|
103
|
+
param_name = param_match.group(1)
|
104
|
+
param_value = param_match.group(2).strip()
|
105
|
+
|
106
|
+
# If tools provided, use schema-aware parsing
|
107
|
+
if tools:
|
108
|
+
arg_type = get_argument_type(func_name, param_name, tools)
|
109
|
+
if arg_type and arg_type != "string":
|
110
|
+
parsed_value, _ = parse_arguments(param_value)
|
111
|
+
params[param_name] = parsed_value
|
112
|
+
else:
|
113
|
+
params[param_name] = param_value
|
114
|
+
else:
|
115
|
+
# Fallback to generic parsing if no tools provided
|
116
|
+
parsed_value, _ = parse_arguments(param_value)
|
117
|
+
params[param_name] = parsed_value
|
118
|
+
|
119
|
+
return func_name, params
|
120
|
+
|
121
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
122
|
+
"""
|
123
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
124
|
+
"""
|
125
|
+
if self.bot_token not in text:
|
126
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
127
|
+
|
128
|
+
try:
|
129
|
+
pre_text, rest = text.split(self.bot_token, 1)
|
130
|
+
|
131
|
+
# If no end token, return everything as normal text
|
132
|
+
if self.eot_token not in rest:
|
133
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
134
|
+
|
135
|
+
tool_section, post_text = rest.split(self.eot_token, 1)
|
136
|
+
|
137
|
+
# Find all individual tool calls using regex
|
138
|
+
calls = []
|
139
|
+
tool_call_pattern = (
|
140
|
+
f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
|
141
|
+
)
|
142
|
+
|
143
|
+
for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
|
144
|
+
call_content = match.group(1)
|
145
|
+
|
146
|
+
# Check if it's a function call
|
147
|
+
if self.tool_sep not in call_content:
|
148
|
+
continue
|
149
|
+
|
150
|
+
type_part, invoke_part = call_content.split(self.tool_sep, 1)
|
151
|
+
if type_part.strip() != "function":
|
152
|
+
continue
|
153
|
+
|
154
|
+
func_name, params = self._parse_steptml_invoke(invoke_part, tools)
|
155
|
+
if func_name:
|
156
|
+
# Use parse_base_json to create the ToolCallItem
|
157
|
+
action = {"name": func_name, "arguments": params}
|
158
|
+
calls.extend(self.parse_base_json(action, tools))
|
159
|
+
|
160
|
+
# Combine pre and post text
|
161
|
+
normal_text = pre_text + post_text
|
162
|
+
|
163
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
164
|
+
|
165
|
+
except Exception as e:
|
166
|
+
logger.error(f"Error in detect_and_parse: {e}")
|
167
|
+
# Return the original text if parsing fails
|
168
|
+
return StreamingParseResult(normal_text=text)
|
169
|
+
|
170
|
+
def parse_streaming_increment(
|
171
|
+
self, new_text: str, tools: List[Tool]
|
172
|
+
) -> StreamingParseResult:
|
173
|
+
"""
|
174
|
+
Streaming incremental parsing for Step3 format.
|
175
|
+
"""
|
176
|
+
self._buffer += new_text
|
177
|
+
|
178
|
+
# Build tool indices for validation
|
179
|
+
if not hasattr(self, "_tool_indices"):
|
180
|
+
self._tool_indices = self._get_tool_indices(tools)
|
181
|
+
|
182
|
+
# If we've finished the tool block, everything is normal text
|
183
|
+
if self._tool_block_finished:
|
184
|
+
normal_text = self._buffer
|
185
|
+
self._buffer = ""
|
186
|
+
return StreamingParseResult(normal_text=normal_text)
|
187
|
+
|
188
|
+
# Check if tool block hasn't started yet
|
189
|
+
if not self._in_tool_block:
|
190
|
+
if self.bot_token in self._buffer:
|
191
|
+
idx = self._buffer.find(self.bot_token)
|
192
|
+
normal_text = self._buffer[:idx]
|
193
|
+
self._buffer = self._buffer[idx + len(self.bot_token) :]
|
194
|
+
self._in_tool_block = True
|
195
|
+
return StreamingParseResult(normal_text=normal_text)
|
196
|
+
else:
|
197
|
+
# Check if we might have a partial bot_token
|
198
|
+
partial_len = self._ends_with_partial_token(
|
199
|
+
self._buffer, self.bot_token
|
200
|
+
)
|
201
|
+
if partial_len:
|
202
|
+
return StreamingParseResult() # Wait for more text
|
203
|
+
else:
|
204
|
+
normal_text = self._buffer
|
205
|
+
self._buffer = ""
|
206
|
+
return StreamingParseResult(normal_text=normal_text)
|
207
|
+
|
208
|
+
# We're inside the tool block
|
209
|
+
calls: List[ToolCallItem] = []
|
210
|
+
|
211
|
+
# Check if tool block is ending
|
212
|
+
if self.eot_token in self._buffer:
|
213
|
+
idx = self._buffer.find(self.eot_token)
|
214
|
+
|
215
|
+
# If we're in the middle of a tool call, we need to handle it
|
216
|
+
if self._in_tool_call:
|
217
|
+
# The buffer before eot_token might contain the end of the current tool call
|
218
|
+
before_eot = self._buffer[:idx]
|
219
|
+
if self.tool_call_end in before_eot:
|
220
|
+
# Parse this final tool call
|
221
|
+
result = self._parse_partial_tool_call(tools)
|
222
|
+
calls.extend(result.calls)
|
223
|
+
else:
|
224
|
+
# Incomplete tool call - log warning
|
225
|
+
logger.warning("Tool block ended with incomplete tool call")
|
226
|
+
|
227
|
+
remaining = self._buffer[idx + len(self.eot_token) :]
|
228
|
+
self._buffer = ""
|
229
|
+
self._tool_block_finished = True
|
230
|
+
|
231
|
+
# Reset any partial tool call state
|
232
|
+
self._reset_streaming_state()
|
233
|
+
|
234
|
+
return StreamingParseResult(normal_text=remaining, calls=calls)
|
235
|
+
|
236
|
+
# Check if we're in a tool call or need to start one
|
237
|
+
if not self._in_tool_call:
|
238
|
+
if self.tool_call_begin in self._buffer:
|
239
|
+
idx = self._buffer.find(self.tool_call_begin)
|
240
|
+
# Remove any content before tool call begin (shouldn't happen but be safe)
|
241
|
+
self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
|
242
|
+
self._in_tool_call = True
|
243
|
+
self._function_name_sent = False
|
244
|
+
self._current_function_name = ""
|
245
|
+
self._current_parameters = {}
|
246
|
+
# Fall through to parse the partial tool call
|
247
|
+
else:
|
248
|
+
# Wait for tool call to begin
|
249
|
+
return StreamingParseResult()
|
250
|
+
|
251
|
+
# Parse partial tool call
|
252
|
+
if self._in_tool_call:
|
253
|
+
return self._parse_partial_tool_call(tools)
|
254
|
+
|
255
|
+
return StreamingParseResult()
|
256
|
+
|
257
|
+
def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
|
258
|
+
"""Parse partial tool call for streaming scenarios."""
|
259
|
+
calls = []
|
260
|
+
|
261
|
+
# Check if we have tool_sep (means we're past the type declaration)
|
262
|
+
if self.tool_sep not in self._buffer:
|
263
|
+
return StreamingParseResult(calls=calls) # Wait for more text
|
264
|
+
|
265
|
+
type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
|
266
|
+
if type_part.strip() != "function":
|
267
|
+
# Invalid tool type, skip this tool call
|
268
|
+
self._reset_streaming_state()
|
269
|
+
return StreamingParseResult(calls=calls)
|
270
|
+
|
271
|
+
# Try to extract function name if not sent yet
|
272
|
+
if not self._function_name_sent:
|
273
|
+
name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
|
274
|
+
if name_match:
|
275
|
+
func_name = name_match.group(1)
|
276
|
+
|
277
|
+
# Validate function name
|
278
|
+
if func_name in self._tool_indices:
|
279
|
+
self._current_function_name = func_name
|
280
|
+
self._function_name_sent = True
|
281
|
+
|
282
|
+
# Initialize tool tracking
|
283
|
+
if self.current_tool_id == -1:
|
284
|
+
self.current_tool_id = 0
|
285
|
+
|
286
|
+
# Ensure tracking arrays are large enough
|
287
|
+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
288
|
+
self.prev_tool_call_arr.append({})
|
289
|
+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
290
|
+
self.streamed_args_for_tool.append("")
|
291
|
+
|
292
|
+
# Store tool call info
|
293
|
+
self.prev_tool_call_arr[self.current_tool_id] = {
|
294
|
+
"name": func_name,
|
295
|
+
"arguments": {},
|
296
|
+
}
|
297
|
+
|
298
|
+
# Send tool name with empty parameters
|
299
|
+
calls.append(
|
300
|
+
ToolCallItem(
|
301
|
+
tool_index=self.current_tool_id,
|
302
|
+
name=func_name,
|
303
|
+
parameters="",
|
304
|
+
)
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
# Invalid function name
|
308
|
+
logger.warning(f"Invalid function name: {func_name}")
|
309
|
+
self._reset_streaming_state()
|
310
|
+
return StreamingParseResult(calls=calls)
|
311
|
+
else:
|
312
|
+
# Function name not complete yet
|
313
|
+
return StreamingParseResult(calls=calls)
|
314
|
+
|
315
|
+
# Parse parameters incrementally
|
316
|
+
if self._function_name_sent:
|
317
|
+
# Extract all complete parameters
|
318
|
+
new_params = {}
|
319
|
+
for param_match in self.param_regex.finditer(invoke_part):
|
320
|
+
param_name = param_match.group(1)
|
321
|
+
param_value = param_match.group(2).strip()
|
322
|
+
|
323
|
+
# Use schema-aware parsing
|
324
|
+
arg_type = get_argument_type(
|
325
|
+
self._current_function_name, param_name, tools
|
326
|
+
)
|
327
|
+
if arg_type and arg_type != "string":
|
328
|
+
parsed_value, _ = parse_arguments(param_value)
|
329
|
+
new_params[param_name] = parsed_value
|
330
|
+
else:
|
331
|
+
new_params[param_name] = param_value
|
332
|
+
|
333
|
+
# Check if we have new parameters to stream
|
334
|
+
if new_params != self._current_parameters:
|
335
|
+
# Build the JSON content without the closing brace for streaming
|
336
|
+
if not self._current_parameters:
|
337
|
+
# First parameters - send opening brace and content
|
338
|
+
params_content = json.dumps(new_params, ensure_ascii=False)
|
339
|
+
if len(params_content) > 2: # More than just "{}"
|
340
|
+
# Send everything except the closing brace
|
341
|
+
diff = params_content[:-1]
|
342
|
+
else:
|
343
|
+
diff = "{"
|
344
|
+
else:
|
345
|
+
# Subsequent parameters - calculate the incremental diff
|
346
|
+
old_json = json.dumps(self._current_parameters, ensure_ascii=False)
|
347
|
+
new_json = json.dumps(new_params, ensure_ascii=False)
|
348
|
+
|
349
|
+
# Remove closing braces for comparison
|
350
|
+
old_without_brace = old_json[:-1]
|
351
|
+
new_without_brace = new_json[:-1]
|
352
|
+
|
353
|
+
# The new content should extend the old content
|
354
|
+
if new_without_brace.startswith(old_without_brace):
|
355
|
+
diff = new_without_brace[len(old_without_brace) :]
|
356
|
+
else:
|
357
|
+
# Parameters changed in unexpected way - shouldn't happen in normal streaming
|
358
|
+
diff = ""
|
359
|
+
|
360
|
+
if diff:
|
361
|
+
calls.append(
|
362
|
+
ToolCallItem(
|
363
|
+
tool_index=self.current_tool_id,
|
364
|
+
parameters=diff,
|
365
|
+
)
|
366
|
+
)
|
367
|
+
self.streamed_args_for_tool[self.current_tool_id] += diff
|
368
|
+
|
369
|
+
# Update current state
|
370
|
+
self._current_parameters = new_params
|
371
|
+
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
|
372
|
+
|
373
|
+
# Check if tool call is complete
|
374
|
+
if self.tool_call_end in self._buffer:
|
375
|
+
# Send closing brace if we've sent any parameters
|
376
|
+
if self.streamed_args_for_tool[self.current_tool_id]:
|
377
|
+
calls.append(
|
378
|
+
ToolCallItem(
|
379
|
+
tool_index=self.current_tool_id,
|
380
|
+
parameters="}",
|
381
|
+
)
|
382
|
+
)
|
383
|
+
self.streamed_args_for_tool[self.current_tool_id] += "}"
|
384
|
+
|
385
|
+
# Find the end position
|
386
|
+
end_idx = self._buffer.find(self.tool_call_end)
|
387
|
+
# Remove the processed tool call from buffer
|
388
|
+
self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
|
389
|
+
|
390
|
+
# Reset state for next tool call
|
391
|
+
self._reset_streaming_state()
|
392
|
+
self.current_tool_id += 1
|
393
|
+
|
394
|
+
return StreamingParseResult(calls=calls)
|
395
|
+
|
396
|
+
def _reset_streaming_state(self):
|
397
|
+
"""Reset streaming state for the next tool call"""
|
398
|
+
self._in_tool_call = False
|
399
|
+
self._function_name_sent = False
|
400
|
+
self._current_function_name = ""
|
401
|
+
self._current_parameters = {}
|
402
|
+
|
403
|
+
def supports_structural_tag(self) -> bool:
|
404
|
+
"""Return True if this detector supports structural tag format."""
|
405
|
+
return False
|
406
|
+
|
407
|
+
def structure_info(self) -> _GetInfoFunc:
|
408
|
+
raise NotImplementedError()
|
409
|
+
|
410
|
+
def build_ebnf(self, tools: List[Tool]) -> str:
|
411
|
+
"""
|
412
|
+
Build EBNF grammar for Step3 tool call format.
|
413
|
+
"""
|
414
|
+
# Custom call rule for steptml format
|
415
|
+
call_rule_fmt = (
|
416
|
+
'"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
|
417
|
+
'{arguments_rule} "</steptml:invoke>"'
|
418
|
+
)
|
419
|
+
|
420
|
+
# Custom key-value rule for steptml parameters
|
421
|
+
key_value_rule_fmt = (
|
422
|
+
'"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
|
423
|
+
)
|
424
|
+
|
425
|
+
return EBNFComposer.build_ebnf(
|
426
|
+
tools,
|
427
|
+
sequence_start_token=self.bot_token,
|
428
|
+
sequence_end_token=self.eot_token,
|
429
|
+
individual_call_start_token=self.tool_call_begin,
|
430
|
+
individual_call_end_token=self.tool_call_end,
|
431
|
+
tool_call_separator="",
|
432
|
+
function_format="xml",
|
433
|
+
call_rule_fmt=call_rule_fmt,
|
434
|
+
key_value_rule_fmt=key_value_rule_fmt,
|
435
|
+
key_value_separator="",
|
436
|
+
)
|
@@ -41,6 +41,7 @@ from sglang.srt.configs import (
|
|
41
41
|
ExaoneConfig,
|
42
42
|
KimiVLConfig,
|
43
43
|
MultiModalityConfig,
|
44
|
+
Step3VLConfig,
|
44
45
|
)
|
45
46
|
from sglang.srt.configs.internvl import InternVLChatConfig
|
46
47
|
from sglang.srt.connector import create_remote_connector
|
@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
54
55
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
55
56
|
KimiVLConfig.model_type: KimiVLConfig,
|
56
57
|
InternVLChatConfig.model_type: InternVLChatConfig,
|
58
|
+
Step3VLConfig.model_type: Step3VLConfig,
|
57
59
|
}
|
58
60
|
|
59
61
|
for name, cls in _CONFIG_REGISTRY.items():
|
@@ -165,7 +165,7 @@ def process_content_for_template_format(
|
|
165
165
|
new_msg["content"] = processed_content_parts
|
166
166
|
return new_msg
|
167
167
|
|
168
|
-
|
168
|
+
elif content_format == "string":
|
169
169
|
# String format: flatten to text only (for templates like DeepSeek)
|
170
170
|
text_parts = []
|
171
171
|
for chunk in msg_dict["content"]:
|
@@ -179,3 +179,6 @@ def process_content_for_template_format(
|
|
179
179
|
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
180
180
|
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
181
181
|
return new_msg
|
182
|
+
|
183
|
+
else:
|
184
|
+
raise ValueError(f"Invalid content format: {content_format}")
|