sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
try:
|
9
|
+
from mcp import ClientSession
|
10
|
+
from mcp.client.sse import sse_client
|
11
|
+
from mcp.types import ListToolsResult
|
12
|
+
except ImportError as e:
|
13
|
+
ClientSession = sse_client = ListToolsResult = e
|
14
|
+
|
15
|
+
from openai_harmony import ToolDescription, ToolNamespaceConfig
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
async def list_server_and_tools(server_url: str):
|
21
|
+
|
22
|
+
async with sse_client(url=server_url) as streams, ClientSession(
|
23
|
+
*streams
|
24
|
+
) as session:
|
25
|
+
initialize_response = await session.initialize()
|
26
|
+
list_tools_response = await session.list_tools()
|
27
|
+
return initialize_response, list_tools_response
|
28
|
+
|
29
|
+
|
30
|
+
def trim_schema(schema: dict) -> dict:
|
31
|
+
# Turn JSON Schema from MCP generated into Harmony's variant.
|
32
|
+
if "title" in schema:
|
33
|
+
del schema["title"]
|
34
|
+
if "default" in schema and schema["default"] is None:
|
35
|
+
del schema["default"]
|
36
|
+
if "anyOf" in schema:
|
37
|
+
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
|
38
|
+
# into "type": ["type-1", "type-2"]
|
39
|
+
# if there's more than 1 types, also remove "null" type as Harmony will
|
40
|
+
# just ignore it
|
41
|
+
types = [
|
42
|
+
type_dict["type"]
|
43
|
+
for type_dict in schema["anyOf"]
|
44
|
+
if type_dict["type"] != "null"
|
45
|
+
]
|
46
|
+
schema["type"] = types
|
47
|
+
del schema["anyOf"]
|
48
|
+
if "properties" in schema:
|
49
|
+
schema["properties"] = {
|
50
|
+
k: trim_schema(v) for k, v in schema["properties"].items()
|
51
|
+
}
|
52
|
+
return schema
|
53
|
+
|
54
|
+
|
55
|
+
def post_process_tools_description(
|
56
|
+
list_tools_result: "ListToolsResult",
|
57
|
+
) -> "ListToolsResult":
|
58
|
+
# Adapt the MCP tool result for Harmony
|
59
|
+
for tool in list_tools_result.tools:
|
60
|
+
tool.inputSchema = trim_schema(tool.inputSchema)
|
61
|
+
|
62
|
+
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
63
|
+
# in text out for Python)
|
64
|
+
list_tools_result.tools = [
|
65
|
+
tool
|
66
|
+
for tool in list_tools_result.tools
|
67
|
+
if getattr(tool.annotations, "include_in_prompt", True)
|
68
|
+
]
|
69
|
+
|
70
|
+
return list_tools_result
|
71
|
+
|
72
|
+
|
73
|
+
class ToolServer(ABC):
|
74
|
+
|
75
|
+
@abstractmethod
|
76
|
+
def has_tool(self, tool_name: str):
|
77
|
+
pass
|
78
|
+
|
79
|
+
@abstractmethod
|
80
|
+
def get_tool_description(self, tool_name: str):
|
81
|
+
pass
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
|
85
|
+
|
86
|
+
|
87
|
+
class MCPToolServer(ToolServer):
|
88
|
+
|
89
|
+
def __init__(self):
|
90
|
+
self.harmony_tool_descriptions = {}
|
91
|
+
|
92
|
+
async def add_tool_server(self, server_url: str):
|
93
|
+
tool_urls = server_url.split(",")
|
94
|
+
self.harmony_tool_descriptions = {}
|
95
|
+
self.urls: dict[str, str] = {}
|
96
|
+
for url in tool_urls:
|
97
|
+
url = f"http://{url}/sse"
|
98
|
+
initialize_response, list_tools_response = await list_server_and_tools(url)
|
99
|
+
|
100
|
+
list_tools_response = post_process_tools_description(list_tools_response)
|
101
|
+
|
102
|
+
tool_from_mcp = ToolNamespaceConfig(
|
103
|
+
name=initialize_response.serverInfo.name,
|
104
|
+
description=initialize_response.instructions,
|
105
|
+
tools=[
|
106
|
+
ToolDescription.new(
|
107
|
+
name=tool.name,
|
108
|
+
description=tool.description,
|
109
|
+
parameters=tool.inputSchema,
|
110
|
+
)
|
111
|
+
for tool in list_tools_response.tools
|
112
|
+
],
|
113
|
+
)
|
114
|
+
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
115
|
+
if tool_from_mcp.name not in self.urls:
|
116
|
+
self.urls[tool_from_mcp.name] = url
|
117
|
+
else:
|
118
|
+
logger.warning(
|
119
|
+
"Tool %s already exists. Ignoring duplicate tool server %s",
|
120
|
+
tool_from_mcp.name,
|
121
|
+
url,
|
122
|
+
)
|
123
|
+
|
124
|
+
def has_tool(self, tool_name: str):
|
125
|
+
return tool_name in self.harmony_tool_descriptions
|
126
|
+
|
127
|
+
def get_tool_description(self, tool_name: str):
|
128
|
+
return self.harmony_tool_descriptions.get(tool_name)
|
129
|
+
|
130
|
+
@asynccontextmanager
|
131
|
+
async def get_tool_session(self, tool_name: str):
|
132
|
+
url = self.urls.get(tool_name)
|
133
|
+
if url:
|
134
|
+
async with sse_client(url=url) as streams, ClientSession(
|
135
|
+
*streams
|
136
|
+
) as session:
|
137
|
+
await session.initialize()
|
138
|
+
yield session
|
139
|
+
else:
|
140
|
+
logger.warning("Tool %s not found", tool_name)
|
141
|
+
|
142
|
+
|
143
|
+
class DemoToolServer(ToolServer):
|
144
|
+
|
145
|
+
def __init__(self):
|
146
|
+
from sglang.srt.entrypoints.tool import (
|
147
|
+
HarmonyBrowserTool,
|
148
|
+
HarmonyPythonTool,
|
149
|
+
Tool,
|
150
|
+
)
|
151
|
+
|
152
|
+
self.tools: dict[str, Tool] = {}
|
153
|
+
browser_tool = HarmonyBrowserTool()
|
154
|
+
if browser_tool.enabled:
|
155
|
+
self.tools["browser"] = browser_tool
|
156
|
+
python_tool = HarmonyPythonTool()
|
157
|
+
if python_tool.enabled:
|
158
|
+
self.tools["python"] = python_tool
|
159
|
+
|
160
|
+
def has_tool(self, tool_name: str):
|
161
|
+
return tool_name in self.tools
|
162
|
+
|
163
|
+
def get_tool_description(self, tool_name: str):
|
164
|
+
if tool_name not in self.tools:
|
165
|
+
return None
|
166
|
+
if tool_name == "browser":
|
167
|
+
return ToolNamespaceConfig.browser()
|
168
|
+
elif tool_name == "python":
|
169
|
+
return ToolNamespaceConfig.python()
|
170
|
+
else:
|
171
|
+
raise ValueError(f"Unknown tool {tool_name}")
|
172
|
+
|
173
|
+
@asynccontextmanager
|
174
|
+
async def get_tool_session(self, tool_name: str):
|
175
|
+
yield self.tools[tool_name]
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import TYPE_CHECKING, Any
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
# Avoid circular import.
|
9
|
+
from sglang.srt.entrypoints.context import ConversationContext
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class Tool(ABC):
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
18
|
+
pass
|
19
|
+
|
20
|
+
|
21
|
+
class HarmonyBrowserTool(Tool):
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self.enabled = True
|
25
|
+
exa_api_key = os.getenv("EXA_API_KEY")
|
26
|
+
if not exa_api_key:
|
27
|
+
self.enabled = False
|
28
|
+
logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
|
29
|
+
return
|
30
|
+
|
31
|
+
try:
|
32
|
+
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
33
|
+
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
34
|
+
except ImportError:
|
35
|
+
self.enabled = False
|
36
|
+
logger.warning_once("gpt_oss is not installed, browsing is disabled")
|
37
|
+
return
|
38
|
+
|
39
|
+
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
|
40
|
+
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
|
41
|
+
logger.info_once("Browser tool initialized")
|
42
|
+
|
43
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
44
|
+
from sglang.srt.entrypoints.context import HarmonyContext
|
45
|
+
|
46
|
+
assert isinstance(context, HarmonyContext)
|
47
|
+
last_msg = context.messages[-1]
|
48
|
+
tool_output_msgs = []
|
49
|
+
async for msg in self.browser_tool.process(last_msg):
|
50
|
+
tool_output_msgs.append(msg)
|
51
|
+
return tool_output_msgs
|
52
|
+
|
53
|
+
@property
|
54
|
+
def tool_config(self) -> Any:
|
55
|
+
return self.browser_tool.tool_config
|
56
|
+
|
57
|
+
|
58
|
+
class HarmonyPythonTool(Tool):
|
59
|
+
|
60
|
+
def __init__(self):
|
61
|
+
self.enabled = True
|
62
|
+
|
63
|
+
try:
|
64
|
+
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
65
|
+
except ImportError:
|
66
|
+
self.enabled = False
|
67
|
+
logger.warning_once(
|
68
|
+
"gpt_oss is not installed, code interpreter is disabled"
|
69
|
+
)
|
70
|
+
return
|
71
|
+
|
72
|
+
self.python_tool = PythonTool()
|
73
|
+
logger.info_once("Code interpreter tool initialized")
|
74
|
+
|
75
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
76
|
+
from sglang.srt.entrypoints.context import HarmonyContext
|
77
|
+
|
78
|
+
assert isinstance(context, HarmonyContext)
|
79
|
+
last_msg = context.messages[-1]
|
80
|
+
tool_output_msgs = []
|
81
|
+
async for msg in self.python_tool.process(last_msg):
|
82
|
+
tool_output_msgs.append(msg)
|
83
|
+
return tool_output_msgs
|
84
|
+
|
85
|
+
@property
|
86
|
+
def tool_config(self) -> Any:
|
87
|
+
return self.python_tool.tool_config
|
@@ -35,6 +35,7 @@ class ExpertLocationMetadata:
|
|
35
35
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
36
36
|
physical_to_logical_map_cpu: torch.Tensor
|
37
37
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
38
|
+
logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance
|
38
39
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
39
40
|
# (layers, num_logical_experts)
|
40
41
|
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
@@ -221,6 +222,7 @@ class ExpertLocationMetadata:
|
|
221
222
|
physical_to_logical_map=physical_to_logical_map,
|
222
223
|
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
223
224
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
225
|
+
logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),
|
224
226
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
225
227
|
logical_to_rank_dispatch_physical_map=(
|
226
228
|
compute_logical_to_rank_dispatch_physical_map(
|
@@ -251,6 +253,7 @@ class ExpertLocationMetadata:
|
|
251
253
|
"physical_to_logical_map",
|
252
254
|
"physical_to_logical_map_cpu",
|
253
255
|
"logical_to_all_physical_map",
|
256
|
+
"logical_to_all_physical_map_cpu",
|
254
257
|
"logical_to_all_physical_map_num_valid",
|
255
258
|
"logical_to_rank_dispatch_physical_map",
|
256
259
|
]:
|
@@ -270,9 +273,10 @@ class ExpertLocationMetadata:
|
|
270
273
|
def logical_to_all_physical(
|
271
274
|
self, layer_id: int, logical_expert_id: int
|
272
275
|
) -> List[int]:
|
276
|
+
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
|
273
277
|
return [
|
274
278
|
physical_expert_id
|
275
|
-
for physical_expert_id in self.
|
279
|
+
for physical_expert_id in self.logical_to_all_physical_map_cpu[
|
276
280
|
layer_id, logical_expert_id
|
277
281
|
].tolist()
|
278
282
|
if physical_expert_id != -1
|
@@ -316,6 +316,7 @@ class EBNFComposer:
|
|
316
316
|
|
317
317
|
combined_args = "".join(rule_parts)
|
318
318
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
319
|
+
arguments_rule = arguments_rule or '""'
|
319
320
|
|
320
321
|
# Add the function call rule and its arguments rule
|
321
322
|
ebnf_lines.append(
|
@@ -11,6 +11,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
11
11
|
from sglang.srt.function_call.core_types import ToolCallItem
|
12
12
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
13
13
|
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
14
|
+
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
|
14
15
|
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
15
16
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
16
17
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
@@ -41,6 +42,7 @@ class FunctionCallParser:
|
|
41
42
|
"qwen3_coder": Qwen3CoderDetector,
|
42
43
|
"glm45": Glm4MoeDetector,
|
43
44
|
"step3": Step3Detector,
|
45
|
+
"gpt-oss": GptOssDetector,
|
44
46
|
}
|
45
47
|
|
46
48
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
@@ -158,7 +158,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
|
158
158
|
individual_call_end_token=self.eot_token,
|
159
159
|
tool_call_separator="\\n",
|
160
160
|
function_format="xml",
|
161
|
-
call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"',
|
161
|
+
call_rule_fmt='"{name}" "\\n" ( {arguments_rule} "\\n" )?',
|
162
162
|
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
|
163
163
|
key_value_separator="\\n",
|
164
164
|
)
|
@@ -0,0 +1,331 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import re
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
7
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
|
+
from sglang.srt.function_call.core_types import (
|
9
|
+
StreamingParseResult,
|
10
|
+
ToolCallItem,
|
11
|
+
_GetInfoFunc,
|
12
|
+
)
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class GptOssDetector(BaseFormatDetector):
|
18
|
+
"""
|
19
|
+
Detector for T4-style function calls with channel format.
|
20
|
+
|
21
|
+
Supports two formats:
|
22
|
+
1. Direct function call: <|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{args}<|call|>
|
23
|
+
2. Commentary with action plan: <|channel|>commentary<|message|>{content}<|end|>
|
24
|
+
|
25
|
+
For parallel function calls, each call is self-contained and starts with its own channel:
|
26
|
+
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"SF"}<|call|>
|
27
|
+
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"SF attractions"}<|call|>
|
28
|
+
|
29
|
+
Examples:
|
30
|
+
Single: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"San Francisco"}<|call|>commentary
|
31
|
+
Multiple: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"Paris"}<|call|>commentary<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"Paris tourism"}<|call|>
|
32
|
+
With Action Plan: <|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|><|start|>assistant<|channel|>commentary to=functions.x<|constrain|>json<|message|>{"template": "basic_html", "path": "index.html"}<|call|>
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
self.bot_token = "<|start|>assistant<|channel|>commentary"
|
38
|
+
self.eot_token = "<|call|>"
|
39
|
+
# TODO: no clear indication how parallel tool call response format is
|
40
|
+
self.tool_call_separator = ""
|
41
|
+
|
42
|
+
# Pattern for complete function calls with to= parameter
|
43
|
+
# Handles both <|call|> and <|call|>commentary endings
|
44
|
+
# Also handles optional <|start|>assistant prefix and whitespace after function name
|
45
|
+
self.function_call_pattern = re.compile(
|
46
|
+
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
|
47
|
+
r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?",
|
48
|
+
re.DOTALL,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Pattern for streaming function calls (incomplete)
|
52
|
+
# Also handles optional whitespace after function name
|
53
|
+
self.streaming_pattern = re.compile(
|
54
|
+
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
|
55
|
+
r"<\|constrain\|>json<\|message\|>(.*)",
|
56
|
+
re.DOTALL,
|
57
|
+
)
|
58
|
+
|
59
|
+
# Pattern for commentary with action plan (no to= parameter)
|
60
|
+
self.commentary_pattern = re.compile(
|
61
|
+
r"<\|channel\|>commentary<\|message\|>(.*?)<\|end\|>",
|
62
|
+
re.DOTALL,
|
63
|
+
)
|
64
|
+
|
65
|
+
self._last_arguments = ""
|
66
|
+
|
67
|
+
def has_tool_call(self, text: str) -> bool:
|
68
|
+
"""Check if text contains TypeScript-style function call markers."""
|
69
|
+
return self.bot_token in text
|
70
|
+
|
71
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
72
|
+
"""Parse TypeScript-style function calls from complete text."""
|
73
|
+
if not self.has_tool_call(text):
|
74
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
75
|
+
|
76
|
+
tool_indices = self._get_tool_indices(tools)
|
77
|
+
|
78
|
+
calls = []
|
79
|
+
tool_index = 0
|
80
|
+
|
81
|
+
# Process the entire text to handle mixed commentary and tool calls
|
82
|
+
normal_text_parts = []
|
83
|
+
|
84
|
+
# Find all commentary sections (both with and without to=)
|
85
|
+
all_commentary_pattern = re.compile(
|
86
|
+
r"<\|channel\|>commentary(?:\s+to=[^<]*)?<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
|
87
|
+
re.DOTALL,
|
88
|
+
)
|
89
|
+
|
90
|
+
# Track processed positions to avoid double-processing
|
91
|
+
processed_ranges = []
|
92
|
+
|
93
|
+
# First, extract all tool calls
|
94
|
+
for match in self.function_call_pattern.finditer(text):
|
95
|
+
full_function_name = match.group(1)
|
96
|
+
args_content = match.group(2)
|
97
|
+
processed_ranges.append((match.start(), match.end()))
|
98
|
+
|
99
|
+
function_name = (
|
100
|
+
full_function_name.split(".")[-1]
|
101
|
+
if "." in full_function_name
|
102
|
+
else full_function_name
|
103
|
+
)
|
104
|
+
|
105
|
+
try:
|
106
|
+
arguments = json.loads(args_content) if args_content.strip() else {}
|
107
|
+
except json.JSONDecodeError:
|
108
|
+
continue
|
109
|
+
|
110
|
+
if function_name in tool_indices:
|
111
|
+
calls.append(
|
112
|
+
ToolCallItem(
|
113
|
+
tool_index=tool_index,
|
114
|
+
name=function_name,
|
115
|
+
parameters=json.dumps(arguments, ensure_ascii=False),
|
116
|
+
)
|
117
|
+
)
|
118
|
+
tool_index += 1
|
119
|
+
|
120
|
+
# Then, find non-tool-call commentary sections for normal text
|
121
|
+
for match in all_commentary_pattern.finditer(text):
|
122
|
+
# Check if this match overlaps with any processed tool call
|
123
|
+
match_start, match_end = match.start(), match.end()
|
124
|
+
is_tool_call = any(
|
125
|
+
start <= match_start < end or start < match_end <= end
|
126
|
+
for start, end in processed_ranges
|
127
|
+
)
|
128
|
+
|
129
|
+
# If this commentary is not part of a tool call, include it in normal text
|
130
|
+
if not is_tool_call:
|
131
|
+
content = match.group(1).strip()
|
132
|
+
if content:
|
133
|
+
normal_text_parts.append(content)
|
134
|
+
|
135
|
+
# Handle remaining text after all matches
|
136
|
+
if processed_ranges:
|
137
|
+
last_match_end = max(end for _, end in processed_ranges)
|
138
|
+
if last_match_end < len(text):
|
139
|
+
remaining_text = text[last_match_end:]
|
140
|
+
|
141
|
+
# Clean up <|start|>assistant prefixes and extract final content
|
142
|
+
# Remove standalone <|start|>assistant prefixes
|
143
|
+
remaining_text = re.sub(r"<\|start\|>assistant(?!\w)", "", remaining_text)
|
144
|
+
|
145
|
+
# Extract content from final channel if present
|
146
|
+
final_pattern = re.compile(
|
147
|
+
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)", re.DOTALL
|
148
|
+
)
|
149
|
+
final_match = final_pattern.search(remaining_text)
|
150
|
+
|
151
|
+
if final_match:
|
152
|
+
# Get everything before final channel + final channel content
|
153
|
+
before_final = remaining_text[: final_match.start()].strip()
|
154
|
+
final_content = final_match.group(1).strip()
|
155
|
+
|
156
|
+
parts = []
|
157
|
+
if before_final:
|
158
|
+
parts.append(before_final)
|
159
|
+
if final_content:
|
160
|
+
parts.append(final_content)
|
161
|
+
remaining_text = " ".join(parts) if parts else ""
|
162
|
+
|
163
|
+
remaining_text = remaining_text.strip()
|
164
|
+
|
165
|
+
if remaining_text:
|
166
|
+
normal_text_parts.append(remaining_text)
|
167
|
+
|
168
|
+
# Combine all normal text parts
|
169
|
+
final_normal_text = " ".join(part for part in normal_text_parts if part).strip()
|
170
|
+
return StreamingParseResult(normal_text=final_normal_text, calls=calls)
|
171
|
+
|
172
|
+
def parse_streaming_increment(
|
173
|
+
self, new_text: str, tools: List[Tool]
|
174
|
+
) -> StreamingParseResult:
|
175
|
+
"""Parse incremental streaming text for TypeScript-style function calls."""
|
176
|
+
self._buffer += new_text
|
177
|
+
current_text = self._buffer
|
178
|
+
|
179
|
+
# Check if we have a tool call
|
180
|
+
has_tool_call = "<|channel|>commentary to=" in current_text
|
181
|
+
|
182
|
+
if not has_tool_call and current_text:
|
183
|
+
# Check for commentary without function calls
|
184
|
+
commentary_match = self.commentary_pattern.search(current_text)
|
185
|
+
if commentary_match:
|
186
|
+
commentary_content = commentary_match.group(1)
|
187
|
+
self._buffer = current_text[commentary_match.end() :]
|
188
|
+
return StreamingParseResult(normal_text=commentary_content, calls=[])
|
189
|
+
|
190
|
+
# Check for final channel content
|
191
|
+
final_pattern = re.compile(
|
192
|
+
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
|
193
|
+
re.DOTALL,
|
194
|
+
)
|
195
|
+
final_match = final_pattern.search(current_text)
|
196
|
+
if final_match:
|
197
|
+
final_content = final_match.group(1).strip()
|
198
|
+
self._buffer = ""
|
199
|
+
return StreamingParseResult(normal_text=final_content, calls=[])
|
200
|
+
|
201
|
+
self._buffer = ""
|
202
|
+
return StreamingParseResult(normal_text=new_text, calls=[])
|
203
|
+
|
204
|
+
if not hasattr(self, "_tool_indices"):
|
205
|
+
self._tool_indices = self._get_tool_indices(tools)
|
206
|
+
|
207
|
+
calls = []
|
208
|
+
try:
|
209
|
+
# Check for streaming function call
|
210
|
+
match = self.streaming_pattern.search(current_text)
|
211
|
+
if match:
|
212
|
+
full_function_name = match.group(1)
|
213
|
+
args_content = match.group(2)
|
214
|
+
|
215
|
+
function_name = (
|
216
|
+
full_function_name.split(".")[-1]
|
217
|
+
if "." in full_function_name
|
218
|
+
else full_function_name
|
219
|
+
)
|
220
|
+
|
221
|
+
# Initialize state if this is the first tool call
|
222
|
+
if self.current_tool_id == -1:
|
223
|
+
self.current_tool_id = 0
|
224
|
+
self.prev_tool_call_arr = []
|
225
|
+
self.streamed_args_for_tool = [""]
|
226
|
+
|
227
|
+
# Ensure we have enough entries in tracking arrays
|
228
|
+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
229
|
+
self.prev_tool_call_arr.append({})
|
230
|
+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
231
|
+
self.streamed_args_for_tool.append("")
|
232
|
+
|
233
|
+
if not self.current_tool_name_sent:
|
234
|
+
calls.append(
|
235
|
+
ToolCallItem(
|
236
|
+
tool_index=self.current_tool_id,
|
237
|
+
name=function_name,
|
238
|
+
parameters="",
|
239
|
+
)
|
240
|
+
)
|
241
|
+
self.current_tool_name_sent = True
|
242
|
+
# Store the tool call info
|
243
|
+
self.prev_tool_call_arr[self.current_tool_id] = {
|
244
|
+
"name": function_name,
|
245
|
+
"arguments": {},
|
246
|
+
}
|
247
|
+
self.streamed_args_for_tool[self.current_tool_id] = ""
|
248
|
+
|
249
|
+
# Check if we have a complete function call
|
250
|
+
complete_match = self.function_call_pattern.search(current_text)
|
251
|
+
if complete_match:
|
252
|
+
args_content = complete_match.group(2)
|
253
|
+
|
254
|
+
try:
|
255
|
+
parsed_args = json.loads(args_content)
|
256
|
+
self.prev_tool_call_arr[self.current_tool_id][
|
257
|
+
"arguments"
|
258
|
+
] = parsed_args
|
259
|
+
|
260
|
+
# Send complete arguments if we haven't sent them yet
|
261
|
+
if not self.streamed_args_for_tool[self.current_tool_id]:
|
262
|
+
# Send the complete arguments as JSON string
|
263
|
+
calls.append(
|
264
|
+
ToolCallItem(
|
265
|
+
tool_index=self.current_tool_id,
|
266
|
+
name=None,
|
267
|
+
parameters=json.dumps(
|
268
|
+
parsed_args, ensure_ascii=False
|
269
|
+
),
|
270
|
+
)
|
271
|
+
)
|
272
|
+
self.streamed_args_for_tool[self.current_tool_id] = (
|
273
|
+
json.dumps(parsed_args, ensure_ascii=False)
|
274
|
+
)
|
275
|
+
except json.JSONDecodeError:
|
276
|
+
pass
|
277
|
+
|
278
|
+
# Remove the completed function call from buffer
|
279
|
+
remaining_after_call = current_text[complete_match.end() :]
|
280
|
+
|
281
|
+
# Clean up <|start|>assistant prefixes and extract final content
|
282
|
+
remaining_after_call = re.sub(
|
283
|
+
r"<\|start\|>assistant(?!\w)", "", remaining_after_call
|
284
|
+
)
|
285
|
+
|
286
|
+
# Extract content from final channel if present
|
287
|
+
final_pattern = re.compile(
|
288
|
+
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
|
289
|
+
re.DOTALL,
|
290
|
+
)
|
291
|
+
final_match = final_pattern.search(remaining_after_call)
|
292
|
+
|
293
|
+
if final_match:
|
294
|
+
before_final = remaining_after_call[
|
295
|
+
: final_match.start()
|
296
|
+
].strip()
|
297
|
+
final_content = final_match.group(1).strip()
|
298
|
+
|
299
|
+
parts = []
|
300
|
+
if before_final:
|
301
|
+
parts.append(before_final)
|
302
|
+
if final_content:
|
303
|
+
parts.append(final_content)
|
304
|
+
remaining_after_call = " ".join(parts) if parts else ""
|
305
|
+
|
306
|
+
self._buffer = remaining_after_call.strip()
|
307
|
+
|
308
|
+
# Reset state for next tool call
|
309
|
+
self.current_tool_name_sent = False
|
310
|
+
self.current_tool_id += 1
|
311
|
+
|
312
|
+
# Return final content if available
|
313
|
+
final_text = ""
|
314
|
+
if final_match and final_content:
|
315
|
+
final_text = final_content
|
316
|
+
elif remaining_after_call:
|
317
|
+
final_text = remaining_after_call
|
318
|
+
|
319
|
+
return StreamingParseResult(normal_text=final_text, calls=calls)
|
320
|
+
|
321
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
322
|
+
|
323
|
+
except Exception as e:
|
324
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
325
|
+
return StreamingParseResult(normal_text=current_text, calls=[])
|
326
|
+
|
327
|
+
def structure_info(self) -> _GetInfoFunc:
|
328
|
+
raise NotImplementedError()
|
329
|
+
|
330
|
+
def build_ebnf(self, tools: List[Tool]) -> str:
|
331
|
+
raise NotImplementedError()
|