sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.18"
|
2
2
|
|
3
3
|
# SGL API Components
|
4
4
|
from sglang.api import (
|
@@ -24,6 +24,7 @@ from sglang.api import (
|
|
24
24
|
|
25
25
|
# SGL Backends
|
26
26
|
from sglang.backend.anthropic import Anthropic
|
27
|
+
from sglang.backend.litellm import LiteLLM
|
27
28
|
from sglang.backend.openai import OpenAI
|
28
29
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
30
|
from sglang.backend.vertexai import VertexAI
|
@@ -35,6 +36,7 @@ from sglang.global_config import global_config
|
|
35
36
|
__all__ = [
|
36
37
|
"global_config",
|
37
38
|
"Anthropic",
|
39
|
+
"LiteLLM",
|
38
40
|
"OpenAI",
|
39
41
|
"RuntimeEndpoint",
|
40
42
|
"VertexAI",
|
sglang/api.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""Public APIs of the language."""
|
2
2
|
|
3
3
|
import os
|
4
4
|
import re
|
@@ -20,13 +20,13 @@ from sglang.lang.ir import (
|
|
20
20
|
|
21
21
|
|
22
22
|
def function(
|
23
|
-
func: Optional[Callable] = None,
|
23
|
+
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
24
24
|
):
|
25
25
|
if func:
|
26
|
-
return SglFunction(func,
|
26
|
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
27
27
|
|
28
28
|
def decorator(func):
|
29
|
-
return SglFunction(func,
|
29
|
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
30
30
|
|
31
31
|
return decorator
|
32
32
|
|
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
|
|
43
43
|
global_config.default_backend = backend
|
44
44
|
|
45
45
|
|
46
|
-
def flush_cache(backend: BaseBackend = None):
|
46
|
+
def flush_cache(backend: Optional[BaseBackend] = None):
|
47
47
|
backend = backend or global_config.default_backend
|
48
48
|
if backend is None:
|
49
49
|
return False
|
50
50
|
return backend.flush_cache()
|
51
51
|
|
52
52
|
|
53
|
-
def get_server_args(backend: BaseBackend = None):
|
53
|
+
def get_server_args(backend: Optional[BaseBackend] = None):
|
54
54
|
backend = backend or global_config.default_backend
|
55
55
|
if backend is None:
|
56
56
|
return None
|
@@ -158,7 +158,7 @@ def video(path: str, num_frames: int):
|
|
158
158
|
|
159
159
|
def select(
|
160
160
|
name: Optional[str] = None,
|
161
|
-
choices: List[str] = None,
|
161
|
+
choices: Optional[List[str]] = None,
|
162
162
|
temperature: float = 0.0,
|
163
163
|
):
|
164
164
|
assert choices is not None
|
sglang/backend/anthropic.py
CHANGED
@@ -0,0 +1,90 @@
|
|
1
|
+
from typing import Mapping, Optional
|
2
|
+
|
3
|
+
from sglang.backend.base_backend import BaseBackend
|
4
|
+
from sglang.lang.chat_template import get_chat_template_by_model_path
|
5
|
+
from sglang.lang.interpreter import StreamExecutor
|
6
|
+
from sglang.lang.ir import SglSamplingParams
|
7
|
+
|
8
|
+
try:
|
9
|
+
import litellm
|
10
|
+
except ImportError as e:
|
11
|
+
litellm = e
|
12
|
+
litellm.num_retries = 1
|
13
|
+
|
14
|
+
|
15
|
+
class LiteLLM(BaseBackend):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
model_name,
|
19
|
+
chat_template=None,
|
20
|
+
api_key=None,
|
21
|
+
organization: Optional[str] = None,
|
22
|
+
base_url: Optional[str] = None,
|
23
|
+
timeout: Optional[float] = 600,
|
24
|
+
max_retries: Optional[int] = litellm.num_retries,
|
25
|
+
default_headers: Optional[Mapping[str, str]] = None,
|
26
|
+
):
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
if isinstance(litellm, Exception):
|
30
|
+
raise litellm
|
31
|
+
|
32
|
+
self.model_name = model_name
|
33
|
+
|
34
|
+
self.chat_template = chat_template or get_chat_template_by_model_path(
|
35
|
+
model_name
|
36
|
+
)
|
37
|
+
|
38
|
+
self.client_params = {
|
39
|
+
"api_key": api_key,
|
40
|
+
"organization": organization,
|
41
|
+
"base_url": base_url,
|
42
|
+
"timeout": timeout,
|
43
|
+
"max_retries": max_retries,
|
44
|
+
"default_headers": default_headers,
|
45
|
+
}
|
46
|
+
|
47
|
+
def get_chat_template(self):
|
48
|
+
return self.chat_template
|
49
|
+
|
50
|
+
def generate(
|
51
|
+
self,
|
52
|
+
s: StreamExecutor,
|
53
|
+
sampling_params: SglSamplingParams,
|
54
|
+
):
|
55
|
+
if s.messages_:
|
56
|
+
messages = s.messages_
|
57
|
+
else:
|
58
|
+
messages = [{"role": "user", "content": s.text_}]
|
59
|
+
|
60
|
+
ret = litellm.completion(
|
61
|
+
model=self.model_name,
|
62
|
+
messages=messages,
|
63
|
+
**self.client_params,
|
64
|
+
**sampling_params.to_anthropic_kwargs(),
|
65
|
+
)
|
66
|
+
comp = ret.choices[0].message.content
|
67
|
+
|
68
|
+
return comp, {}
|
69
|
+
|
70
|
+
def generate_stream(
|
71
|
+
self,
|
72
|
+
s: StreamExecutor,
|
73
|
+
sampling_params: SglSamplingParams,
|
74
|
+
):
|
75
|
+
if s.messages_:
|
76
|
+
messages = s.messages_
|
77
|
+
else:
|
78
|
+
messages = [{"role": "user", "content": s.text_}]
|
79
|
+
|
80
|
+
ret = litellm.completion(
|
81
|
+
model=self.model_name,
|
82
|
+
messages=messages,
|
83
|
+
stream=True,
|
84
|
+
**self.client_params,
|
85
|
+
**sampling_params.to_litellm_kwargs(),
|
86
|
+
)
|
87
|
+
for chunk in ret:
|
88
|
+
text = chunk.choices[0].delta.content
|
89
|
+
if text is not None:
|
90
|
+
yield text, {}
|
sglang/backend/openai.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
2
3
|
import time
|
4
|
+
import warnings
|
3
5
|
from typing import Callable, List, Optional, Union
|
4
6
|
|
5
7
|
import numpy as np
|
@@ -41,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
|
|
41
43
|
]
|
42
44
|
|
43
45
|
|
46
|
+
@dataclasses.dataclass
|
47
|
+
class TokenUsage:
|
48
|
+
prompt_tokens: int
|
49
|
+
completion_tokens: int
|
50
|
+
|
51
|
+
def reset(self):
|
52
|
+
self.prompt_tokens = self.completion_tokens = 0
|
53
|
+
|
54
|
+
|
44
55
|
class OpenAI(BaseBackend):
|
45
56
|
def __init__(
|
46
57
|
self,
|
@@ -80,40 +91,92 @@ class OpenAI(BaseBackend):
|
|
80
91
|
else:
|
81
92
|
self.is_chat_model = True
|
82
93
|
|
83
|
-
self.
|
94
|
+
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
95
|
+
|
96
|
+
# Usage
|
97
|
+
self.token_usage = TokenUsage(0, 0)
|
98
|
+
|
99
|
+
# API speculative execution
|
100
|
+
# TODO(ying): This does not support multi-threading (run_batch)
|
101
|
+
self.spec_kwargs = {}
|
102
|
+
self.spec_format = []
|
103
|
+
self.spec_max_num_tries = 3
|
84
104
|
|
85
105
|
def get_chat_template(self):
|
86
106
|
return self.chat_template
|
87
107
|
|
108
|
+
def _prepare_spec_execution(
|
109
|
+
self,
|
110
|
+
sampling_params: SglSamplingParams,
|
111
|
+
num_api_spec_tokens: int,
|
112
|
+
spec_var_name: str,
|
113
|
+
):
|
114
|
+
if "max_tokens" not in self.spec_kwargs:
|
115
|
+
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
116
|
+
else:
|
117
|
+
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
118
|
+
|
119
|
+
params = sampling_params.to_openai_kwargs()
|
120
|
+
for key, value in params.items():
|
121
|
+
if key in ["stop"]:
|
122
|
+
continue
|
123
|
+
if key in ["max_tokens"]:
|
124
|
+
warnings.warn(
|
125
|
+
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
126
|
+
)
|
127
|
+
continue
|
128
|
+
if key not in self.spec_kwargs:
|
129
|
+
self.spec_kwargs[key] = value
|
130
|
+
else:
|
131
|
+
assert (
|
132
|
+
value == self.spec_kwargs[key]
|
133
|
+
), "sampling parameters should be consistent if turn on api speculative execution."
|
134
|
+
self.spec_format.append(
|
135
|
+
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
136
|
+
)
|
137
|
+
return "", {}
|
138
|
+
|
88
139
|
def generate(
|
89
140
|
self,
|
90
141
|
s: StreamExecutor,
|
91
142
|
sampling_params: SglSamplingParams,
|
143
|
+
spec_var_name: str = None,
|
92
144
|
):
|
93
145
|
if sampling_params.dtype is None:
|
94
146
|
if self.is_chat_model:
|
95
|
-
if
|
96
|
-
|
97
|
-
|
98
|
-
|
147
|
+
if s.num_api_spec_tokens is None:
|
148
|
+
if not s.text_.endswith(self.chat_prefix):
|
149
|
+
raise RuntimeError(
|
150
|
+
"This use case is not supported if api speculative execution is off. "
|
151
|
+
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
152
|
+
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
153
|
+
)
|
154
|
+
prompt = s.messages_
|
155
|
+
else:
|
156
|
+
return self._prepare_spec_execution(
|
157
|
+
sampling_params, s.num_api_spec_tokens, spec_var_name
|
99
158
|
)
|
100
|
-
prompt = s.messages_
|
101
159
|
else:
|
102
160
|
prompt = s.text_
|
103
161
|
|
104
162
|
kwargs = sampling_params.to_openai_kwargs()
|
105
163
|
comp = openai_completion(
|
106
164
|
client=self.client,
|
165
|
+
token_usage=self.token_usage,
|
107
166
|
is_chat=self.is_chat_model,
|
108
167
|
model=self.model_name,
|
109
168
|
prompt=prompt,
|
110
169
|
**kwargs,
|
111
170
|
)
|
112
171
|
elif sampling_params.dtype in [str, "str", "string"]:
|
172
|
+
assert (
|
173
|
+
not self.is_chat_model
|
174
|
+
), "constrained type not supported on chat model"
|
113
175
|
kwargs = sampling_params.to_openai_kwargs()
|
114
176
|
kwargs.pop("stop")
|
115
177
|
comp = openai_completion(
|
116
178
|
client=self.client,
|
179
|
+
token_usage=self.token_usage,
|
117
180
|
is_chat=self.is_chat_model,
|
118
181
|
model=self.model_name,
|
119
182
|
prompt=s.text_ + '"',
|
@@ -122,10 +185,14 @@ class OpenAI(BaseBackend):
|
|
122
185
|
)
|
123
186
|
comp = '"' + comp + '"'
|
124
187
|
elif sampling_params.dtype in [int, "int"]:
|
188
|
+
assert (
|
189
|
+
not self.is_chat_model
|
190
|
+
), "constrained type not supported on chat model"
|
125
191
|
kwargs = sampling_params.to_openai_kwargs()
|
126
192
|
kwargs.pop("stop")
|
127
193
|
comp = openai_completion(
|
128
194
|
client=self.client,
|
195
|
+
token_usage=self.token_usage,
|
129
196
|
is_chat=self.is_chat_model,
|
130
197
|
model=self.model_name,
|
131
198
|
prompt=s.text_,
|
@@ -138,6 +205,63 @@ class OpenAI(BaseBackend):
|
|
138
205
|
|
139
206
|
return comp, {}
|
140
207
|
|
208
|
+
def spec_fill(self, value: str):
|
209
|
+
assert self.is_chat_model
|
210
|
+
self.spec_format.append({"text": value, "stop": None, "name": None})
|
211
|
+
|
212
|
+
def spec_pattern_match(self, comp):
|
213
|
+
for i, term in enumerate(self.spec_format):
|
214
|
+
text = term["text"]
|
215
|
+
if text != "":
|
216
|
+
if comp.startswith(text):
|
217
|
+
comp = comp[len(text) :]
|
218
|
+
else:
|
219
|
+
return False
|
220
|
+
else:
|
221
|
+
pos = comp.find(term["stop"])
|
222
|
+
if pos != -1:
|
223
|
+
term["text"] = comp[:pos]
|
224
|
+
comp = comp[pos:]
|
225
|
+
else:
|
226
|
+
if i == len(self.spec_format) - 1:
|
227
|
+
term["text"] = comp
|
228
|
+
else:
|
229
|
+
return False
|
230
|
+
return True
|
231
|
+
|
232
|
+
def role_end_generate(
|
233
|
+
self,
|
234
|
+
s: StreamExecutor,
|
235
|
+
):
|
236
|
+
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
237
|
+
return
|
238
|
+
|
239
|
+
comp = ""
|
240
|
+
if not all(x["name"] is None for x in self.spec_format):
|
241
|
+
# TODO(ying): throw errors or warnings
|
242
|
+
for i in range(self.spec_max_num_tries):
|
243
|
+
comp = openai_completion(
|
244
|
+
client=self.client,
|
245
|
+
token_usage=self.token_usage,
|
246
|
+
is_chat=self.is_chat_model,
|
247
|
+
model=self.model_name,
|
248
|
+
prompt=s.messages_,
|
249
|
+
**self.spec_kwargs,
|
250
|
+
)
|
251
|
+
if self.spec_pattern_match(comp):
|
252
|
+
break
|
253
|
+
|
254
|
+
for term in self.spec_format:
|
255
|
+
s.text_ += term["text"]
|
256
|
+
name = term["name"]
|
257
|
+
if name is not None:
|
258
|
+
s.variables[name] = term["text"]
|
259
|
+
s.meta_info[name] = {}
|
260
|
+
s.variable_event[name].set()
|
261
|
+
|
262
|
+
self.spec_kwargs = {}
|
263
|
+
self.spec_format = []
|
264
|
+
|
141
265
|
def generate_stream(
|
142
266
|
self,
|
143
267
|
s: StreamExecutor,
|
@@ -145,7 +269,7 @@ class OpenAI(BaseBackend):
|
|
145
269
|
):
|
146
270
|
if sampling_params.dtype is None:
|
147
271
|
if self.is_chat_model:
|
148
|
-
if not s.text_.endswith(self.
|
272
|
+
if not s.text_.endswith(self.chat_prefix):
|
149
273
|
raise RuntimeError(
|
150
274
|
"This use case is not supported. "
|
151
275
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
@@ -157,6 +281,7 @@ class OpenAI(BaseBackend):
|
|
157
281
|
kwargs = sampling_params.to_openai_kwargs()
|
158
282
|
generator = openai_completion_stream(
|
159
283
|
client=self.client,
|
284
|
+
token_usage=self.token_usage,
|
160
285
|
is_chat=self.is_chat_model,
|
161
286
|
model=self.model_name,
|
162
287
|
prompt=prompt,
|
@@ -202,6 +327,8 @@ class OpenAI(BaseBackend):
|
|
202
327
|
)
|
203
328
|
ret_str = ret.choices[0].text
|
204
329
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
330
|
+
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
331
|
+
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
205
332
|
|
206
333
|
# TODO:
|
207
334
|
# 1. return logits as the scores
|
@@ -231,7 +358,9 @@ class OpenAI(BaseBackend):
|
|
231
358
|
return decision, scores, None, None
|
232
359
|
|
233
360
|
|
234
|
-
def openai_completion(
|
361
|
+
def openai_completion(
|
362
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
363
|
+
):
|
235
364
|
for attempt in range(retries):
|
236
365
|
try:
|
237
366
|
if is_chat:
|
@@ -245,6 +374,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|
245
374
|
comp = [c.text for c in ret.choices]
|
246
375
|
else:
|
247
376
|
comp = ret.choices[0].text
|
377
|
+
|
378
|
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
379
|
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
248
380
|
break
|
249
381
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
250
382
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
@@ -258,16 +390,23 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|
258
390
|
return comp
|
259
391
|
|
260
392
|
|
261
|
-
def openai_completion_stream(
|
393
|
+
def openai_completion_stream(
|
394
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
395
|
+
):
|
262
396
|
for attempt in range(retries):
|
263
397
|
try:
|
264
398
|
if is_chat:
|
265
399
|
if "stop" in kwargs and kwargs["stop"] is None:
|
266
400
|
kwargs.pop("stop")
|
267
401
|
generator = client.chat.completions.create(
|
268
|
-
messages=prompt,
|
402
|
+
messages=prompt,
|
403
|
+
stream=True,
|
404
|
+
stream_options={"include_usage": True},
|
405
|
+
**kwargs,
|
269
406
|
)
|
270
407
|
for ret in generator:
|
408
|
+
if len(ret.choices) == 0:
|
409
|
+
continue
|
271
410
|
try:
|
272
411
|
content = ret.choices[0].delta.content
|
273
412
|
except IndexError:
|
@@ -275,11 +414,19 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
|
275
414
|
yield content or "", {}
|
276
415
|
else:
|
277
416
|
generator = client.completions.create(
|
278
|
-
prompt=prompt,
|
417
|
+
prompt=prompt,
|
418
|
+
stream=True,
|
419
|
+
stream_options={"include_usage": True},
|
420
|
+
**kwargs,
|
279
421
|
)
|
280
422
|
for ret in generator:
|
423
|
+
if len(ret.choices) == 0:
|
424
|
+
continue
|
281
425
|
content = ret.choices[0].text
|
282
426
|
yield content or "", {}
|
427
|
+
|
428
|
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
429
|
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
283
430
|
break
|
284
431
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
285
432
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
34
34
|
api_key=self.api_key,
|
35
35
|
verify=self.verify,
|
36
36
|
)
|
37
|
-
|
37
|
+
self._assert_success(res)
|
38
38
|
self.model_info = res.json()
|
39
39
|
|
40
40
|
self.chat_template = get_chat_template_by_model_path(
|
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
50
50
|
auth_token=self.auth_token,
|
51
51
|
verify=self.verify,
|
52
52
|
)
|
53
|
-
|
53
|
+
self._assert_success(res)
|
54
54
|
|
55
55
|
def get_server_args(self):
|
56
56
|
res = http_request(
|
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
58
58
|
auth_token=self.auth_token,
|
59
59
|
verify=self.verify,
|
60
60
|
)
|
61
|
+
self._assert_success(res)
|
61
62
|
return res.json()
|
62
63
|
|
63
64
|
def get_chat_template(self):
|
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
71
72
|
api_key=self.api_key,
|
72
73
|
verify=self.verify,
|
73
74
|
)
|
74
|
-
|
75
|
+
self._assert_success(res)
|
75
76
|
|
76
77
|
def commit_lazy_operations(self, s: StreamExecutor):
|
77
78
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
83
84
|
api_key=self.api_key,
|
84
85
|
verify=self.verify,
|
85
86
|
)
|
86
|
-
|
87
|
+
self._assert_success(res)
|
87
88
|
|
88
89
|
def fill_image(self, s: StreamExecutor):
|
89
90
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
95
96
|
api_key=self.api_key,
|
96
97
|
verify=self.verify,
|
97
98
|
)
|
98
|
-
|
99
|
+
self._assert_success(res)
|
99
100
|
|
100
101
|
def generate(
|
101
102
|
self,
|
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
133
134
|
api_key=self.api_key,
|
134
135
|
verify=self.verify,
|
135
136
|
)
|
137
|
+
self._assert_success(res)
|
138
|
+
|
136
139
|
obj = res.json()
|
137
140
|
comp = obj["text"]
|
138
141
|
return comp, obj["meta_info"]
|
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
167
170
|
data["stream"] = True
|
168
171
|
self._add_images(s, data)
|
169
172
|
|
170
|
-
|
173
|
+
res = http_request(
|
171
174
|
self.base_url + "/generate",
|
172
175
|
json=data,
|
173
176
|
stream=True,
|
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
175
178
|
api_key=self.api_key,
|
176
179
|
verify=self.verify,
|
177
180
|
)
|
181
|
+
self._assert_success(res)
|
178
182
|
pos = 0
|
179
183
|
|
180
184
|
incomplete_text = ""
|
181
|
-
for chunk in
|
185
|
+
for chunk in res.iter_lines(decode_unicode=False):
|
182
186
|
chunk = chunk.decode("utf-8")
|
183
187
|
if chunk and chunk.startswith("data:"):
|
184
188
|
if chunk == "data: [DONE]":
|
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
211
215
|
api_key=self.api_key,
|
212
216
|
verify=self.verify,
|
213
217
|
)
|
214
|
-
|
218
|
+
self._assert_success(res)
|
215
219
|
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
216
220
|
|
217
221
|
# Compute logprob
|
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
229
233
|
api_key=self.api_key,
|
230
234
|
verify=self.verify,
|
231
235
|
)
|
232
|
-
|
236
|
+
self._assert_success(res)
|
233
237
|
obj = res.json()
|
234
238
|
normalized_prompt_logprobs = [
|
235
239
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
|
|
253
257
|
api_key=self.api_key,
|
254
258
|
verify=self.verify,
|
255
259
|
)
|
256
|
-
|
260
|
+
self._assert_success(res)
|
257
261
|
|
258
262
|
def _add_images(self, s: StreamExecutor, data):
|
259
263
|
if s.images_:
|
260
264
|
assert len(s.images_) == 1, "Only support one image."
|
261
265
|
data["image_data"] = s.images_[0][1]
|
266
|
+
|
267
|
+
def _assert_success(self, res):
|
268
|
+
if res.status_code != 200:
|
269
|
+
raise RuntimeError(res.json())
|