sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_utils.py +80 -50
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +2 -2
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +16 -16
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
import logging
|
2
3
|
import re
|
3
4
|
from abc import ABC, abstractmethod
|
4
5
|
from json import JSONDecodeError, JSONDecoder
|
@@ -8,6 +9,8 @@ import partial_json_parser
|
|
8
9
|
from partial_json_parser.core.options import Allow
|
9
10
|
from pydantic import BaseModel, Field
|
10
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
11
14
|
TOOLS_TAG_LIST = [
|
12
15
|
"<|plugin|>",
|
13
16
|
"<function=",
|
@@ -88,17 +91,43 @@ class BaseFormatDetector:
|
|
88
91
|
self.bot_token = ""
|
89
92
|
self.eot_token = ""
|
90
93
|
|
91
|
-
def parse_base_json(self, action:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
)
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
94
|
+
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
|
95
|
+
tool_indices = {
|
96
|
+
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
97
|
+
}
|
98
|
+
if not isinstance(action, list):
|
99
|
+
name = action.get("name")
|
100
|
+
if not name or name not in tool_indices:
|
101
|
+
logger.warning(f"Model attempted to call undefined function: {name}")
|
102
|
+
return []
|
103
|
+
|
104
|
+
return [
|
105
|
+
ToolCallItem(
|
106
|
+
tool_index=tool_indices[name],
|
107
|
+
name=name,
|
108
|
+
parameters=json.dumps(
|
109
|
+
action.get("parameters") or action.get("arguments", {}),
|
110
|
+
ensure_ascii=False,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
]
|
114
|
+
|
115
|
+
results = []
|
116
|
+
for act in action:
|
117
|
+
name = act.get("name")
|
118
|
+
if name and name in tool_indices:
|
119
|
+
results.append(
|
120
|
+
ToolCallItem(
|
121
|
+
tool_index=tool_indices[name],
|
122
|
+
name=name,
|
123
|
+
parameters=json.dumps(
|
124
|
+
act.get("parameters") or act.get("arguments", {}),
|
125
|
+
ensure_ascii=False,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
|
130
|
+
return results
|
102
131
|
|
103
132
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
104
133
|
"""
|
@@ -112,9 +141,7 @@ class BaseFormatDetector:
|
|
112
141
|
self, new_text: str, tools: List[Function]
|
113
142
|
) -> StreamingParseResult:
|
114
143
|
"""
|
115
|
-
Streaming incremental parsing
|
116
|
-
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
117
|
-
incremental argument output.
|
144
|
+
Streaming incremental parsing with tool validation.
|
118
145
|
"""
|
119
146
|
# Append new text to buffer
|
120
147
|
self._buffer += new_text
|
@@ -125,17 +152,19 @@ class BaseFormatDetector:
|
|
125
152
|
new_text = new_text.replace(self.eot_token, "")
|
126
153
|
return StreamingParseResult(normal_text=new_text)
|
127
154
|
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
155
|
+
# Build tool indices if not already built
|
156
|
+
if not hasattr(self, "_tool_indices"):
|
157
|
+
self._tool_indices = {
|
158
|
+
tool.function.name: i
|
159
|
+
for i, tool in enumerate(tools)
|
160
|
+
if tool.function and tool.function.name
|
161
|
+
}
|
162
|
+
|
132
163
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
133
164
|
try:
|
134
165
|
tool_call_arr = []
|
135
166
|
is_complete = []
|
136
167
|
try:
|
137
|
-
# depending on the prompt format the Llama model may or may not
|
138
|
-
# prefix the output with the <|python_tag|> token
|
139
168
|
start_idx = (
|
140
169
|
len(self.bot_token)
|
141
170
|
if current_text.startswith(self.bot_token)
|
@@ -149,8 +178,18 @@ class BaseFormatDetector:
|
|
149
178
|
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
150
179
|
)
|
151
180
|
start_idx += end_idx + len("; ")
|
152
|
-
|
153
|
-
#
|
181
|
+
|
182
|
+
# Validate tool name if present
|
183
|
+
if "name" in obj and obj["name"] not in self._tool_indices:
|
184
|
+
# Invalid tool name - reset state
|
185
|
+
self._buffer = ""
|
186
|
+
self.current_tool_id = -1
|
187
|
+
self.current_tool_name_sent = False
|
188
|
+
if self.streamed_args_for_tool:
|
189
|
+
self.streamed_args_for_tool.pop()
|
190
|
+
return StreamingParseResult()
|
191
|
+
|
192
|
+
# Handle parameters/arguments consistency
|
154
193
|
if "parameters" in obj:
|
155
194
|
assert (
|
156
195
|
"arguments" not in obj
|
@@ -159,29 +198,17 @@ class BaseFormatDetector:
|
|
159
198
|
tool_call_arr.append(obj)
|
160
199
|
|
161
200
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
162
|
-
# not enough tokens to parse into JSON yet
|
163
201
|
return StreamingParseResult()
|
164
202
|
|
165
|
-
# select as the current tool call the one we're on the state at
|
166
|
-
current_tool_call: Dict = (
|
167
|
-
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
168
|
-
)
|
169
|
-
|
170
|
-
# case -- if no tokens have been streamed for the tool, e.g.
|
171
|
-
# only the array brackets, stream nothing
|
172
203
|
if len(tool_call_arr) == 0:
|
173
204
|
return StreamingParseResult()
|
174
205
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
179
|
-
):
|
206
|
+
current_tool_call: Dict = (
|
207
|
+
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
208
|
+
)
|
180
209
|
|
181
|
-
|
182
|
-
|
183
|
-
# auto-generated due to JSON completions, but wasn't
|
184
|
-
# streamed to the client yet.
|
210
|
+
# Handle new tool in array
|
211
|
+
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
185
212
|
if self.current_tool_id >= 0:
|
186
213
|
cur_arguments = current_tool_call.get("arguments")
|
187
214
|
if cur_arguments:
|
@@ -190,7 +217,6 @@ class BaseFormatDetector:
|
|
190
217
|
argument_diff = cur_args_json[sent:]
|
191
218
|
|
192
219
|
res = StreamingParseResult(
|
193
|
-
normal_text=None,
|
194
220
|
calls=[
|
195
221
|
ToolCallItem(
|
196
222
|
tool_index=self.current_tool_id,
|
@@ -206,23 +232,20 @@ class BaseFormatDetector:
|
|
206
232
|
res = StreamingParseResult()
|
207
233
|
else:
|
208
234
|
res = StreamingParseResult()
|
209
|
-
|
235
|
+
|
210
236
|
self.current_tool_id = len(tool_call_arr) - 1
|
211
237
|
self.current_tool_name_sent = False
|
212
238
|
self.streamed_args_for_tool.append("")
|
213
|
-
print("starting on new tool %d", self.current_tool_id)
|
214
239
|
return res
|
215
240
|
|
216
|
-
#
|
217
|
-
# - otherwise send nothing
|
241
|
+
# Handle tool name
|
218
242
|
elif not self.current_tool_name_sent:
|
219
243
|
function_name = current_tool_call.get("name")
|
220
|
-
if function_name:
|
244
|
+
if function_name and function_name in self._tool_indices:
|
221
245
|
res = StreamingParseResult(
|
222
|
-
normal_text=None,
|
223
246
|
calls=[
|
224
247
|
ToolCallItem(
|
225
|
-
tool_index=self.
|
248
|
+
tool_index=self._tool_indices[function_name],
|
226
249
|
name=function_name,
|
227
250
|
parameters="",
|
228
251
|
)
|
@@ -232,8 +255,7 @@ class BaseFormatDetector:
|
|
232
255
|
else:
|
233
256
|
res = StreamingParseResult()
|
234
257
|
|
235
|
-
#
|
236
|
-
# arguments
|
258
|
+
# Handle streaming arguments
|
237
259
|
else:
|
238
260
|
cur_arguments = current_tool_call.get("arguments")
|
239
261
|
res = StreamingParseResult()
|
@@ -250,13 +272,12 @@ class BaseFormatDetector:
|
|
250
272
|
argument_diff = cur_args_json[sent:]
|
251
273
|
self._buffer = ""
|
252
274
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
253
|
-
self.current_tool_name_sent
|
275
|
+
self.current_tool_name_sent = False
|
254
276
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
255
277
|
|
256
278
|
elif prev_arguments:
|
257
279
|
prev_args_json = json.dumps(prev_arguments)
|
258
280
|
if cur_args_json != prev_args_json:
|
259
|
-
|
260
281
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
261
282
|
argument_diff = prefix[sent:]
|
262
283
|
|
@@ -279,8 +300,7 @@ class BaseFormatDetector:
|
|
279
300
|
return res
|
280
301
|
|
281
302
|
except Exception as e:
|
282
|
-
|
283
|
-
# Skipping chunk as a result of tool streaming extraction error
|
303
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
284
304
|
return StreamingParseResult()
|
285
305
|
|
286
306
|
|
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
|
|
372
392
|
Detector for Llama 3.2 models.
|
373
393
|
Assumes function call format:
|
374
394
|
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
375
|
-
Does not require a closing tag "</python_tag|>",
|
376
|
-
relies on json.loads(...) success to determine if JSON is complete.
|
377
395
|
"""
|
378
396
|
|
379
397
|
def __init__(self):
|
380
|
-
"""
|
381
|
-
Initializes the detector with necessary state variables.
|
382
|
-
"""
|
383
398
|
super().__init__()
|
384
399
|
self.bot_token = "<|python_tag|>"
|
385
400
|
|
386
401
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
387
|
-
"""
|
388
|
-
One-time parsing: Detects and parses tool calls in the provided text.
|
389
|
-
|
390
|
-
:param text: The complete text to parse.
|
391
|
-
:param tools: List of available tools.
|
392
|
-
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
393
|
-
"""
|
394
|
-
|
402
|
+
"""Parse function calls from text, handling multiple JSON objects."""
|
395
403
|
if "<|python_tag|>" not in text:
|
396
404
|
return []
|
397
|
-
|
398
|
-
|
399
|
-
|
405
|
+
|
406
|
+
_, action_text = text.split("<|python_tag|>")
|
407
|
+
|
408
|
+
# Split by semicolon and process each part
|
409
|
+
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
410
|
+
|
411
|
+
all_actions = []
|
412
|
+
for part in json_parts:
|
413
|
+
try:
|
414
|
+
# Parse each individual JSON object
|
415
|
+
action = json.loads(part)
|
416
|
+
all_actions.append(action)
|
417
|
+
except json.JSONDecodeError as e:
|
418
|
+
logger.warning(f"Failed to parse JSON part: {part}")
|
419
|
+
logger.warning(f"JSON parse error: {str(e)}")
|
420
|
+
continue
|
421
|
+
|
422
|
+
# Only process if we found valid JSON objects
|
423
|
+
if all_actions:
|
424
|
+
return self.parse_base_json(all_actions, tools)
|
425
|
+
|
426
|
+
return []
|
400
427
|
|
401
428
|
|
402
429
|
class MultiFormatParser:
|
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
17
17
|
def __init__(self, model_runner: ModelRunner):
|
18
18
|
# Lazy import to avoid the initialization of cuda context
|
19
19
|
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
20
|
+
extend_attention_fwd,
|
20
21
|
flash_decode_attention_fwd,
|
21
22
|
flash_decode_sparse_attention_fwd,
|
22
23
|
)
|
23
|
-
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
24
|
-
extend_attention_fwd,
|
25
|
-
)
|
26
24
|
|
27
25
|
super().__init__()
|
28
26
|
|
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
37
37
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
38
38
|
)
|
39
39
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
40
|
+
self.qo_indptr = torch.zeros(
|
41
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
42
|
+
)
|
40
43
|
|
41
44
|
self.num_head = (
|
42
45
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
54
57
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
55
58
|
"""Init auxiliary variables for triton attention backend."""
|
56
59
|
|
60
|
+
bs = forward_batch.batch_size
|
61
|
+
kv_indptr = self.kv_indptr
|
62
|
+
|
57
63
|
if forward_batch.forward_mode.is_decode():
|
58
64
|
attn_logits = torch.empty(
|
59
65
|
(
|
@@ -68,31 +74,62 @@ class TritonAttnBackend(AttentionBackend):
|
|
68
74
|
|
69
75
|
max_extend_len = None
|
70
76
|
|
71
|
-
kv_indptr = self.kv_indptr
|
72
|
-
bs = len(forward_batch.req_pool_indices)
|
73
77
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
74
78
|
kv_indptr = kv_indptr[: bs + 1]
|
75
79
|
kv_indices = torch.empty(
|
76
|
-
forward_batch.seq_lens_sum, dtype=torch.int32, device=
|
80
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
77
81
|
)
|
78
82
|
create_flashinfer_kv_indices_triton[(bs,)](
|
79
|
-
|
83
|
+
self.req_to_token,
|
80
84
|
forward_batch.req_pool_indices,
|
81
85
|
forward_batch.seq_lens,
|
82
86
|
kv_indptr,
|
83
87
|
None,
|
84
88
|
kv_indices,
|
85
|
-
|
89
|
+
self.req_to_token.stride(0),
|
86
90
|
)
|
87
91
|
|
92
|
+
qo_indptr = None
|
93
|
+
custom_mask = None
|
94
|
+
mask_offsets = None
|
88
95
|
else:
|
96
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(
|
97
|
+
forward_batch.extend_prefix_lens, dim=0
|
98
|
+
)
|
99
|
+
kv_indptr = kv_indptr[: bs + 1]
|
100
|
+
kv_indices = torch.empty(
|
101
|
+
forward_batch.extend_prefix_lens.sum().item(),
|
102
|
+
dtype=torch.int32,
|
103
|
+
device=self.device,
|
104
|
+
)
|
105
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
106
|
+
self.req_to_token,
|
107
|
+
forward_batch.req_pool_indices,
|
108
|
+
forward_batch.extend_prefix_lens,
|
109
|
+
kv_indptr,
|
110
|
+
None,
|
111
|
+
kv_indices,
|
112
|
+
self.req_to_token.stride(0),
|
113
|
+
)
|
114
|
+
|
115
|
+
qo_indptr = self.qo_indptr
|
116
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
117
|
+
qo_indptr = qo_indptr[: bs + 1]
|
118
|
+
custom_mask = None
|
119
|
+
mask_offsets = None
|
120
|
+
|
89
121
|
attn_logits = None
|
90
122
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
91
123
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
124
|
+
self.forward_metadata = (
|
125
|
+
attn_logits,
|
126
|
+
max_extend_len,
|
127
|
+
kv_indptr,
|
128
|
+
kv_indices,
|
129
|
+
qo_indptr,
|
130
|
+
custom_mask,
|
131
|
+
mask_offsets,
|
132
|
+
)
|
96
133
|
|
97
134
|
def init_cuda_graph_state(self, max_bs: int):
|
98
135
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -144,6 +181,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
144
181
|
None,
|
145
182
|
kv_indptr,
|
146
183
|
kv_indices,
|
184
|
+
None,
|
185
|
+
None,
|
186
|
+
None,
|
147
187
|
)
|
148
188
|
|
149
189
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -197,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
237
|
layer, forward_batch.out_cache_loc, k, v
|
198
238
|
)
|
199
239
|
|
200
|
-
|
240
|
+
(
|
241
|
+
_,
|
242
|
+
max_extend_len,
|
243
|
+
kv_indptr,
|
244
|
+
kv_indices,
|
245
|
+
qo_indptr,
|
246
|
+
custom_mask,
|
247
|
+
mask_offsets,
|
248
|
+
) = self.forward_metadata
|
201
249
|
self.extend_attention_fwd(
|
202
250
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
203
251
|
k.contiguous(),
|
@@ -205,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
205
253
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
206
254
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
207
255
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
256
|
+
qo_indptr,
|
257
|
+
kv_indptr,
|
258
|
+
kv_indices,
|
259
|
+
custom_mask,
|
260
|
+
mask_offsets,
|
213
261
|
max_extend_len,
|
214
262
|
layer.scaling,
|
215
263
|
layer.logit_cap,
|
@@ -235,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
235
283
|
else:
|
236
284
|
o = torch.empty_like(q)
|
237
285
|
|
238
|
-
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
286
|
+
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
239
287
|
|
240
288
|
if save_kv_cache:
|
241
289
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|