sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,59 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.16"
|
2
2
|
|
3
|
-
|
3
|
+
# SGL API Components
|
4
|
+
from sglang.api import (
|
5
|
+
Runtime,
|
6
|
+
assistant,
|
7
|
+
assistant_begin,
|
8
|
+
assistant_end,
|
9
|
+
flush_cache,
|
10
|
+
function,
|
11
|
+
gen,
|
12
|
+
gen_int,
|
13
|
+
gen_string,
|
14
|
+
get_server_args,
|
15
|
+
image,
|
16
|
+
select,
|
17
|
+
set_default_backend,
|
18
|
+
system,
|
19
|
+
user,
|
20
|
+
user_begin,
|
21
|
+
user_end,
|
22
|
+
video,
|
23
|
+
)
|
24
|
+
|
25
|
+
# SGL Backends
|
26
|
+
from sglang.backend.anthropic import Anthropic
|
27
|
+
from sglang.backend.openai import OpenAI
|
28
|
+
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
|
+
from sglang.backend.vertexai import VertexAI
|
30
|
+
|
31
|
+
# Global Configurations
|
4
32
|
from sglang.global_config import global_config
|
33
|
+
|
34
|
+
# public APIs management
|
35
|
+
__all__ = [
|
36
|
+
"global_config",
|
37
|
+
"Anthropic",
|
38
|
+
"OpenAI",
|
39
|
+
"RuntimeEndpoint",
|
40
|
+
"VertexAI",
|
41
|
+
"function",
|
42
|
+
"Runtime",
|
43
|
+
"set_default_backend",
|
44
|
+
"flush_cache",
|
45
|
+
"get_server_args",
|
46
|
+
"gen",
|
47
|
+
"gen_int",
|
48
|
+
"gen_string",
|
49
|
+
"image",
|
50
|
+
"video",
|
51
|
+
"select",
|
52
|
+
"system",
|
53
|
+
"user",
|
54
|
+
"assistant",
|
55
|
+
"user_begin",
|
56
|
+
"user_end",
|
57
|
+
"assistant_begin",
|
58
|
+
"assistant_end",
|
59
|
+
]
|
sglang/api.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1
|
-
"""Public API"""
|
1
|
+
"""Some Public API Definitions"""
|
2
2
|
|
3
|
+
import os
|
3
4
|
import re
|
4
5
|
from typing import Callable, List, Optional, Union
|
5
6
|
|
6
|
-
from sglang.backend.anthropic import Anthropic
|
7
7
|
from sglang.backend.base_backend import BaseBackend
|
8
|
-
from sglang.backend.openai import OpenAI
|
9
|
-
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
10
|
-
from sglang.backend.vertexai import VertexAI
|
11
8
|
from sglang.global_config import global_config
|
12
9
|
from sglang.lang.ir import (
|
13
10
|
SglExpr,
|
@@ -18,6 +15,7 @@ from sglang.lang.ir import (
|
|
18
15
|
SglRoleBegin,
|
19
16
|
SglRoleEnd,
|
20
17
|
SglSelect,
|
18
|
+
SglVideo,
|
21
19
|
)
|
22
20
|
|
23
21
|
|
@@ -35,6 +33,7 @@ def function(
|
|
35
33
|
|
36
34
|
def Runtime(*args, **kwargs):
|
37
35
|
# Avoid importing unnecessary dependency
|
36
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
38
37
|
from sglang.srt.server import Runtime
|
39
38
|
|
40
39
|
return Runtime(*args, **kwargs)
|
@@ -153,6 +152,10 @@ def image(expr: SglExpr):
|
|
153
152
|
return SglImage(expr)
|
154
153
|
|
155
154
|
|
155
|
+
def video(path: str, num_frames: int):
|
156
|
+
return SglVideo(path, num_frames)
|
157
|
+
|
158
|
+
|
156
159
|
def select(
|
157
160
|
name: Optional[str] = None,
|
158
161
|
choices: List[str] = None,
|
sglang/backend/anthropic.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import List, Optional, Union
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
|
4
5
|
from sglang.backend.base_backend import BaseBackend
|
5
6
|
from sglang.lang.chat_template import get_chat_template
|
6
7
|
from sglang.lang.interpreter import StreamExecutor
|
@@ -13,7 +14,7 @@ except ImportError as e:
|
|
13
14
|
|
14
15
|
|
15
16
|
class Anthropic(BaseBackend):
|
16
|
-
def __init__(self, model_name):
|
17
|
+
def __init__(self, model_name, *args, **kwargs):
|
17
18
|
super().__init__()
|
18
19
|
|
19
20
|
if isinstance(anthropic, Exception):
|
@@ -21,6 +22,7 @@ class Anthropic(BaseBackend):
|
|
21
22
|
|
22
23
|
self.model_name = model_name
|
23
24
|
self.chat_template = get_chat_template("claude")
|
25
|
+
self.client = anthropic.Anthropic(*args, **kwargs)
|
24
26
|
|
25
27
|
def get_chat_template(self):
|
26
28
|
return self.chat_template
|
@@ -35,8 +37,14 @@ class Anthropic(BaseBackend):
|
|
35
37
|
else:
|
36
38
|
messages = [{"role": "user", "content": s.text_}]
|
37
39
|
|
38
|
-
|
40
|
+
if messages and messages[0]["role"] == "system":
|
41
|
+
system = messages.pop(0)["content"]
|
42
|
+
else:
|
43
|
+
system = ""
|
44
|
+
|
45
|
+
ret = self.client.messages.create(
|
39
46
|
model=self.model_name,
|
47
|
+
system=system,
|
40
48
|
messages=messages,
|
41
49
|
**sampling_params.to_anthropic_kwargs(),
|
42
50
|
)
|
@@ -54,10 +62,16 @@ class Anthropic(BaseBackend):
|
|
54
62
|
else:
|
55
63
|
messages = [{"role": "user", "content": s.text_}]
|
56
64
|
|
57
|
-
|
65
|
+
if messages and messages[0]["role"] == "system":
|
66
|
+
system = messages.pop(0)["content"]
|
67
|
+
else:
|
68
|
+
system = ""
|
69
|
+
|
70
|
+
with self.client.messages.stream(
|
58
71
|
model=self.model_name,
|
72
|
+
system=system,
|
59
73
|
messages=messages,
|
60
74
|
**sampling_params.to_anthropic_kwargs(),
|
61
75
|
) as stream:
|
62
76
|
for text in stream.text_stream:
|
63
|
-
yield text, {}
|
77
|
+
yield text, {}
|
sglang/backend/openai.py
CHANGED
@@ -3,6 +3,7 @@ import time
|
|
3
3
|
from typing import Callable, List, Optional, Union
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
|
6
7
|
from sglang.backend.base_backend import BaseBackend
|
7
8
|
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
8
9
|
from sglang.lang.interpreter import StreamExecutor
|
@@ -227,7 +228,7 @@ class OpenAI(BaseBackend):
|
|
227
228
|
prompt_tokens.append(ret_token)
|
228
229
|
|
229
230
|
decision = choices[np.argmax(scores)]
|
230
|
-
return decision, scores,
|
231
|
+
return decision, scores, None, None
|
231
232
|
|
232
233
|
|
233
234
|
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import requests
|
6
|
+
|
6
7
|
from sglang.backend.base_backend import BaseBackend
|
7
8
|
from sglang.global_config import global_config
|
8
9
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
@@ -73,9 +74,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
73
74
|
assert res.status_code == 200
|
74
75
|
|
75
76
|
def commit_lazy_operations(self, s: StreamExecutor):
|
77
|
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
78
|
+
self._add_images(s, data)
|
76
79
|
res = http_request(
|
77
80
|
self.base_url + "/generate",
|
78
|
-
json=
|
81
|
+
json=data,
|
79
82
|
auth_token=self.auth_token,
|
80
83
|
api_key=self.api_key,
|
81
84
|
verify=self.verify,
|
@@ -104,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
104
107
|
"text": s.text_,
|
105
108
|
"sampling_params": {
|
106
109
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
110
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
107
111
|
**sampling_params.to_srt_kwargs(),
|
108
112
|
},
|
109
113
|
}
|
@@ -112,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
112
116
|
"text": s.text_,
|
113
117
|
"sampling_params": {
|
114
118
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
119
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
115
120
|
"dtype": "int",
|
116
121
|
**sampling_params.to_srt_kwargs(),
|
117
122
|
},
|
@@ -142,6 +147,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
142
147
|
"text": s.text_,
|
143
148
|
"sampling_params": {
|
144
149
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
150
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
145
151
|
**sampling_params.to_srt_kwargs(),
|
146
152
|
},
|
147
153
|
}
|
@@ -150,6 +156,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
150
156
|
"text": s.text_,
|
151
157
|
"sampling_params": {
|
152
158
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
159
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
153
160
|
"dtype": "int",
|
154
161
|
**sampling_params.to_srt_kwargs(),
|
155
162
|
},
|
@@ -224,13 +231,19 @@ class RuntimeEndpoint(BaseBackend):
|
|
224
231
|
)
|
225
232
|
assert res.status_code == 200
|
226
233
|
obj = res.json()
|
227
|
-
|
234
|
+
normalized_prompt_logprobs = [
|
228
235
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
229
236
|
]
|
230
|
-
|
237
|
+
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
238
|
+
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
239
|
+
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
231
240
|
|
232
|
-
|
233
|
-
|
241
|
+
return (
|
242
|
+
decision,
|
243
|
+
normalized_prompt_logprobs,
|
244
|
+
prefill_token_logprobs,
|
245
|
+
decode_token_logprobs,
|
246
|
+
)
|
234
247
|
|
235
248
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
236
249
|
res = http_request(
|
sglang/backend/vertexai.py
CHANGED
sglang/global_config.py
CHANGED
@@ -12,10 +12,11 @@ class GlobalConfig:
|
|
12
12
|
|
13
13
|
# Output configs
|
14
14
|
self.skip_special_tokens_in_output = True
|
15
|
+
self.spaces_between_special_tokens_in_out = True
|
15
16
|
|
16
17
|
# Optimization configs
|
17
18
|
self.eager_fill_image = False
|
18
|
-
self.
|
19
|
+
self.enable_precache_with_tracing = True
|
19
20
|
self.enable_parallel_encoding = True
|
20
21
|
self.enable_parallel_decoding = True
|
21
22
|
|
@@ -24,5 +25,8 @@ class GlobalConfig:
|
|
24
25
|
# adjust_cache: Adjust the position embedding of KV cache.
|
25
26
|
self.concate_and_append_mode = "no_adjust"
|
26
27
|
|
28
|
+
# Request dependency time due to network delay
|
29
|
+
self.request_dependency_time = 0.03
|
30
|
+
|
27
31
|
|
28
32
|
global_config = GlobalConfig()
|
sglang/lang/chat_template.py
CHANGED
@@ -162,6 +162,28 @@ register_chat_template(
|
|
162
162
|
)
|
163
163
|
)
|
164
164
|
|
165
|
+
register_chat_template(
|
166
|
+
ChatTemplate(
|
167
|
+
name="llama-3-instruct",
|
168
|
+
default_system_prompt=None,
|
169
|
+
role_prefix_and_suffix={
|
170
|
+
"system": (
|
171
|
+
"<|start_header_id|>system<|end_header_id|>\n\n",
|
172
|
+
"<|eot_id|>",
|
173
|
+
),
|
174
|
+
"user": (
|
175
|
+
"<|start_header_id|>user<|end_header_id|>\n\n",
|
176
|
+
"<|eot_id|>",
|
177
|
+
),
|
178
|
+
"assistant": (
|
179
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
180
|
+
"<|eot_id|>",
|
181
|
+
),
|
182
|
+
},
|
183
|
+
stop_str=("<|eot_id|>",),
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
165
187
|
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
166
188
|
register_chat_template(
|
167
189
|
ChatTemplate(
|
@@ -192,6 +214,44 @@ register_chat_template(
|
|
192
214
|
)
|
193
215
|
)
|
194
216
|
|
217
|
+
register_chat_template(
|
218
|
+
ChatTemplate(
|
219
|
+
name="dbrx-instruct",
|
220
|
+
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
|
221
|
+
role_prefix_and_suffix={
|
222
|
+
"system": ("<|im_start|>system\n", "<|im_end|>"),
|
223
|
+
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
|
224
|
+
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
|
225
|
+
},
|
226
|
+
stop_str=("<|im_end|>",),
|
227
|
+
)
|
228
|
+
)
|
229
|
+
|
230
|
+
register_chat_template(
|
231
|
+
ChatTemplate(
|
232
|
+
name="c4ai-command-r",
|
233
|
+
default_system_prompt=None,
|
234
|
+
role_prefix_and_suffix={
|
235
|
+
"system": (
|
236
|
+
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
237
|
+
"<|END_OF_TURN_TOKEN|>",
|
238
|
+
),
|
239
|
+
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
240
|
+
"assistant": (
|
241
|
+
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
242
|
+
"<|END_OF_TURN_TOKEN|>",
|
243
|
+
),
|
244
|
+
},
|
245
|
+
style=ChatTemplateStyle.PLAIN,
|
246
|
+
)
|
247
|
+
)
|
248
|
+
|
249
|
+
|
250
|
+
@register_chat_template_matching_function
|
251
|
+
def match_dbrx(model_path: str):
|
252
|
+
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
|
253
|
+
return get_chat_template("dbrx-instruct")
|
254
|
+
|
195
255
|
|
196
256
|
@register_chat_template_matching_function
|
197
257
|
def match_vicuna(model_path: str):
|
@@ -199,6 +259,8 @@ def match_vicuna(model_path: str):
|
|
199
259
|
return get_chat_template("vicuna_v1.1")
|
200
260
|
if "llava-v1.5" in model_path.lower():
|
201
261
|
return get_chat_template("vicuna_v1.1")
|
262
|
+
if "llava-next-video-7b" in model_path.lower():
|
263
|
+
return get_chat_template("vicuna_v1.1")
|
202
264
|
|
203
265
|
|
204
266
|
@register_chat_template_matching_function
|
@@ -214,21 +276,33 @@ def match_llama2_chat(model_path: str):
|
|
214
276
|
return get_chat_template("llama-2-chat")
|
215
277
|
|
216
278
|
|
279
|
+
@register_chat_template_matching_function
|
280
|
+
def match_llama3_instruct(model_path: str):
|
281
|
+
model_path = model_path.lower()
|
282
|
+
if "llama-3" in model_path and "instruct" in model_path:
|
283
|
+
return get_chat_template("llama-3-instruct")
|
284
|
+
|
285
|
+
|
217
286
|
@register_chat_template_matching_function
|
218
287
|
def match_chat_ml(model_path: str):
|
288
|
+
# import pdb;pdb.set_trace()
|
219
289
|
model_path = model_path.lower()
|
220
290
|
if "tinyllama" in model_path:
|
221
291
|
return get_chat_template("chatml")
|
222
292
|
if "qwen" in model_path and "chat" in model_path:
|
223
293
|
return get_chat_template("chatml")
|
224
|
-
if
|
294
|
+
if (
|
295
|
+
"llava-v1.6-34b" in model_path
|
296
|
+
or "llava-v1.6-yi-34b" in model_path
|
297
|
+
or "llava-next-video-34b" in model_path
|
298
|
+
):
|
225
299
|
return get_chat_template("chatml-llava")
|
226
300
|
|
227
301
|
|
228
302
|
@register_chat_template_matching_function
|
229
303
|
def match_chat_yi(model_path: str):
|
230
304
|
model_path = model_path.lower()
|
231
|
-
if "yi" in model_path:
|
305
|
+
if "yi" in model_path and "llava" not in model_path:
|
232
306
|
return get_chat_template("yi")
|
233
307
|
|
234
308
|
|
@@ -239,6 +313,13 @@ def match_gemma_it(model_path: str):
|
|
239
313
|
return get_chat_template("gemma-it")
|
240
314
|
|
241
315
|
|
316
|
+
@register_chat_template_matching_function
|
317
|
+
def match_c4ai_command_r(model_path: str):
|
318
|
+
model_path = model_path.lower()
|
319
|
+
if "c4ai-command-r" in model_path:
|
320
|
+
return get_chat_template("c4ai-command-r")
|
321
|
+
|
322
|
+
|
242
323
|
if __name__ == "__main__":
|
243
324
|
messages = [
|
244
325
|
{"role": "system", "content": None}, # None means default
|
sglang/lang/interpreter.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""The interpreter that executes SGL programs"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import contextvars
|
4
5
|
import multiprocessing
|
5
6
|
import queue
|
6
7
|
import threading
|
@@ -10,6 +11,7 @@ from contextlib import contextmanager
|
|
10
11
|
from typing import Any, Callable, Dict, List, Optional, Union
|
11
12
|
|
12
13
|
import tqdm
|
14
|
+
|
13
15
|
from sglang.global_config import global_config
|
14
16
|
from sglang.lang.ir import (
|
15
17
|
SglCommitLazy,
|
@@ -26,8 +28,9 @@ from sglang.lang.ir import (
|
|
26
28
|
SglVariable,
|
27
29
|
SglVarScopeBegin,
|
28
30
|
SglVarScopeEnd,
|
31
|
+
SglVideo,
|
29
32
|
)
|
30
|
-
from sglang.utils import encode_image_base64
|
33
|
+
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
31
34
|
|
32
35
|
|
33
36
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -84,9 +87,9 @@ def run_program_batch(
|
|
84
87
|
if hasattr(backend, "endpoint"):
|
85
88
|
backend = backend.endpoint
|
86
89
|
|
87
|
-
#
|
88
|
-
if len(batch_arguments) > 1:
|
89
|
-
|
90
|
+
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
91
|
+
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
92
|
+
cache_program(program, backend)
|
90
93
|
|
91
94
|
# Run all programs
|
92
95
|
if num_threads == "auto":
|
@@ -152,21 +155,12 @@ def run_program_batch(
|
|
152
155
|
return rets
|
153
156
|
|
154
157
|
|
155
|
-
def
|
156
|
-
|
157
|
-
# TODO: handle multiple backends
|
158
|
-
from sglang.lang.tracer import extract_prefix_by_tracing
|
159
|
-
|
160
|
-
prefix = extract_prefix_by_tracing(program, backend)
|
161
|
-
if prefix and len(prefix) > 64:
|
162
|
-
prefix_rid = backend.cache_prefix(prefix)
|
163
|
-
program.pin_prefix_rid = prefix_rid
|
164
|
-
return prefix_rid
|
165
|
-
return None
|
158
|
+
def cache_program(program, backend):
|
159
|
+
from sglang.lang.tracer import extract_prefix_by_tracing
|
166
160
|
|
167
|
-
|
168
|
-
|
169
|
-
|
161
|
+
prefix = extract_prefix_by_tracing(program, backend)
|
162
|
+
if prefix and len(prefix) > 64:
|
163
|
+
backend.cache_prefix(prefix)
|
170
164
|
|
171
165
|
|
172
166
|
class StreamExecutor:
|
@@ -193,6 +187,7 @@ class StreamExecutor:
|
|
193
187
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
194
188
|
self.meta_info = {} # Dict[name: str -> info: str]
|
195
189
|
self.is_finished = False
|
190
|
+
self.error = None
|
196
191
|
|
197
192
|
# For completion
|
198
193
|
self.text_ = "" # The full text
|
@@ -217,7 +212,13 @@ class StreamExecutor:
|
|
217
212
|
self.use_thread = use_thread
|
218
213
|
if self.use_thread:
|
219
214
|
self.queue = queue.Queue()
|
220
|
-
|
215
|
+
|
216
|
+
def _run_worker_in_context():
|
217
|
+
self._thread_worker_func()
|
218
|
+
|
219
|
+
self.worker = threading.Thread(
|
220
|
+
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
|
221
|
+
)
|
221
222
|
self.worker.start()
|
222
223
|
|
223
224
|
# For streaming
|
@@ -248,17 +249,24 @@ class StreamExecutor:
|
|
248
249
|
def set_var(self, name, value):
|
249
250
|
self.variables[name] = value
|
250
251
|
|
251
|
-
def get_meta_info(self, name):
|
252
|
+
def get_meta_info(self, name, timeout=None):
|
252
253
|
if name in self.variable_event:
|
253
|
-
self.variable_event[name].wait()
|
254
|
+
got = self.variable_event[name].wait(timeout)
|
255
|
+
if not got:
|
256
|
+
raise TimeoutError(f"Timeout while waiting for event '{name}'")
|
254
257
|
ret = self.meta_info.get(name, None)
|
255
258
|
return ret
|
256
259
|
|
257
|
-
def fork(
|
258
|
-
self
|
259
|
-
|
260
|
+
def fork(
|
261
|
+
self,
|
262
|
+
size: int = 1,
|
263
|
+
position_ids_offset: Optional[List[int]] = None,
|
264
|
+
):
|
265
|
+
if size > 1:
|
266
|
+
self.submit(SglCommitLazy())
|
260
267
|
|
261
|
-
|
268
|
+
self.sync()
|
269
|
+
size = int(size)
|
262
270
|
|
263
271
|
exes = [
|
264
272
|
StreamExecutor(
|
@@ -268,14 +276,15 @@ class StreamExecutor:
|
|
268
276
|
self.chat_template,
|
269
277
|
self.stream,
|
270
278
|
)
|
271
|
-
for _ in range(
|
279
|
+
for _ in range(size)
|
272
280
|
]
|
273
|
-
for i in range(
|
281
|
+
for i in range(size):
|
274
282
|
exes[i].variables = dict(self.variables)
|
275
283
|
exes[i].text_ = str(self.text_)
|
276
284
|
exes[i].messages_ = list(self.messages_)
|
277
285
|
exes[i].cur_role = self.cur_role
|
278
286
|
exes[i].fork_start_text_pos = len(self.text_)
|
287
|
+
exes[i].images_ = list(self.images_)
|
279
288
|
|
280
289
|
return exes
|
281
290
|
|
@@ -294,17 +303,39 @@ class StreamExecutor:
|
|
294
303
|
self.backend.end_program(self)
|
295
304
|
|
296
305
|
def _thread_worker_func(self):
|
306
|
+
error = None
|
307
|
+
|
297
308
|
while True:
|
298
309
|
expr = self.queue.get()
|
299
310
|
if expr is None:
|
300
311
|
self.queue.task_done()
|
301
312
|
break
|
302
313
|
|
303
|
-
|
314
|
+
try:
|
315
|
+
self._execute(expr)
|
316
|
+
except Exception as e:
|
317
|
+
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
318
|
+
error = e
|
319
|
+
break
|
304
320
|
self.queue.task_done()
|
305
321
|
if self.stream_text_event:
|
306
322
|
self.stream_text_event.set()
|
307
323
|
|
324
|
+
# Clean the queue and events
|
325
|
+
if error is not None:
|
326
|
+
try:
|
327
|
+
while True:
|
328
|
+
self.queue.task_done()
|
329
|
+
self.queue.get_nowait()
|
330
|
+
except queue.Empty:
|
331
|
+
pass
|
332
|
+
for name in self.variable_event:
|
333
|
+
self.variable_event[name].set()
|
334
|
+
if self.stream_var_event:
|
335
|
+
for name in self.stream_var_event:
|
336
|
+
self.stream_var_event[name].set()
|
337
|
+
self.error = error
|
338
|
+
|
308
339
|
if self.stream_text_event:
|
309
340
|
self.stream_text_event.set()
|
310
341
|
|
@@ -331,6 +362,8 @@ class StreamExecutor:
|
|
331
362
|
self._execute_role_end(other)
|
332
363
|
elif isinstance(other, SglImage):
|
333
364
|
self._execute_image(other)
|
365
|
+
elif isinstance(other, SglVideo):
|
366
|
+
self._execute_video(other)
|
334
367
|
elif isinstance(other, SglVariable):
|
335
368
|
self._execute_variable(other)
|
336
369
|
elif isinstance(other, SglVarScopeBegin):
|
@@ -367,6 +400,16 @@ class StreamExecutor:
|
|
367
400
|
self.cur_images.append((path, base64_data))
|
368
401
|
self.text_ += self.chat_template.image_token
|
369
402
|
|
403
|
+
def _execute_video(self, expr: SglVideo):
|
404
|
+
path = expr.path
|
405
|
+
num_frames = expr.num_frames
|
406
|
+
|
407
|
+
base64_data = encode_video_base64(path, num_frames)
|
408
|
+
|
409
|
+
self.images_.append((path, base64_data))
|
410
|
+
self.cur_images.append((path, base64_data))
|
411
|
+
self.text_ += self.chat_template.image_token
|
412
|
+
|
370
413
|
# if global_config.eager_fill_image:
|
371
414
|
# self.backend.fill_image(self)
|
372
415
|
|
@@ -454,15 +497,19 @@ class StreamExecutor:
|
|
454
497
|
self.stream_var_event[name].set()
|
455
498
|
|
456
499
|
def _execute_select(self, expr: SglSelect):
|
457
|
-
|
458
|
-
|
459
|
-
|
500
|
+
(
|
501
|
+
decision,
|
502
|
+
normalized_prompt_logprobs,
|
503
|
+
prefill_token_logprobs,
|
504
|
+
decode_token_logprobs,
|
505
|
+
) = self.backend.select(self, expr.choices, expr.temperature)
|
460
506
|
if expr.name is not None:
|
461
507
|
name = expr.name
|
462
508
|
self.variables[name] = decision
|
463
509
|
self.meta_info[name] = {
|
464
|
-
"
|
465
|
-
"
|
510
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
511
|
+
"prefill_token_logprobs": prefill_token_logprobs,
|
512
|
+
"decode_token_logprobs": decode_token_logprobs,
|
466
513
|
}
|
467
514
|
self.variable_event[name].set()
|
468
515
|
self.text_ += decision
|
@@ -634,8 +681,12 @@ class ProgramState:
|
|
634
681
|
yield
|
635
682
|
self.stream_executor.submit(SglVarScopeEnd(name))
|
636
683
|
|
637
|
-
def fork(
|
638
|
-
|
684
|
+
def fork(
|
685
|
+
self,
|
686
|
+
size: int = 1,
|
687
|
+
position_ids_offset: Optional[List[int]] = None,
|
688
|
+
):
|
689
|
+
stream_executors = self.stream_executor.fork(size, position_ids_offset)
|
639
690
|
states = [ProgramState(x) for x in stream_executors]
|
640
691
|
state_group = ProgramStateGroup(states, self)
|
641
692
|
return state_group
|
@@ -657,6 +708,9 @@ class ProgramState:
|
|
657
708
|
def sync(self):
|
658
709
|
return self.stream_executor.sync()
|
659
710
|
|
711
|
+
def error(self):
|
712
|
+
return self.stream_executor.error
|
713
|
+
|
660
714
|
def text_iter(self, var_name: Optional[str] = None):
|
661
715
|
if self.stream_executor.stream:
|
662
716
|
prev = 0
|
@@ -745,6 +799,9 @@ class ProgramState:
|
|
745
799
|
def __setitem__(self, name, value):
|
746
800
|
self.set_var(name, value)
|
747
801
|
|
802
|
+
def __contains__(self, name):
|
803
|
+
return name in self.stream_executor.variables
|
804
|
+
|
748
805
|
def __del__(self):
|
749
806
|
self.stream_executor.end()
|
750
807
|
|