sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +124 -12
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +173 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/llama.py +8 -3
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +486 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +420 -401
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
20
20
|
import interegular
|
21
21
|
import torch
|
22
22
|
from outlines.fsm.guide import RegexGuide
|
23
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
24
23
|
from outlines.models.transformers import TransformerTokenizer
|
25
24
|
from pydantic import BaseModel
|
26
25
|
|
@@ -29,6 +28,15 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
29
28
|
BaseGrammarObject,
|
30
29
|
)
|
31
30
|
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
31
|
+
from sglang.srt.utils import is_hip
|
32
|
+
|
33
|
+
is_hip_ = is_hip()
|
34
|
+
|
35
|
+
if is_hip_:
|
36
|
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
37
|
+
else:
|
38
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
39
|
+
|
32
40
|
|
33
41
|
logger = logging.getLogger(__name__)
|
34
42
|
|
sglang/srt/custom_op.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
5
|
+
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
6
|
+
|
7
|
+
|
8
|
+
class CustomOp(nn.Module):
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
self._forward_method = self.dispatch_forward()
|
12
|
+
|
13
|
+
def forward(self, *args, **kwargs):
|
14
|
+
return self._forward_method(*args, **kwargs)
|
15
|
+
|
16
|
+
def forward_native(self, *args, **kwargs):
|
17
|
+
raise NotImplementedError
|
18
|
+
|
19
|
+
def forward_cuda(self, *args, **kwargs):
|
20
|
+
raise NotImplementedError
|
21
|
+
|
22
|
+
def forward_hip(self, *args, **kwargs):
|
23
|
+
return self.forward_cuda(*args, **kwargs)
|
24
|
+
|
25
|
+
def forward_xpu(self, *args, **kwargs):
|
26
|
+
return self.forward_native(*args, **kwargs)
|
27
|
+
|
28
|
+
def forward_hpu(self, *args, **kwargs):
|
29
|
+
return self.forward_native(*args, **kwargs)
|
30
|
+
|
31
|
+
def forward_cpu(self, *args, **kwargs):
|
32
|
+
return self.forward_native(*args, **kwargs)
|
33
|
+
|
34
|
+
def dispatch_forward(self):
|
35
|
+
if _is_cuda:
|
36
|
+
return self.forward_cuda
|
37
|
+
elif _is_rocm:
|
38
|
+
return self.forward_hip
|
39
|
+
else:
|
40
|
+
return self.forward_native
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
316
316
|
# Check flashinfer version
|
317
317
|
if server_args.attention_backend == "flashinfer":
|
318
318
|
assert_pkg_version(
|
319
|
-
"
|
320
|
-
"0.
|
319
|
+
"flashinfer_python",
|
320
|
+
"0.2.0.post2",
|
321
321
|
"Please uninstall the old version and "
|
322
322
|
"reinstall the latest version by following the instructions "
|
323
323
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
import logging
|
2
3
|
import re
|
3
4
|
from abc import ABC, abstractmethod
|
4
5
|
from json import JSONDecodeError, JSONDecoder
|
@@ -8,6 +9,8 @@ import partial_json_parser
|
|
8
9
|
from partial_json_parser.core.options import Allow
|
9
10
|
from pydantic import BaseModel, Field
|
10
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
11
14
|
TOOLS_TAG_LIST = [
|
12
15
|
"<|plugin|>",
|
13
16
|
"<function=",
|
@@ -88,17 +91,43 @@ class BaseFormatDetector:
|
|
88
91
|
self.bot_token = ""
|
89
92
|
self.eot_token = ""
|
90
93
|
|
91
|
-
def parse_base_json(self, action:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
)
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
94
|
+
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
|
95
|
+
tool_indices = {
|
96
|
+
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
97
|
+
}
|
98
|
+
if not isinstance(action, list):
|
99
|
+
name = action.get("name")
|
100
|
+
if not name or name not in tool_indices:
|
101
|
+
logger.warning(f"Model attempted to call undefined function: {name}")
|
102
|
+
return []
|
103
|
+
|
104
|
+
return [
|
105
|
+
ToolCallItem(
|
106
|
+
tool_index=tool_indices[name],
|
107
|
+
name=name,
|
108
|
+
parameters=json.dumps(
|
109
|
+
action.get("parameters") or action.get("arguments", {}),
|
110
|
+
ensure_ascii=False,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
]
|
114
|
+
|
115
|
+
results = []
|
116
|
+
for act in action:
|
117
|
+
name = act.get("name")
|
118
|
+
if name and name in tool_indices:
|
119
|
+
results.append(
|
120
|
+
ToolCallItem(
|
121
|
+
tool_index=tool_indices[name],
|
122
|
+
name=name,
|
123
|
+
parameters=json.dumps(
|
124
|
+
act.get("parameters") or act.get("arguments", {}),
|
125
|
+
ensure_ascii=False,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
|
130
|
+
return results
|
102
131
|
|
103
132
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
104
133
|
"""
|
@@ -112,9 +141,7 @@ class BaseFormatDetector:
|
|
112
141
|
self, new_text: str, tools: List[Function]
|
113
142
|
) -> StreamingParseResult:
|
114
143
|
"""
|
115
|
-
Streaming incremental parsing
|
116
|
-
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
117
|
-
incremental argument output.
|
144
|
+
Streaming incremental parsing with tool validation.
|
118
145
|
"""
|
119
146
|
# Append new text to buffer
|
120
147
|
self._buffer += new_text
|
@@ -125,17 +152,19 @@ class BaseFormatDetector:
|
|
125
152
|
new_text = new_text.replace(self.eot_token, "")
|
126
153
|
return StreamingParseResult(normal_text=new_text)
|
127
154
|
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
155
|
+
# Build tool indices if not already built
|
156
|
+
if not hasattr(self, "_tool_indices"):
|
157
|
+
self._tool_indices = {
|
158
|
+
tool.function.name: i
|
159
|
+
for i, tool in enumerate(tools)
|
160
|
+
if tool.function and tool.function.name
|
161
|
+
}
|
162
|
+
|
132
163
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
133
164
|
try:
|
134
165
|
tool_call_arr = []
|
135
166
|
is_complete = []
|
136
167
|
try:
|
137
|
-
# depending on the prompt format the Llama model may or may not
|
138
|
-
# prefix the output with the <|python_tag|> token
|
139
168
|
start_idx = (
|
140
169
|
len(self.bot_token)
|
141
170
|
if current_text.startswith(self.bot_token)
|
@@ -149,8 +178,18 @@ class BaseFormatDetector:
|
|
149
178
|
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
150
179
|
)
|
151
180
|
start_idx += end_idx + len("; ")
|
152
|
-
|
153
|
-
#
|
181
|
+
|
182
|
+
# Validate tool name if present
|
183
|
+
if "name" in obj and obj["name"] not in self._tool_indices:
|
184
|
+
# Invalid tool name - reset state
|
185
|
+
self._buffer = ""
|
186
|
+
self.current_tool_id = -1
|
187
|
+
self.current_tool_name_sent = False
|
188
|
+
if self.streamed_args_for_tool:
|
189
|
+
self.streamed_args_for_tool.pop()
|
190
|
+
return StreamingParseResult()
|
191
|
+
|
192
|
+
# Handle parameters/arguments consistency
|
154
193
|
if "parameters" in obj:
|
155
194
|
assert (
|
156
195
|
"arguments" not in obj
|
@@ -159,29 +198,17 @@ class BaseFormatDetector:
|
|
159
198
|
tool_call_arr.append(obj)
|
160
199
|
|
161
200
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
162
|
-
# not enough tokens to parse into JSON yet
|
163
201
|
return StreamingParseResult()
|
164
202
|
|
165
|
-
# select as the current tool call the one we're on the state at
|
166
|
-
current_tool_call: Dict = (
|
167
|
-
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
168
|
-
)
|
169
|
-
|
170
|
-
# case -- if no tokens have been streamed for the tool, e.g.
|
171
|
-
# only the array brackets, stream nothing
|
172
203
|
if len(tool_call_arr) == 0:
|
173
204
|
return StreamingParseResult()
|
174
205
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
179
|
-
):
|
206
|
+
current_tool_call: Dict = (
|
207
|
+
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
208
|
+
)
|
180
209
|
|
181
|
-
|
182
|
-
|
183
|
-
# auto-generated due to JSON completions, but wasn't
|
184
|
-
# streamed to the client yet.
|
210
|
+
# Handle new tool in array
|
211
|
+
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
185
212
|
if self.current_tool_id >= 0:
|
186
213
|
cur_arguments = current_tool_call.get("arguments")
|
187
214
|
if cur_arguments:
|
@@ -190,7 +217,6 @@ class BaseFormatDetector:
|
|
190
217
|
argument_diff = cur_args_json[sent:]
|
191
218
|
|
192
219
|
res = StreamingParseResult(
|
193
|
-
normal_text=None,
|
194
220
|
calls=[
|
195
221
|
ToolCallItem(
|
196
222
|
tool_index=self.current_tool_id,
|
@@ -206,23 +232,20 @@ class BaseFormatDetector:
|
|
206
232
|
res = StreamingParseResult()
|
207
233
|
else:
|
208
234
|
res = StreamingParseResult()
|
209
|
-
|
235
|
+
|
210
236
|
self.current_tool_id = len(tool_call_arr) - 1
|
211
237
|
self.current_tool_name_sent = False
|
212
238
|
self.streamed_args_for_tool.append("")
|
213
|
-
print("starting on new tool %d", self.current_tool_id)
|
214
239
|
return res
|
215
240
|
|
216
|
-
#
|
217
|
-
# - otherwise send nothing
|
241
|
+
# Handle tool name
|
218
242
|
elif not self.current_tool_name_sent:
|
219
243
|
function_name = current_tool_call.get("name")
|
220
|
-
if function_name:
|
244
|
+
if function_name and function_name in self._tool_indices:
|
221
245
|
res = StreamingParseResult(
|
222
|
-
normal_text=None,
|
223
246
|
calls=[
|
224
247
|
ToolCallItem(
|
225
|
-
tool_index=self.
|
248
|
+
tool_index=self._tool_indices[function_name],
|
226
249
|
name=function_name,
|
227
250
|
parameters="",
|
228
251
|
)
|
@@ -232,8 +255,7 @@ class BaseFormatDetector:
|
|
232
255
|
else:
|
233
256
|
res = StreamingParseResult()
|
234
257
|
|
235
|
-
#
|
236
|
-
# arguments
|
258
|
+
# Handle streaming arguments
|
237
259
|
else:
|
238
260
|
cur_arguments = current_tool_call.get("arguments")
|
239
261
|
res = StreamingParseResult()
|
@@ -250,13 +272,12 @@ class BaseFormatDetector:
|
|
250
272
|
argument_diff = cur_args_json[sent:]
|
251
273
|
self._buffer = ""
|
252
274
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
253
|
-
self.current_tool_name_sent
|
275
|
+
self.current_tool_name_sent = False
|
254
276
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
255
277
|
|
256
278
|
elif prev_arguments:
|
257
279
|
prev_args_json = json.dumps(prev_arguments)
|
258
280
|
if cur_args_json != prev_args_json:
|
259
|
-
|
260
281
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
261
282
|
argument_diff = prefix[sent:]
|
262
283
|
|
@@ -279,8 +300,7 @@ class BaseFormatDetector:
|
|
279
300
|
return res
|
280
301
|
|
281
302
|
except Exception as e:
|
282
|
-
|
283
|
-
# Skipping chunk as a result of tool streaming extraction error
|
303
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
284
304
|
return StreamingParseResult()
|
285
305
|
|
286
306
|
|
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
|
|
372
392
|
Detector for Llama 3.2 models.
|
373
393
|
Assumes function call format:
|
374
394
|
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
375
|
-
Does not require a closing tag "</python_tag|>",
|
376
|
-
relies on json.loads(...) success to determine if JSON is complete.
|
377
395
|
"""
|
378
396
|
|
379
397
|
def __init__(self):
|
380
|
-
"""
|
381
|
-
Initializes the detector with necessary state variables.
|
382
|
-
"""
|
383
398
|
super().__init__()
|
384
399
|
self.bot_token = "<|python_tag|>"
|
385
400
|
|
386
401
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
387
|
-
"""
|
388
|
-
One-time parsing: Detects and parses tool calls in the provided text.
|
389
|
-
|
390
|
-
:param text: The complete text to parse.
|
391
|
-
:param tools: List of available tools.
|
392
|
-
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
393
|
-
"""
|
394
|
-
|
402
|
+
"""Parse function calls from text, handling multiple JSON objects."""
|
395
403
|
if "<|python_tag|>" not in text:
|
396
404
|
return []
|
397
|
-
|
398
|
-
|
399
|
-
|
405
|
+
|
406
|
+
_, action_text = text.split("<|python_tag|>")
|
407
|
+
|
408
|
+
# Split by semicolon and process each part
|
409
|
+
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
410
|
+
|
411
|
+
all_actions = []
|
412
|
+
for part in json_parts:
|
413
|
+
try:
|
414
|
+
# Parse each individual JSON object
|
415
|
+
action = json.loads(part)
|
416
|
+
all_actions.append(action)
|
417
|
+
except json.JSONDecodeError as e:
|
418
|
+
logger.warning(f"Failed to parse JSON part: {part}")
|
419
|
+
logger.warning(f"JSON parse error: {str(e)}")
|
420
|
+
continue
|
421
|
+
|
422
|
+
# Only process if we found valid JSON objects
|
423
|
+
if all_actions:
|
424
|
+
return self.parse_base_json(all_actions, tools)
|
425
|
+
|
426
|
+
return []
|
400
427
|
|
401
428
|
|
402
429
|
class MultiFormatParser:
|
sglang/srt/layers/activation.py
CHANGED
@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
|
|
25
25
|
if is_cuda_available():
|
26
26
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
27
27
|
|
28
|
-
from
|
29
|
-
|
28
|
+
from sglang.srt.custom_op import CustomOp
|
30
29
|
from sglang.srt.distributed import (
|
31
30
|
divide,
|
32
31
|
get_tensor_model_parallel_rank,
|
33
32
|
get_tensor_model_parallel_world_size,
|
34
33
|
)
|
35
|
-
from sglang.srt.layers.custom_op_util import register_custom_op
|
36
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
35
|
from sglang.srt.utils import set_weight_attrs
|
38
36
|
|
39
37
|
logger = logging.getLogger(__name__)
|
40
38
|
|
41
39
|
|
42
|
-
@register_custom_op("sglang_silu_and_mul")
|
43
40
|
class SiluAndMul(CustomOp):
|
44
41
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
45
42
|
d = x.shape[-1] // 2
|
@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
|
|
53
50
|
return out
|
54
51
|
|
55
52
|
|
56
|
-
@register_custom_op("sglang_gelu_and_mul")
|
57
53
|
class GeluAndMul(CustomOp):
|
58
54
|
def __init__(self, approximate="tanh"):
|
59
55
|
super().__init__()
|
@@ -76,6 +72,15 @@ class GeluAndMul(CustomOp):
|
|
76
72
|
return out
|
77
73
|
|
78
74
|
|
75
|
+
class QuickGELU(CustomOp):
|
76
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
return x * torch.sigmoid(1.702 * x)
|
78
|
+
|
79
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
80
|
+
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
81
|
+
return self.forward_native(x)
|
82
|
+
|
83
|
+
|
79
84
|
class ScaledActivation(nn.Module):
|
80
85
|
"""An activation function with post-scale parameters.
|
81
86
|
|
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
17
17
|
def __init__(self, model_runner: ModelRunner):
|
18
18
|
# Lazy import to avoid the initialization of cuda context
|
19
19
|
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
20
|
+
extend_attention_fwd,
|
20
21
|
flash_decode_attention_fwd,
|
21
22
|
flash_decode_sparse_attention_fwd,
|
22
23
|
)
|
23
|
-
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
24
|
-
extend_attention_fwd,
|
25
|
-
)
|
26
24
|
|
27
25
|
super().__init__()
|
28
26
|
|