sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
6
+ from typing import Any
7
+
8
+ logger = logging.getLogger(__name__)
9
+ try:
10
+ from mcp import ClientSession
11
+ from mcp.client.sse import sse_client
12
+ from mcp.types import ListToolsResult
13
+ except ImportError:
14
+ logger.warning("Ignoring mcp import error")
15
+
16
+ from openai_harmony import ToolDescription, ToolNamespaceConfig
17
+
18
+
19
+ async def list_server_and_tools(server_url: str):
20
+
21
+ async with sse_client(url=server_url) as streams, ClientSession(
22
+ *streams
23
+ ) as session:
24
+ initialize_response = await session.initialize()
25
+ list_tools_response = await session.list_tools()
26
+ return initialize_response, list_tools_response
27
+
28
+
29
+ def trim_schema(schema: dict) -> dict:
30
+ # Turn JSON Schema from MCP generated into Harmony's variant.
31
+ if "title" in schema:
32
+ del schema["title"]
33
+ if "default" in schema and schema["default"] is None:
34
+ del schema["default"]
35
+ if "anyOf" in schema:
36
+ # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
37
+ # into "type": ["type-1", "type-2"]
38
+ # if there's more than 1 types, also remove "null" type as Harmony will
39
+ # just ignore it
40
+ types = [
41
+ type_dict["type"]
42
+ for type_dict in schema["anyOf"]
43
+ if type_dict["type"] != "null"
44
+ ]
45
+ schema["type"] = types
46
+ del schema["anyOf"]
47
+ if "properties" in schema:
48
+ schema["properties"] = {
49
+ k: trim_schema(v) for k, v in schema["properties"].items()
50
+ }
51
+ return schema
52
+
53
+
54
+ def post_process_tools_description(
55
+ list_tools_result: "ListToolsResult",
56
+ ) -> "ListToolsResult":
57
+ # Adapt the MCP tool result for Harmony
58
+ for tool in list_tools_result.tools:
59
+ tool.inputSchema = trim_schema(tool.inputSchema)
60
+
61
+ # Some tools schema don't need to be part of the prompt (e.g. simple text
62
+ # in text out for Python)
63
+ list_tools_result.tools = [
64
+ tool
65
+ for tool in list_tools_result.tools
66
+ if getattr(tool.annotations, "include_in_prompt", True)
67
+ ]
68
+
69
+ return list_tools_result
70
+
71
+
72
+ class ToolServer(ABC):
73
+
74
+ @abstractmethod
75
+ def has_tool(self, tool_name: str):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def get_tool_description(self, tool_name: str):
80
+ pass
81
+
82
+ @abstractmethod
83
+ def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
84
+
85
+
86
+ class MCPToolServer(ToolServer):
87
+
88
+ def __init__(self):
89
+ self.harmony_tool_descriptions = {}
90
+
91
+ async def add_tool_server(self, server_url: str):
92
+ tool_urls = server_url.split(",")
93
+ self.harmony_tool_descriptions = {}
94
+ self.urls: dict[str, str] = {}
95
+ for url in tool_urls:
96
+ url = f"http://{url}/sse"
97
+ initialize_response, list_tools_response = await list_server_and_tools(url)
98
+
99
+ list_tools_response = post_process_tools_description(list_tools_response)
100
+
101
+ tool_from_mcp = ToolNamespaceConfig(
102
+ name=initialize_response.serverInfo.name,
103
+ description=initialize_response.instructions,
104
+ tools=[
105
+ ToolDescription.new(
106
+ name=tool.name,
107
+ description=tool.description,
108
+ parameters=tool.inputSchema,
109
+ )
110
+ for tool in list_tools_response.tools
111
+ ],
112
+ )
113
+ self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
114
+ if tool_from_mcp.name not in self.urls:
115
+ self.urls[tool_from_mcp.name] = url
116
+ else:
117
+ logger.warning(
118
+ "Tool %s already exists. Ignoring duplicate tool server %s",
119
+ tool_from_mcp.name,
120
+ url,
121
+ )
122
+
123
+ def has_tool(self, tool_name: str):
124
+ return tool_name in self.harmony_tool_descriptions
125
+
126
+ def get_tool_description(self, tool_name: str):
127
+ return self.harmony_tool_descriptions.get(tool_name)
128
+
129
+ @asynccontextmanager
130
+ async def get_tool_session(self, tool_name: str):
131
+ url = self.urls.get(tool_name)
132
+ if url:
133
+ async with sse_client(url=url) as streams, ClientSession(
134
+ *streams
135
+ ) as session:
136
+ await session.initialize()
137
+ yield session
138
+ else:
139
+ logger.warning("Tool %s not found", tool_name)
140
+
141
+
142
+ class DemoToolServer(ToolServer):
143
+
144
+ def __init__(self):
145
+ from sglang.srt.entrypoints.tool import (
146
+ HarmonyBrowserTool,
147
+ HarmonyPythonTool,
148
+ Tool,
149
+ )
150
+
151
+ self.tools: dict[str, Tool] = {}
152
+ browser_tool = HarmonyBrowserTool()
153
+ if browser_tool.enabled:
154
+ self.tools["browser"] = browser_tool
155
+ python_tool = HarmonyPythonTool()
156
+ if python_tool.enabled:
157
+ self.tools["python"] = python_tool
158
+
159
+ def has_tool(self, tool_name: str):
160
+ return tool_name in self.tools
161
+
162
+ def get_tool_description(self, tool_name: str):
163
+ if tool_name not in self.tools:
164
+ return None
165
+ if tool_name == "browser":
166
+ return ToolNamespaceConfig.browser()
167
+ elif tool_name == "python":
168
+ return ToolNamespaceConfig.python()
169
+ else:
170
+ raise ValueError(f"Unknown tool {tool_name}")
171
+
172
+ @asynccontextmanager
173
+ async def get_tool_session(self, tool_name: str):
174
+ yield self.tools[tool_name]
@@ -0,0 +1,87 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import logging
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ if TYPE_CHECKING:
8
+ # Avoid circular import.
9
+ from sglang.srt.entrypoints.context import ConversationContext
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Tool(ABC):
15
+
16
+ @abstractmethod
17
+ async def get_result(self, context: "ConversationContext") -> Any:
18
+ pass
19
+
20
+
21
+ class HarmonyBrowserTool(Tool):
22
+
23
+ def __init__(self):
24
+ self.enabled = True
25
+ exa_api_key = os.getenv("EXA_API_KEY")
26
+ if not exa_api_key:
27
+ self.enabled = False
28
+ logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
29
+ return
30
+
31
+ try:
32
+ from gpt_oss.tools.simple_browser import SimpleBrowserTool
33
+ from gpt_oss.tools.simple_browser.backend import ExaBackend
34
+ except ImportError:
35
+ self.enabled = False
36
+ logger.warning_once("gpt_oss is not installed, browsing is disabled")
37
+ return
38
+
39
+ browser_backend = ExaBackend(source="web", api_key=exa_api_key)
40
+ self.browser_tool = SimpleBrowserTool(backend=browser_backend)
41
+ logger.info_once("Browser tool initialized")
42
+
43
+ async def get_result(self, context: "ConversationContext") -> Any:
44
+ from sglang.srt.entrypoints.context import HarmonyContext
45
+
46
+ assert isinstance(context, HarmonyContext)
47
+ last_msg = context.messages[-1]
48
+ tool_output_msgs = []
49
+ async for msg in self.browser_tool.process(last_msg):
50
+ tool_output_msgs.append(msg)
51
+ return tool_output_msgs
52
+
53
+ @property
54
+ def tool_config(self) -> Any:
55
+ return self.browser_tool.tool_config
56
+
57
+
58
+ class HarmonyPythonTool(Tool):
59
+
60
+ def __init__(self):
61
+ self.enabled = True
62
+
63
+ try:
64
+ from gpt_oss.tools.python_docker.docker_tool import PythonTool
65
+ except ImportError:
66
+ self.enabled = False
67
+ logger.warning_once(
68
+ "gpt_oss is not installed, code interpreter is disabled"
69
+ )
70
+ return
71
+
72
+ self.python_tool = PythonTool()
73
+ logger.info_once("Code interpreter tool initialized")
74
+
75
+ async def get_result(self, context: "ConversationContext") -> Any:
76
+ from sglang.srt.entrypoints.context import HarmonyContext
77
+
78
+ assert isinstance(context, HarmonyContext)
79
+ last_msg = context.messages[-1]
80
+ tool_output_msgs = []
81
+ async for msg in self.python_tool.process(last_msg):
82
+ tool_output_msgs.append(msg)
83
+ return tool_output_msgs
84
+
85
+ @property
86
+ def tool_config(self) -> Any:
87
+ return self.python_tool.tool_config
@@ -35,6 +35,7 @@ class ExpertLocationMetadata:
35
35
  physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
36
36
  physical_to_logical_map_cpu: torch.Tensor
37
37
  logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
38
+ logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance
38
39
  logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
39
40
  # (layers, num_logical_experts)
40
41
  logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
@@ -221,6 +222,7 @@ class ExpertLocationMetadata:
221
222
  physical_to_logical_map=physical_to_logical_map,
222
223
  physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
223
224
  logical_to_all_physical_map=logical_to_all_physical_map_padded,
225
+ logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),
224
226
  logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
225
227
  logical_to_rank_dispatch_physical_map=(
226
228
  compute_logical_to_rank_dispatch_physical_map(
@@ -251,6 +253,7 @@ class ExpertLocationMetadata:
251
253
  "physical_to_logical_map",
252
254
  "physical_to_logical_map_cpu",
253
255
  "logical_to_all_physical_map",
256
+ "logical_to_all_physical_map_cpu",
254
257
  "logical_to_all_physical_map_num_valid",
255
258
  "logical_to_rank_dispatch_physical_map",
256
259
  ]:
@@ -270,9 +273,10 @@ class ExpertLocationMetadata:
270
273
  def logical_to_all_physical(
271
274
  self, layer_id: int, logical_expert_id: int
272
275
  ) -> List[int]:
276
+ # Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
273
277
  return [
274
278
  physical_expert_id
275
- for physical_expert_id in self.logical_to_all_physical_map[
279
+ for physical_expert_id in self.logical_to_all_physical_map_cpu[
276
280
  layer_id, logical_expert_id
277
281
  ].tolist()
278
282
  if physical_expert_id != -1
@@ -0,0 +1,130 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Harmony tool call parser for processing tool calls in harmony models."""
15
+
16
+ import uuid
17
+ from typing import List, Optional, Tuple
18
+
19
+ from sglang.srt.entrypoints.openai.protocol import (
20
+ ChatMessage,
21
+ FunctionResponse,
22
+ ToolCall,
23
+ )
24
+
25
+
26
+ class HarmonyToolCallParser:
27
+ """Parser for extracting tool calls from harmony model outputs."""
28
+
29
+ def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
30
+ """
31
+ Extract tool call from a single message if it's a tool call.
32
+
33
+ Args:
34
+ msg: The harmony message
35
+
36
+ Returns:
37
+ ToolCall if the message is a tool call, None otherwise
38
+ """
39
+ if (
40
+ msg.channel == "commentary"
41
+ and msg.recipient
42
+ and msg.recipient.startswith("functions.")
43
+ ):
44
+ function_name = msg.recipient.split(".")[-1]
45
+ arguments = msg.content[0].text if msg.content else "{}"
46
+
47
+ return ToolCall(
48
+ id=f"call_{uuid.uuid4().hex[:24]}",
49
+ function=FunctionResponse(
50
+ name=function_name,
51
+ arguments=arguments,
52
+ ),
53
+ )
54
+ return None
55
+
56
+ def process_streaming_chunk(
57
+ self,
58
+ harmony_parser,
59
+ index: int,
60
+ tool_call_trackers: dict,
61
+ stream_buffers: dict,
62
+ ) -> Tuple[Optional[dict], bool, Optional[str]]:
63
+ """
64
+ Process a streaming chunk for tool calls.
65
+
66
+ Args:
67
+ harmony_parser: The harmony parser instance
68
+ index: The choice index
69
+ tool_call_trackers: Dict tracking tool calls per choice
70
+ stream_buffers: Dict for buffering content
71
+
72
+ Returns:
73
+ Tuple of (tool_call_data, is_tool_call, delta)
74
+ """
75
+ # Check if we're in a tool call
76
+ is_tool_call = (
77
+ harmony_parser.current_channel == "commentary"
78
+ and harmony_parser.current_recipient
79
+ and harmony_parser.current_recipient.startswith("functions.")
80
+ )
81
+
82
+ delta = harmony_parser.last_content_delta or ""
83
+ tool_call_data = None
84
+
85
+ if is_tool_call:
86
+ # Handle tool call streaming
87
+ function_name = harmony_parser.current_recipient.split(".")[-1]
88
+
89
+ # Track tool call indices per choice
90
+ if index not in tool_call_trackers:
91
+ tool_call_trackers[index] = {"count": 0, "current_function": None}
92
+
93
+ # Check if we just started a new tool call
94
+ tool_call_tracker = tool_call_trackers[index]
95
+ if tool_call_tracker["current_function"] != function_name:
96
+ # New tool call started
97
+ tool_call_tracker["current_function"] = function_name
98
+ tool_call_index = tool_call_tracker["count"]
99
+ tool_call_tracker["count"] += 1
100
+
101
+ # Store the tool call index for this function
102
+ tool_call_key = f"{index}_{function_name}"
103
+ stream_buffers[tool_call_key] = {
104
+ "index": tool_call_index,
105
+ "content": "",
106
+ }
107
+
108
+ tool_call_data = {
109
+ "id": f"call_{uuid.uuid4().hex[:24]}",
110
+ "index": tool_call_index,
111
+ "function_name": function_name,
112
+ "arguments": delta,
113
+ "is_first_chunk": True,
114
+ }
115
+ else:
116
+ # Subsequent chunks for the same tool call
117
+ tool_call_key = f"{index}_{function_name}"
118
+ tool_call_index = stream_buffers[tool_call_key]["index"]
119
+
120
+ tool_call_data = {
121
+ "id": None,
122
+ "index": tool_call_index,
123
+ "function_name": None,
124
+ "arguments": delta,
125
+ "is_first_chunk": False,
126
+ }
127
+
128
+ stream_buffers[tool_call_key]["content"] += delta
129
+
130
+ return tool_call_data, is_tool_call, delta
@@ -14,10 +14,11 @@
14
14
  """Utilities for Huggingface Transformers."""
15
15
 
16
16
  import contextlib
17
+ import json
17
18
  import os
18
19
  import warnings
19
20
  from pathlib import Path
20
- from typing import Dict, Optional, Type, Union
21
+ from typing import Any, Dict, Optional, Type, Union
21
22
 
22
23
  import torch
23
24
  from huggingface_hub import snapshot_download
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
62
63
  AutoConfig.register(name, cls)
63
64
 
64
65
 
65
- def download_from_hf(model_path: str):
66
+ def download_from_hf(
67
+ model_path: str,
68
+ allow_patterns: Optional[Union[str, list]] = None,
69
+ ):
66
70
  if os.path.exists(model_path):
67
71
  return model_path
68
72
 
69
- return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
73
+ if not allow_patterns:
74
+ allow_patterns = ["*.json", "*.bin", "*.model"]
75
+
76
+ return snapshot_download(model_path, allow_patterns=allow_patterns)
70
77
 
71
78
 
72
79
  def get_hf_text_config(config: PretrainedConfig):
@@ -171,6 +178,26 @@ def get_generation_config(
171
178
  return None
172
179
 
173
180
 
181
+ # Qwen-1M related
182
+ def get_sparse_attention_config(
183
+ model: str,
184
+ sparse_attention_config_filename: str = "sparse_attention_config.json",
185
+ ) -> Dict[str, Any]:
186
+ is_local = os.path.isdir(model)
187
+ if not is_local:
188
+ # Download the config files.
189
+ model = download_from_hf(model, allow_patterns=["*.json"])
190
+
191
+ config_file = os.path.join(model, sparse_attention_config_filename)
192
+ if not os.path.exists(config_file):
193
+ return {}
194
+
195
+ # Load the sparse attention config.
196
+ with open(config_file) as f:
197
+ config = json.load(f)
198
+ return config
199
+
200
+
174
201
  # Models don't use the same configuration key for determining the maximum
175
202
  # context length. Store them here so we can sanely check them.
176
203
  # NOTE: The ordering here is important. Some models have two of these and we
@@ -9,6 +9,8 @@ import logging
9
9
  import jinja2
10
10
  import transformers.utils.chat_template_utils as hf_chat_utils
11
11
 
12
+ from sglang.srt.utils import ImageData
13
+
12
14
  logger = logging.getLogger(__name__)
13
15
 
14
16
  # ============================================================================
@@ -140,7 +142,12 @@ def process_content_for_template_format(
140
142
  chunk_type = chunk.get("type")
141
143
 
142
144
  if chunk_type == "image_url":
143
- image_data.append(chunk["image_url"]["url"])
145
+ image_data.append(
146
+ ImageData(
147
+ url=chunk["image_url"]["url"],
148
+ detail=chunk["image_url"].get("detail", "auto"),
149
+ )
150
+ )
144
151
  if chunk.get("modalities"):
145
152
  modalities.append(chunk.get("modalities"))
146
153
  # Normalize to simple 'image' type for template compatibility
@@ -720,11 +720,6 @@ class AiterIndicesUpdaterPrefill:
720
720
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
721
721
  self.update = self.update_single_wrapper
722
722
 
723
- # get the last index of the pool
724
- self.pool_size = (
725
- model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
726
- ) - 1
727
-
728
723
  self.kv_indices = None
729
724
  self.max_q_len = 0
730
725
  self.max_kv_len = 0
@@ -769,9 +764,8 @@ class AiterIndicesUpdaterPrefill:
769
764
  # but the 0 location will be made nan (noqa) in cuda graph capture mode
770
765
  # this will cause the output tensor value becomes nan
771
766
  # WA is to assure that last index of pool not changed
772
- kv_indices = torch.full(
773
- (paged_kernel_lens_sum + 128,),
774
- self.pool_size,
767
+ kv_indices = torch.empty(
768
+ paged_kernel_lens_sum + 256,
775
769
  dtype=torch.int32,
776
770
  device=req_pool_indices.device,
777
771
  )
@@ -785,6 +779,9 @@ class AiterIndicesUpdaterPrefill:
785
779
  self.req_to_token.shape[1],
786
780
  )
787
781
 
782
+ token_num = kv_indptr[-1]
783
+ kv_indices[token_num:] = kv_indices[0]
784
+
788
785
  self.max_kv_len = torch.max(paged_kernel_lens).item()
789
786
 
790
787
  extend_lens = seq_lens - prefix_lens