sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from json import JSONDecodeError, JSONDecoder
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
6
|
+
|
7
|
+
import partial_json_parser
|
8
|
+
from partial_json_parser.core.options import Allow
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
|
11
|
+
TOOLS_TAG_LIST = [
|
12
|
+
"<|plugin|>",
|
13
|
+
"<function=",
|
14
|
+
"<tool_call>",
|
15
|
+
"<|python_tag|>",
|
16
|
+
"[TOOL_CALLS]",
|
17
|
+
]
|
18
|
+
|
19
|
+
|
20
|
+
class Function(BaseModel):
|
21
|
+
"""Function Tool Template."""
|
22
|
+
|
23
|
+
description: Optional[str] = Field(default=None, examples=[None])
|
24
|
+
name: Optional[str] = None
|
25
|
+
parameters: Optional[object] = None
|
26
|
+
|
27
|
+
|
28
|
+
class ToolCallItem(BaseModel):
|
29
|
+
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
30
|
+
|
31
|
+
tool_index: int
|
32
|
+
name: Optional[str] = None
|
33
|
+
parameters: str # JSON string
|
34
|
+
|
35
|
+
|
36
|
+
def _find_common_prefix(s1: str, s2: str) -> str:
|
37
|
+
prefix = ""
|
38
|
+
min_length = min(len(s1), len(s2))
|
39
|
+
for i in range(0, min_length):
|
40
|
+
if s1[i] == s2[i]:
|
41
|
+
prefix += s1[i]
|
42
|
+
else:
|
43
|
+
break
|
44
|
+
return prefix
|
45
|
+
|
46
|
+
|
47
|
+
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
48
|
+
try:
|
49
|
+
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
50
|
+
except JSONDecodeError as e:
|
51
|
+
if "Extra data" in e.msg:
|
52
|
+
dec = JSONDecoder()
|
53
|
+
return dec.raw_decode(input_str)
|
54
|
+
raise
|
55
|
+
|
56
|
+
|
57
|
+
def _is_complete_json(input_str: str) -> bool:
|
58
|
+
try:
|
59
|
+
json.loads(input_str)
|
60
|
+
return True
|
61
|
+
except JSONDecodeError:
|
62
|
+
return False
|
63
|
+
|
64
|
+
|
65
|
+
class StreamingParseResult:
|
66
|
+
"""Result of streaming incremental parsing."""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
70
|
+
):
|
71
|
+
self.normal_text = normal_text
|
72
|
+
self.calls = calls or []
|
73
|
+
|
74
|
+
|
75
|
+
class BaseFormatDetector:
|
76
|
+
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
77
|
+
|
78
|
+
def __init__(self):
|
79
|
+
# initialize properties used for state when parsing tool calls in
|
80
|
+
self._buffer = ""
|
81
|
+
# streaming mode
|
82
|
+
self.prev_tool_call_arr: List[Dict] = []
|
83
|
+
self.current_tool_id: int = -1
|
84
|
+
self.current_tool_name_sent: bool = False
|
85
|
+
self.streamed_args_for_tool: List[str] = (
|
86
|
+
[]
|
87
|
+
) # map what has been streamed for each tool so far to a list
|
88
|
+
self.bot_token = ""
|
89
|
+
self.eot_token = ""
|
90
|
+
|
91
|
+
def parse_base_json(self, action: Dict, tools: List[Function]):
|
92
|
+
name, parameters = action["name"], json.dumps(
|
93
|
+
action.get("parameters", action.get("arguments", {})),
|
94
|
+
ensure_ascii=False,
|
95
|
+
)
|
96
|
+
tool_index = [tool.function.name for tool in tools].index(name)
|
97
|
+
tool_call_item = ToolCallItem(
|
98
|
+
tool_index=tool_index, name=name, parameters=parameters
|
99
|
+
)
|
100
|
+
calls = [tool_call_item]
|
101
|
+
return calls
|
102
|
+
|
103
|
+
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
104
|
+
"""
|
105
|
+
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
106
|
+
Note that leftover_text here represents "content that this parser will not consume further".
|
107
|
+
"""
|
108
|
+
action = json.loads(text)
|
109
|
+
return self.parse_base_json(action, tools)
|
110
|
+
|
111
|
+
def parse_streaming_increment(
|
112
|
+
self, new_text: str, tools: List[Function]
|
113
|
+
) -> StreamingParseResult:
|
114
|
+
"""
|
115
|
+
Streaming incremental parsing, referencing the logic of Llama32Detector.
|
116
|
+
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
117
|
+
incremental argument output.
|
118
|
+
"""
|
119
|
+
# Append new text to buffer
|
120
|
+
self._buffer += new_text
|
121
|
+
current_text = self._buffer
|
122
|
+
if not (self.bot_token in current_text or current_text.startswith("{")):
|
123
|
+
self._buffer = ""
|
124
|
+
if self.eot_token in new_text:
|
125
|
+
new_text = new_text.replace(self.eot_token, "")
|
126
|
+
return StreamingParseResult(normal_text=new_text)
|
127
|
+
|
128
|
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
129
|
+
# sent yet, don't allow sending
|
130
|
+
# an incomplete string since OpenAI only ever (as far as I have
|
131
|
+
# seen) allows sending the entire tool/ function name at once.
|
132
|
+
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
133
|
+
try:
|
134
|
+
tool_call_arr = []
|
135
|
+
is_complete = []
|
136
|
+
try:
|
137
|
+
# depending on the prompt format the Llama model may or may not
|
138
|
+
# prefix the output with the <|python_tag|> token
|
139
|
+
start_idx = (
|
140
|
+
len(self.bot_token)
|
141
|
+
if current_text.startswith(self.bot_token)
|
142
|
+
else 0
|
143
|
+
)
|
144
|
+
while start_idx < len(current_text):
|
145
|
+
(obj, end_idx) = _partial_json_loads(
|
146
|
+
current_text[start_idx:], flags
|
147
|
+
)
|
148
|
+
is_complete.append(
|
149
|
+
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
150
|
+
)
|
151
|
+
start_idx += end_idx + len("; ")
|
152
|
+
# depending on the prompt Llama can use
|
153
|
+
# either arguments or parameters
|
154
|
+
if "parameters" in obj:
|
155
|
+
assert (
|
156
|
+
"arguments" not in obj
|
157
|
+
), "model generated both parameters and arguments"
|
158
|
+
obj["arguments"] = obj["parameters"]
|
159
|
+
tool_call_arr.append(obj)
|
160
|
+
|
161
|
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
162
|
+
# not enough tokens to parse into JSON yet
|
163
|
+
return StreamingParseResult()
|
164
|
+
|
165
|
+
# select as the current tool call the one we're on the state at
|
166
|
+
current_tool_call: Dict = (
|
167
|
+
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
168
|
+
)
|
169
|
+
|
170
|
+
# case -- if no tokens have been streamed for the tool, e.g.
|
171
|
+
# only the array brackets, stream nothing
|
172
|
+
if len(tool_call_arr) == 0:
|
173
|
+
return StreamingParseResult()
|
174
|
+
|
175
|
+
# case: we are starting a new tool in the array
|
176
|
+
# -> array has > 0 length AND length has moved past cursor
|
177
|
+
elif (
|
178
|
+
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
179
|
+
):
|
180
|
+
|
181
|
+
# if we're moving on to a new call, first make sure we
|
182
|
+
# haven't missed anything in the previous one that was
|
183
|
+
# auto-generated due to JSON completions, but wasn't
|
184
|
+
# streamed to the client yet.
|
185
|
+
if self.current_tool_id >= 0:
|
186
|
+
cur_arguments = current_tool_call.get("arguments")
|
187
|
+
if cur_arguments:
|
188
|
+
cur_args_json = json.dumps(cur_arguments)
|
189
|
+
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
190
|
+
argument_diff = cur_args_json[sent:]
|
191
|
+
|
192
|
+
res = StreamingParseResult(
|
193
|
+
normal_text=None,
|
194
|
+
calls=[
|
195
|
+
ToolCallItem(
|
196
|
+
tool_index=self.current_tool_id,
|
197
|
+
name="",
|
198
|
+
parameters=argument_diff,
|
199
|
+
)
|
200
|
+
],
|
201
|
+
)
|
202
|
+
self.streamed_args_for_tool[
|
203
|
+
self.current_tool_id
|
204
|
+
] += argument_diff
|
205
|
+
else:
|
206
|
+
res = StreamingParseResult()
|
207
|
+
else:
|
208
|
+
res = StreamingParseResult()
|
209
|
+
# re-set stuff pertaining to progress in the current tool
|
210
|
+
self.current_tool_id = len(tool_call_arr) - 1
|
211
|
+
self.current_tool_name_sent = False
|
212
|
+
self.streamed_args_for_tool.append("")
|
213
|
+
print("starting on new tool %d", self.current_tool_id)
|
214
|
+
return res
|
215
|
+
|
216
|
+
# if the current tool name hasn't been sent, send if available
|
217
|
+
# - otherwise send nothing
|
218
|
+
elif not self.current_tool_name_sent:
|
219
|
+
function_name = current_tool_call.get("name")
|
220
|
+
if function_name:
|
221
|
+
res = StreamingParseResult(
|
222
|
+
normal_text=None,
|
223
|
+
calls=[
|
224
|
+
ToolCallItem(
|
225
|
+
tool_index=self.current_tool_id,
|
226
|
+
name=function_name,
|
227
|
+
parameters="",
|
228
|
+
)
|
229
|
+
],
|
230
|
+
)
|
231
|
+
self.current_tool_name_sent = True
|
232
|
+
else:
|
233
|
+
res = StreamingParseResult()
|
234
|
+
|
235
|
+
# now we know we're on the same tool call and we're streaming
|
236
|
+
# arguments
|
237
|
+
else:
|
238
|
+
cur_arguments = current_tool_call.get("arguments")
|
239
|
+
res = StreamingParseResult()
|
240
|
+
|
241
|
+
if cur_arguments:
|
242
|
+
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
243
|
+
cur_args_json = json.dumps(cur_arguments)
|
244
|
+
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
245
|
+
"arguments"
|
246
|
+
)
|
247
|
+
|
248
|
+
argument_diff = None
|
249
|
+
if is_complete[self.current_tool_id]:
|
250
|
+
argument_diff = cur_args_json[sent:]
|
251
|
+
self._buffer = ""
|
252
|
+
self.prev_tool_call_arr[self.current_tool_id].clear()
|
253
|
+
self.current_tool_name_sent: bool = False
|
254
|
+
self.streamed_args_for_tool[self.current_tool_id] = ""
|
255
|
+
|
256
|
+
elif prev_arguments:
|
257
|
+
prev_args_json = json.dumps(prev_arguments)
|
258
|
+
if cur_args_json != prev_args_json:
|
259
|
+
|
260
|
+
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
261
|
+
argument_diff = prefix[sent:]
|
262
|
+
|
263
|
+
if argument_diff is not None:
|
264
|
+
res = StreamingParseResult(
|
265
|
+
calls=[
|
266
|
+
ToolCallItem(
|
267
|
+
tool_index=self.current_tool_id,
|
268
|
+
name="",
|
269
|
+
parameters=argument_diff,
|
270
|
+
)
|
271
|
+
],
|
272
|
+
)
|
273
|
+
if not is_complete[self.current_tool_id]:
|
274
|
+
self.streamed_args_for_tool[
|
275
|
+
self.current_tool_id
|
276
|
+
] += argument_diff
|
277
|
+
|
278
|
+
self.prev_tool_call_arr = tool_call_arr
|
279
|
+
return res
|
280
|
+
|
281
|
+
except Exception as e:
|
282
|
+
print(e)
|
283
|
+
# Skipping chunk as a result of tool streaming extraction error
|
284
|
+
return StreamingParseResult()
|
285
|
+
|
286
|
+
|
287
|
+
class Qwen25Detector(BaseFormatDetector):
|
288
|
+
"""
|
289
|
+
Detector for Qwen 2.5 models.
|
290
|
+
Assumes function call format:
|
291
|
+
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
292
|
+
"""
|
293
|
+
|
294
|
+
def __init__(self):
|
295
|
+
"""
|
296
|
+
Initializes the detector with necessary state variables.
|
297
|
+
"""
|
298
|
+
super().__init__()
|
299
|
+
self.bot_token = "<tool_call>"
|
300
|
+
self.eot_token = "</tool_call>"
|
301
|
+
|
302
|
+
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
303
|
+
"""
|
304
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
305
|
+
|
306
|
+
:param text: The complete text to parse.
|
307
|
+
:param tools: List of available tools.
|
308
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
309
|
+
"""
|
310
|
+
if "<tool_call>" not in text:
|
311
|
+
return []
|
312
|
+
pattern = r"<tool_call>(.*?)</tool_call>"
|
313
|
+
match_result_list = re.findall(pattern, text, re.DOTALL)
|
314
|
+
calls = []
|
315
|
+
for match_result in match_result_list:
|
316
|
+
match_result = json.loads(match_result)
|
317
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
318
|
+
return calls
|
319
|
+
|
320
|
+
|
321
|
+
class MistralDetector(BaseFormatDetector):
|
322
|
+
"""
|
323
|
+
Detector for Mistral models.
|
324
|
+
Assumes function call format:
|
325
|
+
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
326
|
+
"""
|
327
|
+
|
328
|
+
def __init__(self):
|
329
|
+
"""
|
330
|
+
Initializes the detector with necessary state variables.
|
331
|
+
"""
|
332
|
+
super().__init__()
|
333
|
+
self.bot_token = "[TOOL_CALLS] ["
|
334
|
+
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
335
|
+
|
336
|
+
def _clean_text(self, text: str) -> str:
|
337
|
+
"""
|
338
|
+
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
339
|
+
for example,
|
340
|
+
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
341
|
+
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
342
|
+
The key pattern is [TOOL_CALLS] [...]
|
343
|
+
"""
|
344
|
+
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
345
|
+
if len(find_results) > 0:
|
346
|
+
return find_results[0]
|
347
|
+
else:
|
348
|
+
return ""
|
349
|
+
|
350
|
+
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
351
|
+
"""
|
352
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
353
|
+
|
354
|
+
:param text: The complete text to parse.
|
355
|
+
:param tools: List of available tools.
|
356
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
357
|
+
"""
|
358
|
+
text = self._clean_text(text)
|
359
|
+
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
360
|
+
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
361
|
+
calls = []
|
362
|
+
if len(raw_tool_calls) > 0:
|
363
|
+
raw_tool_call = raw_tool_calls[0]
|
364
|
+
function_call_arr = json.loads(raw_tool_call)
|
365
|
+
for match_result in function_call_arr:
|
366
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
367
|
+
return calls
|
368
|
+
|
369
|
+
|
370
|
+
class Llama32Detector(BaseFormatDetector):
|
371
|
+
"""
|
372
|
+
Detector for Llama 3.2 models.
|
373
|
+
Assumes function call format:
|
374
|
+
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
375
|
+
Does not require a closing tag "</python_tag|>",
|
376
|
+
relies on json.loads(...) success to determine if JSON is complete.
|
377
|
+
"""
|
378
|
+
|
379
|
+
def __init__(self):
|
380
|
+
"""
|
381
|
+
Initializes the detector with necessary state variables.
|
382
|
+
"""
|
383
|
+
super().__init__()
|
384
|
+
self.bot_token = "<|python_tag|>"
|
385
|
+
|
386
|
+
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
387
|
+
"""
|
388
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
389
|
+
|
390
|
+
:param text: The complete text to parse.
|
391
|
+
:param tools: List of available tools.
|
392
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
393
|
+
"""
|
394
|
+
|
395
|
+
if "<|python_tag|>" not in text:
|
396
|
+
return []
|
397
|
+
_, action = text.split("<|python_tag|>")
|
398
|
+
action = json.loads(action)
|
399
|
+
return self.parse_base_json(action, tools)
|
400
|
+
|
401
|
+
|
402
|
+
class MultiFormatParser:
|
403
|
+
def __init__(self, detectors: List[BaseFormatDetector]):
|
404
|
+
"""
|
405
|
+
:param detectors: A series of available Detector instances passed in
|
406
|
+
"""
|
407
|
+
self.detectors = detectors
|
408
|
+
|
409
|
+
def parse_once(self, text: str, tools: List[Function]):
|
410
|
+
"""
|
411
|
+
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
412
|
+
Return: (final_text, all_calls)
|
413
|
+
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
414
|
+
- all_calls: All calls parsed by the Detectors
|
415
|
+
"""
|
416
|
+
final_calls = []
|
417
|
+
final_normal_text = text
|
418
|
+
for detector in self.detectors:
|
419
|
+
tool_call_list = detector.detect_and_parse(text, tools)
|
420
|
+
if len(tool_call_list) > 0: # parsed successfully
|
421
|
+
final_calls = tool_call_list
|
422
|
+
break
|
423
|
+
|
424
|
+
# leftover_text is the normal text not consumed by any Detector
|
425
|
+
return final_normal_text, final_calls
|
426
|
+
|
427
|
+
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
428
|
+
"""
|
429
|
+
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
430
|
+
and merge their produced normal_text/calls to return.
|
431
|
+
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
432
|
+
"""
|
433
|
+
final_normal_text = ""
|
434
|
+
final_calls = []
|
435
|
+
|
436
|
+
for detector in self.detectors:
|
437
|
+
sp_result = detector.parse_streaming_increment(new_text, tools)
|
438
|
+
# Merge normal_text and calls
|
439
|
+
# If one sp_result contains result call, this should be a successful parse
|
440
|
+
# If one sp_result only contains normal_text, this can either be a successful
|
441
|
+
# parse or it is not using the desired parsing tool.
|
442
|
+
if sp_result.normal_text:
|
443
|
+
final_normal_text = sp_result.normal_text
|
444
|
+
if sp_result.calls:
|
445
|
+
final_calls.extend(sp_result.calls)
|
446
|
+
final_normal_text = sp_result.normal_text
|
447
|
+
break
|
448
|
+
|
449
|
+
return final_normal_text, final_calls
|
450
|
+
|
451
|
+
|
452
|
+
class FunctionCallParser:
|
453
|
+
"""
|
454
|
+
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
455
|
+
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
456
|
+
"""
|
457
|
+
|
458
|
+
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
459
|
+
"llama3": Llama32Detector,
|
460
|
+
"qwen25": Qwen25Detector,
|
461
|
+
"mistral": MistralDetector,
|
462
|
+
}
|
463
|
+
|
464
|
+
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
465
|
+
detectors = []
|
466
|
+
if tool_call_parser:
|
467
|
+
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
468
|
+
if detector_class:
|
469
|
+
detectors.append(detector_class())
|
470
|
+
else:
|
471
|
+
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
472
|
+
else:
|
473
|
+
raise ValueError("Tool Call Parser Not Given!")
|
474
|
+
|
475
|
+
self.multi_format_parser = MultiFormatParser(detectors)
|
476
|
+
self.tools = tools
|
477
|
+
|
478
|
+
def parse_non_stream(self, full_text: str):
|
479
|
+
"""
|
480
|
+
Non-streaming call: one-time parsing
|
481
|
+
"""
|
482
|
+
full_normal_text, calls = self.multi_format_parser.parse_once(
|
483
|
+
full_text, self.tools
|
484
|
+
)
|
485
|
+
return full_normal_text, calls
|
486
|
+
|
487
|
+
def parse_stream_chunk(self, chunk_text: str):
|
488
|
+
"""
|
489
|
+
Streaming call: incremental parsing
|
490
|
+
"""
|
491
|
+
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
492
|
+
chunk_text, self.tools
|
493
|
+
)
|
494
|
+
return normal_text, calls
|
sglang/srt/layers/activation.py
CHANGED
@@ -20,18 +20,18 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
import torch.nn.functional as F
|
22
22
|
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import is_cuda_available
|
24
24
|
|
25
|
-
if
|
26
|
-
from
|
25
|
+
if is_cuda_available():
|
26
|
+
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
27
27
|
|
28
|
-
from vllm.
|
28
|
+
from vllm.model_executor.custom_op import CustomOp
|
29
|
+
|
30
|
+
from sglang.srt.distributed import (
|
29
31
|
divide,
|
30
32
|
get_tensor_model_parallel_rank,
|
31
33
|
get_tensor_model_parallel_world_size,
|
32
34
|
)
|
33
|
-
from vllm.model_executor.custom_op import CustomOp
|
34
|
-
|
35
35
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
36
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
37
|
from sglang.srt.utils import set_weight_attrs
|
@@ -149,8 +149,8 @@ def get_act_fn(
|
|
149
149
|
return act_fn
|
150
150
|
|
151
151
|
|
152
|
-
if not
|
152
|
+
if not is_cuda_available():
|
153
153
|
logger.info(
|
154
|
-
"
|
154
|
+
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
155
155
|
)
|
156
156
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -18,6 +18,7 @@ import triton.language as tl
|
|
18
18
|
|
19
19
|
from sglang.global_config import global_config
|
20
20
|
from sglang.srt.layers.attention import AttentionBackend
|
21
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
21
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
22
23
|
from sglang.srt.utils import is_flashinfer_available
|
23
24
|
|
@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
62
63
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
63
64
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
64
65
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
65
|
-
//
|
66
|
+
// get_attention_tp_size(),
|
66
67
|
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
67
|
-
|
68
|
+
get_attention_tp_size()
|
68
69
|
),
|
69
70
|
)
|
70
71
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
147
148
|
self.prefill_cuda_graph_metadata = {}
|
148
149
|
|
149
150
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
150
|
-
if forward_batch.forward_mode.
|
151
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
151
152
|
self.indices_updater_decode.update(
|
152
153
|
forward_batch.req_pool_indices,
|
153
154
|
forward_batch.seq_lens,
|
@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
238
239
|
forward_mode: ForwardMode,
|
239
240
|
spec_info: Optional[SpecInfo],
|
240
241
|
):
|
241
|
-
if forward_mode.
|
242
|
+
if forward_mode.is_decode_or_idle():
|
242
243
|
decode_wrappers = []
|
243
244
|
for i in range(self.num_wrappers):
|
244
245
|
decode_wrappers.append(
|
@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
307
308
|
forward_mode: ForwardMode,
|
308
309
|
spec_info: Optional[SpecInfo],
|
309
310
|
):
|
310
|
-
if forward_mode.
|
311
|
+
if forward_mode.is_decode_or_idle():
|
311
312
|
self.indices_updater_decode.update(
|
312
313
|
req_pool_indices[:bs],
|
313
314
|
seq_lens[:bs],
|
@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
|
|
453
454
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
454
455
|
# Parse Constants
|
455
456
|
self.num_qo_heads = (
|
456
|
-
model_runner.model_config.num_attention_heads //
|
457
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
457
458
|
)
|
458
459
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
459
|
-
|
460
|
+
get_attention_tp_size()
|
460
461
|
)
|
461
462
|
self.head_dim = model_runner.model_config.head_dim
|
462
463
|
self.data_type = model_runner.kv_cache_dtype
|
@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
|
|
625
626
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
626
627
|
# Parse Constants
|
627
628
|
self.num_qo_heads = (
|
628
|
-
model_runner.model_config.num_attention_heads //
|
629
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
629
630
|
)
|
630
631
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
631
|
-
|
632
|
+
get_attention_tp_size()
|
632
633
|
)
|
633
634
|
self.head_dim = model_runner.model_config.head_dim
|
634
635
|
self.data_type = model_runner.kv_cache_dtype
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
8
9
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
28
29
|
self.decode_attention_fwd = decode_attention_fwd
|
29
30
|
self.extend_attention_fwd = extend_attention_fwd
|
30
31
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
self.num_head = (
|
35
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
|
-
)
|
32
|
+
self.num_head = (
|
33
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
34
|
+
)
|
37
35
|
|
38
36
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
39
37
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|