sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
try:
|
10
|
+
from mcp import ClientSession
|
11
|
+
from mcp.client.sse import sse_client
|
12
|
+
from mcp.types import ListToolsResult
|
13
|
+
except ImportError:
|
14
|
+
logger.warning("Ignoring mcp import error")
|
15
|
+
|
16
|
+
from openai_harmony import ToolDescription, ToolNamespaceConfig
|
17
|
+
|
18
|
+
|
19
|
+
async def list_server_and_tools(server_url: str):
|
20
|
+
|
21
|
+
async with sse_client(url=server_url) as streams, ClientSession(
|
22
|
+
*streams
|
23
|
+
) as session:
|
24
|
+
initialize_response = await session.initialize()
|
25
|
+
list_tools_response = await session.list_tools()
|
26
|
+
return initialize_response, list_tools_response
|
27
|
+
|
28
|
+
|
29
|
+
def trim_schema(schema: dict) -> dict:
|
30
|
+
# Turn JSON Schema from MCP generated into Harmony's variant.
|
31
|
+
if "title" in schema:
|
32
|
+
del schema["title"]
|
33
|
+
if "default" in schema and schema["default"] is None:
|
34
|
+
del schema["default"]
|
35
|
+
if "anyOf" in schema:
|
36
|
+
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
|
37
|
+
# into "type": ["type-1", "type-2"]
|
38
|
+
# if there's more than 1 types, also remove "null" type as Harmony will
|
39
|
+
# just ignore it
|
40
|
+
types = [
|
41
|
+
type_dict["type"]
|
42
|
+
for type_dict in schema["anyOf"]
|
43
|
+
if type_dict["type"] != "null"
|
44
|
+
]
|
45
|
+
schema["type"] = types
|
46
|
+
del schema["anyOf"]
|
47
|
+
if "properties" in schema:
|
48
|
+
schema["properties"] = {
|
49
|
+
k: trim_schema(v) for k, v in schema["properties"].items()
|
50
|
+
}
|
51
|
+
return schema
|
52
|
+
|
53
|
+
|
54
|
+
def post_process_tools_description(
|
55
|
+
list_tools_result: "ListToolsResult",
|
56
|
+
) -> "ListToolsResult":
|
57
|
+
# Adapt the MCP tool result for Harmony
|
58
|
+
for tool in list_tools_result.tools:
|
59
|
+
tool.inputSchema = trim_schema(tool.inputSchema)
|
60
|
+
|
61
|
+
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
62
|
+
# in text out for Python)
|
63
|
+
list_tools_result.tools = [
|
64
|
+
tool
|
65
|
+
for tool in list_tools_result.tools
|
66
|
+
if getattr(tool.annotations, "include_in_prompt", True)
|
67
|
+
]
|
68
|
+
|
69
|
+
return list_tools_result
|
70
|
+
|
71
|
+
|
72
|
+
class ToolServer(ABC):
|
73
|
+
|
74
|
+
@abstractmethod
|
75
|
+
def has_tool(self, tool_name: str):
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def get_tool_description(self, tool_name: str):
|
80
|
+
pass
|
81
|
+
|
82
|
+
@abstractmethod
|
83
|
+
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
|
84
|
+
|
85
|
+
|
86
|
+
class MCPToolServer(ToolServer):
|
87
|
+
|
88
|
+
def __init__(self):
|
89
|
+
self.harmony_tool_descriptions = {}
|
90
|
+
|
91
|
+
async def add_tool_server(self, server_url: str):
|
92
|
+
tool_urls = server_url.split(",")
|
93
|
+
self.harmony_tool_descriptions = {}
|
94
|
+
self.urls: dict[str, str] = {}
|
95
|
+
for url in tool_urls:
|
96
|
+
url = f"http://{url}/sse"
|
97
|
+
initialize_response, list_tools_response = await list_server_and_tools(url)
|
98
|
+
|
99
|
+
list_tools_response = post_process_tools_description(list_tools_response)
|
100
|
+
|
101
|
+
tool_from_mcp = ToolNamespaceConfig(
|
102
|
+
name=initialize_response.serverInfo.name,
|
103
|
+
description=initialize_response.instructions,
|
104
|
+
tools=[
|
105
|
+
ToolDescription.new(
|
106
|
+
name=tool.name,
|
107
|
+
description=tool.description,
|
108
|
+
parameters=tool.inputSchema,
|
109
|
+
)
|
110
|
+
for tool in list_tools_response.tools
|
111
|
+
],
|
112
|
+
)
|
113
|
+
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
114
|
+
if tool_from_mcp.name not in self.urls:
|
115
|
+
self.urls[tool_from_mcp.name] = url
|
116
|
+
else:
|
117
|
+
logger.warning(
|
118
|
+
"Tool %s already exists. Ignoring duplicate tool server %s",
|
119
|
+
tool_from_mcp.name,
|
120
|
+
url,
|
121
|
+
)
|
122
|
+
|
123
|
+
def has_tool(self, tool_name: str):
|
124
|
+
return tool_name in self.harmony_tool_descriptions
|
125
|
+
|
126
|
+
def get_tool_description(self, tool_name: str):
|
127
|
+
return self.harmony_tool_descriptions.get(tool_name)
|
128
|
+
|
129
|
+
@asynccontextmanager
|
130
|
+
async def get_tool_session(self, tool_name: str):
|
131
|
+
url = self.urls.get(tool_name)
|
132
|
+
if url:
|
133
|
+
async with sse_client(url=url) as streams, ClientSession(
|
134
|
+
*streams
|
135
|
+
) as session:
|
136
|
+
await session.initialize()
|
137
|
+
yield session
|
138
|
+
else:
|
139
|
+
logger.warning("Tool %s not found", tool_name)
|
140
|
+
|
141
|
+
|
142
|
+
class DemoToolServer(ToolServer):
|
143
|
+
|
144
|
+
def __init__(self):
|
145
|
+
from sglang.srt.entrypoints.tool import (
|
146
|
+
HarmonyBrowserTool,
|
147
|
+
HarmonyPythonTool,
|
148
|
+
Tool,
|
149
|
+
)
|
150
|
+
|
151
|
+
self.tools: dict[str, Tool] = {}
|
152
|
+
browser_tool = HarmonyBrowserTool()
|
153
|
+
if browser_tool.enabled:
|
154
|
+
self.tools["browser"] = browser_tool
|
155
|
+
python_tool = HarmonyPythonTool()
|
156
|
+
if python_tool.enabled:
|
157
|
+
self.tools["python"] = python_tool
|
158
|
+
|
159
|
+
def has_tool(self, tool_name: str):
|
160
|
+
return tool_name in self.tools
|
161
|
+
|
162
|
+
def get_tool_description(self, tool_name: str):
|
163
|
+
if tool_name not in self.tools:
|
164
|
+
return None
|
165
|
+
if tool_name == "browser":
|
166
|
+
return ToolNamespaceConfig.browser()
|
167
|
+
elif tool_name == "python":
|
168
|
+
return ToolNamespaceConfig.python()
|
169
|
+
else:
|
170
|
+
raise ValueError(f"Unknown tool {tool_name}")
|
171
|
+
|
172
|
+
@asynccontextmanager
|
173
|
+
async def get_tool_session(self, tool_name: str):
|
174
|
+
yield self.tools[tool_name]
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import TYPE_CHECKING, Any
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
# Avoid circular import.
|
9
|
+
from sglang.srt.entrypoints.context import ConversationContext
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class Tool(ABC):
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
18
|
+
pass
|
19
|
+
|
20
|
+
|
21
|
+
class HarmonyBrowserTool(Tool):
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self.enabled = True
|
25
|
+
exa_api_key = os.getenv("EXA_API_KEY")
|
26
|
+
if not exa_api_key:
|
27
|
+
self.enabled = False
|
28
|
+
logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
|
29
|
+
return
|
30
|
+
|
31
|
+
try:
|
32
|
+
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
33
|
+
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
34
|
+
except ImportError:
|
35
|
+
self.enabled = False
|
36
|
+
logger.warning_once("gpt_oss is not installed, browsing is disabled")
|
37
|
+
return
|
38
|
+
|
39
|
+
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
|
40
|
+
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
|
41
|
+
logger.info_once("Browser tool initialized")
|
42
|
+
|
43
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
44
|
+
from sglang.srt.entrypoints.context import HarmonyContext
|
45
|
+
|
46
|
+
assert isinstance(context, HarmonyContext)
|
47
|
+
last_msg = context.messages[-1]
|
48
|
+
tool_output_msgs = []
|
49
|
+
async for msg in self.browser_tool.process(last_msg):
|
50
|
+
tool_output_msgs.append(msg)
|
51
|
+
return tool_output_msgs
|
52
|
+
|
53
|
+
@property
|
54
|
+
def tool_config(self) -> Any:
|
55
|
+
return self.browser_tool.tool_config
|
56
|
+
|
57
|
+
|
58
|
+
class HarmonyPythonTool(Tool):
|
59
|
+
|
60
|
+
def __init__(self):
|
61
|
+
self.enabled = True
|
62
|
+
|
63
|
+
try:
|
64
|
+
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
65
|
+
except ImportError:
|
66
|
+
self.enabled = False
|
67
|
+
logger.warning_once(
|
68
|
+
"gpt_oss is not installed, code interpreter is disabled"
|
69
|
+
)
|
70
|
+
return
|
71
|
+
|
72
|
+
self.python_tool = PythonTool()
|
73
|
+
logger.info_once("Code interpreter tool initialized")
|
74
|
+
|
75
|
+
async def get_result(self, context: "ConversationContext") -> Any:
|
76
|
+
from sglang.srt.entrypoints.context import HarmonyContext
|
77
|
+
|
78
|
+
assert isinstance(context, HarmonyContext)
|
79
|
+
last_msg = context.messages[-1]
|
80
|
+
tool_output_msgs = []
|
81
|
+
async for msg in self.python_tool.process(last_msg):
|
82
|
+
tool_output_msgs.append(msg)
|
83
|
+
return tool_output_msgs
|
84
|
+
|
85
|
+
@property
|
86
|
+
def tool_config(self) -> Any:
|
87
|
+
return self.python_tool.tool_config
|
@@ -35,6 +35,7 @@ class ExpertLocationMetadata:
|
|
35
35
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
36
36
|
physical_to_logical_map_cpu: torch.Tensor
|
37
37
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
38
|
+
logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance
|
38
39
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
39
40
|
# (layers, num_logical_experts)
|
40
41
|
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
@@ -221,6 +222,7 @@ class ExpertLocationMetadata:
|
|
221
222
|
physical_to_logical_map=physical_to_logical_map,
|
222
223
|
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
223
224
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
225
|
+
logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),
|
224
226
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
225
227
|
logical_to_rank_dispatch_physical_map=(
|
226
228
|
compute_logical_to_rank_dispatch_physical_map(
|
@@ -251,6 +253,7 @@ class ExpertLocationMetadata:
|
|
251
253
|
"physical_to_logical_map",
|
252
254
|
"physical_to_logical_map_cpu",
|
253
255
|
"logical_to_all_physical_map",
|
256
|
+
"logical_to_all_physical_map_cpu",
|
254
257
|
"logical_to_all_physical_map_num_valid",
|
255
258
|
"logical_to_rank_dispatch_physical_map",
|
256
259
|
]:
|
@@ -270,9 +273,10 @@ class ExpertLocationMetadata:
|
|
270
273
|
def logical_to_all_physical(
|
271
274
|
self, layer_id: int, logical_expert_id: int
|
272
275
|
) -> List[int]:
|
276
|
+
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
|
273
277
|
return [
|
274
278
|
physical_expert_id
|
275
|
-
for physical_expert_id in self.
|
279
|
+
for physical_expert_id in self.logical_to_all_physical_map_cpu[
|
276
280
|
layer_id, logical_expert_id
|
277
281
|
].tolist()
|
278
282
|
if physical_expert_id != -1
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Harmony tool call parser for processing tool calls in harmony models."""
|
15
|
+
|
16
|
+
import uuid
|
17
|
+
from typing import List, Optional, Tuple
|
18
|
+
|
19
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
20
|
+
ChatMessage,
|
21
|
+
FunctionResponse,
|
22
|
+
ToolCall,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class HarmonyToolCallParser:
|
27
|
+
"""Parser for extracting tool calls from harmony model outputs."""
|
28
|
+
|
29
|
+
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
|
30
|
+
"""
|
31
|
+
Extract tool call from a single message if it's a tool call.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
msg: The harmony message
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
ToolCall if the message is a tool call, None otherwise
|
38
|
+
"""
|
39
|
+
if (
|
40
|
+
msg.channel == "commentary"
|
41
|
+
and msg.recipient
|
42
|
+
and msg.recipient.startswith("functions.")
|
43
|
+
):
|
44
|
+
function_name = msg.recipient.split(".")[-1]
|
45
|
+
arguments = msg.content[0].text if msg.content else "{}"
|
46
|
+
|
47
|
+
return ToolCall(
|
48
|
+
id=f"call_{uuid.uuid4().hex[:24]}",
|
49
|
+
function=FunctionResponse(
|
50
|
+
name=function_name,
|
51
|
+
arguments=arguments,
|
52
|
+
),
|
53
|
+
)
|
54
|
+
return None
|
55
|
+
|
56
|
+
def process_streaming_chunk(
|
57
|
+
self,
|
58
|
+
harmony_parser,
|
59
|
+
index: int,
|
60
|
+
tool_call_trackers: dict,
|
61
|
+
stream_buffers: dict,
|
62
|
+
) -> Tuple[Optional[dict], bool, Optional[str]]:
|
63
|
+
"""
|
64
|
+
Process a streaming chunk for tool calls.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
harmony_parser: The harmony parser instance
|
68
|
+
index: The choice index
|
69
|
+
tool_call_trackers: Dict tracking tool calls per choice
|
70
|
+
stream_buffers: Dict for buffering content
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
Tuple of (tool_call_data, is_tool_call, delta)
|
74
|
+
"""
|
75
|
+
# Check if we're in a tool call
|
76
|
+
is_tool_call = (
|
77
|
+
harmony_parser.current_channel == "commentary"
|
78
|
+
and harmony_parser.current_recipient
|
79
|
+
and harmony_parser.current_recipient.startswith("functions.")
|
80
|
+
)
|
81
|
+
|
82
|
+
delta = harmony_parser.last_content_delta or ""
|
83
|
+
tool_call_data = None
|
84
|
+
|
85
|
+
if is_tool_call:
|
86
|
+
# Handle tool call streaming
|
87
|
+
function_name = harmony_parser.current_recipient.split(".")[-1]
|
88
|
+
|
89
|
+
# Track tool call indices per choice
|
90
|
+
if index not in tool_call_trackers:
|
91
|
+
tool_call_trackers[index] = {"count": 0, "current_function": None}
|
92
|
+
|
93
|
+
# Check if we just started a new tool call
|
94
|
+
tool_call_tracker = tool_call_trackers[index]
|
95
|
+
if tool_call_tracker["current_function"] != function_name:
|
96
|
+
# New tool call started
|
97
|
+
tool_call_tracker["current_function"] = function_name
|
98
|
+
tool_call_index = tool_call_tracker["count"]
|
99
|
+
tool_call_tracker["count"] += 1
|
100
|
+
|
101
|
+
# Store the tool call index for this function
|
102
|
+
tool_call_key = f"{index}_{function_name}"
|
103
|
+
stream_buffers[tool_call_key] = {
|
104
|
+
"index": tool_call_index,
|
105
|
+
"content": "",
|
106
|
+
}
|
107
|
+
|
108
|
+
tool_call_data = {
|
109
|
+
"id": f"call_{uuid.uuid4().hex[:24]}",
|
110
|
+
"index": tool_call_index,
|
111
|
+
"function_name": function_name,
|
112
|
+
"arguments": delta,
|
113
|
+
"is_first_chunk": True,
|
114
|
+
}
|
115
|
+
else:
|
116
|
+
# Subsequent chunks for the same tool call
|
117
|
+
tool_call_key = f"{index}_{function_name}"
|
118
|
+
tool_call_index = stream_buffers[tool_call_key]["index"]
|
119
|
+
|
120
|
+
tool_call_data = {
|
121
|
+
"id": None,
|
122
|
+
"index": tool_call_index,
|
123
|
+
"function_name": None,
|
124
|
+
"arguments": delta,
|
125
|
+
"is_first_chunk": False,
|
126
|
+
}
|
127
|
+
|
128
|
+
stream_buffers[tool_call_key]["content"] += delta
|
129
|
+
|
130
|
+
return tool_call_data, is_tool_call, delta
|
@@ -14,10 +14,11 @@
|
|
14
14
|
"""Utilities for Huggingface Transformers."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
+
import json
|
17
18
|
import os
|
18
19
|
import warnings
|
19
20
|
from pathlib import Path
|
20
|
-
from typing import Dict, Optional, Type, Union
|
21
|
+
from typing import Any, Dict, Optional, Type, Union
|
21
22
|
|
22
23
|
import torch
|
23
24
|
from huggingface_hub import snapshot_download
|
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
|
|
62
63
|
AutoConfig.register(name, cls)
|
63
64
|
|
64
65
|
|
65
|
-
def download_from_hf(
|
66
|
+
def download_from_hf(
|
67
|
+
model_path: str,
|
68
|
+
allow_patterns: Optional[Union[str, list]] = None,
|
69
|
+
):
|
66
70
|
if os.path.exists(model_path):
|
67
71
|
return model_path
|
68
72
|
|
69
|
-
|
73
|
+
if not allow_patterns:
|
74
|
+
allow_patterns = ["*.json", "*.bin", "*.model"]
|
75
|
+
|
76
|
+
return snapshot_download(model_path, allow_patterns=allow_patterns)
|
70
77
|
|
71
78
|
|
72
79
|
def get_hf_text_config(config: PretrainedConfig):
|
@@ -171,6 +178,26 @@ def get_generation_config(
|
|
171
178
|
return None
|
172
179
|
|
173
180
|
|
181
|
+
# Qwen-1M related
|
182
|
+
def get_sparse_attention_config(
|
183
|
+
model: str,
|
184
|
+
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
185
|
+
) -> Dict[str, Any]:
|
186
|
+
is_local = os.path.isdir(model)
|
187
|
+
if not is_local:
|
188
|
+
# Download the config files.
|
189
|
+
model = download_from_hf(model, allow_patterns=["*.json"])
|
190
|
+
|
191
|
+
config_file = os.path.join(model, sparse_attention_config_filename)
|
192
|
+
if not os.path.exists(config_file):
|
193
|
+
return {}
|
194
|
+
|
195
|
+
# Load the sparse attention config.
|
196
|
+
with open(config_file) as f:
|
197
|
+
config = json.load(f)
|
198
|
+
return config
|
199
|
+
|
200
|
+
|
174
201
|
# Models don't use the same configuration key for determining the maximum
|
175
202
|
# context length. Store them here so we can sanely check them.
|
176
203
|
# NOTE: The ordering here is important. Some models have two of these and we
|
@@ -9,6 +9,8 @@ import logging
|
|
9
9
|
import jinja2
|
10
10
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
11
11
|
|
12
|
+
from sglang.srt.utils import ImageData
|
13
|
+
|
12
14
|
logger = logging.getLogger(__name__)
|
13
15
|
|
14
16
|
# ============================================================================
|
@@ -140,7 +142,12 @@ def process_content_for_template_format(
|
|
140
142
|
chunk_type = chunk.get("type")
|
141
143
|
|
142
144
|
if chunk_type == "image_url":
|
143
|
-
image_data.append(
|
145
|
+
image_data.append(
|
146
|
+
ImageData(
|
147
|
+
url=chunk["image_url"]["url"],
|
148
|
+
detail=chunk["image_url"].get("detail", "auto"),
|
149
|
+
)
|
150
|
+
)
|
144
151
|
if chunk.get("modalities"):
|
145
152
|
modalities.append(chunk.get("modalities"))
|
146
153
|
# Normalize to simple 'image' type for template compatibility
|
@@ -720,11 +720,6 @@ class AiterIndicesUpdaterPrefill:
|
|
720
720
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
721
721
|
self.update = self.update_single_wrapper
|
722
722
|
|
723
|
-
# get the last index of the pool
|
724
|
-
self.pool_size = (
|
725
|
-
model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
|
726
|
-
) - 1
|
727
|
-
|
728
723
|
self.kv_indices = None
|
729
724
|
self.max_q_len = 0
|
730
725
|
self.max_kv_len = 0
|
@@ -769,9 +764,8 @@ class AiterIndicesUpdaterPrefill:
|
|
769
764
|
# but the 0 location will be made nan (noqa) in cuda graph capture mode
|
770
765
|
# this will cause the output tensor value becomes nan
|
771
766
|
# WA is to assure that last index of pool not changed
|
772
|
-
kv_indices = torch.
|
773
|
-
|
774
|
-
self.pool_size,
|
767
|
+
kv_indices = torch.empty(
|
768
|
+
paged_kernel_lens_sum + 256,
|
775
769
|
dtype=torch.int32,
|
776
770
|
device=req_pool_indices.device,
|
777
771
|
)
|
@@ -785,6 +779,9 @@ class AiterIndicesUpdaterPrefill:
|
|
785
779
|
self.req_to_token.shape[1],
|
786
780
|
)
|
787
781
|
|
782
|
+
token_num = kv_indptr[-1]
|
783
|
+
kv_indices[token_num:] = kv_indices[0]
|
784
|
+
|
788
785
|
self.max_kv_len = torch.max(paged_kernel_lens).item()
|
789
786
|
|
790
787
|
extend_lens = seq_lens - prefix_lens
|