sglang 0.1.22__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/bench_serving.py +243 -25
- sglang/global_config.py +3 -2
- sglang/lang/interpreter.py +1 -0
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/radix_attention.py +38 -49
- sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
- sglang/srt/managers/controller/infer_batch.py +51 -22
- sglang/srt/managers/controller/model_runner.py +58 -4
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +9 -11
- sglang/srt/memory_pool.py +13 -5
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/llama2.py +19 -10
- sglang/srt/server.py +26 -1
- sglang/srt/server_args.py +12 -6
- sglang/srt/utils.py +93 -1
- sglang/version.py +1 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/METADATA +10 -6
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/RECORD +25 -36
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/WHEEL +1 -1
- sglang/backend/__init__.py +0 -0
- sglang/backend/anthropic.py +0 -77
- sglang/backend/base_backend.py +0 -80
- sglang/backend/litellm.py +0 -90
- sglang/backend/openai.py +0 -438
- sglang/backend/runtime_endpoint.py +0 -283
- sglang/backend/vertexai.py +0 -149
- sglang/bench.py +0 -627
- sglang/srt/managers/controller/dp_worker.py +0 -113
- sglang/srt/openai_api/api_adapter.py +0 -432
- sglang/srt/openai_api/openai_api_adapter.py +0 -431
- sglang/srt/openai_api/openai_protocol.py +0 -207
- sglang/srt/openai_api_adapter.py +0 -411
- sglang/srt/openai_protocol.py +0 -207
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/LICENSE +0 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/top_level.txt +0 -0
sglang/backend/base_backend.py
DELETED
@@ -1,80 +0,0 @@
|
|
1
|
-
from typing import Callable, List, Optional, Union
|
2
|
-
|
3
|
-
from sglang.lang.chat_template import get_chat_template
|
4
|
-
from sglang.lang.interpreter import StreamExecutor
|
5
|
-
from sglang.lang.ir import SglSamplingParams
|
6
|
-
|
7
|
-
|
8
|
-
class BaseBackend:
|
9
|
-
def __init__(self) -> None:
|
10
|
-
self.support_concate_and_append = False
|
11
|
-
self.chat_template = get_chat_template("default")
|
12
|
-
|
13
|
-
def get_model_name(self):
|
14
|
-
raise NotImplementedError()
|
15
|
-
|
16
|
-
def get_chat_template(self):
|
17
|
-
return self.chat_template
|
18
|
-
|
19
|
-
def cache_prefix(self, prefix_str: str):
|
20
|
-
pass
|
21
|
-
|
22
|
-
def uncache_prefix(self, rid: str):
|
23
|
-
pass
|
24
|
-
|
25
|
-
def end_request(self, rid: Union[str, List[str]]):
|
26
|
-
pass
|
27
|
-
|
28
|
-
def begin_program(self, s: StreamExecutor):
|
29
|
-
pass
|
30
|
-
|
31
|
-
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
32
|
-
pass
|
33
|
-
|
34
|
-
def commit_lazy_operations(self, s: StreamExecutor):
|
35
|
-
pass
|
36
|
-
|
37
|
-
def fork_program(
|
38
|
-
self,
|
39
|
-
src: StreamExecutor,
|
40
|
-
dst: List[StreamExecutor],
|
41
|
-
position_ids_offset: Optional[List[int]] = None,
|
42
|
-
):
|
43
|
-
pass
|
44
|
-
|
45
|
-
def fill_image(self, s: StreamExecutor):
|
46
|
-
pass
|
47
|
-
|
48
|
-
def generate(
|
49
|
-
self,
|
50
|
-
s: StreamExecutor,
|
51
|
-
sampling_params: SglSamplingParams,
|
52
|
-
):
|
53
|
-
raise NotImplementedError()
|
54
|
-
|
55
|
-
def generate_stream(
|
56
|
-
self,
|
57
|
-
s: StreamExecutor,
|
58
|
-
sampling_params: SglSamplingParams,
|
59
|
-
):
|
60
|
-
raise NotImplementedError()
|
61
|
-
|
62
|
-
def select(
|
63
|
-
self,
|
64
|
-
s: StreamExecutor,
|
65
|
-
choices: List[str],
|
66
|
-
temperature: float,
|
67
|
-
):
|
68
|
-
raise NotImplementedError()
|
69
|
-
|
70
|
-
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
71
|
-
raise NotImplementedError()
|
72
|
-
|
73
|
-
def shutdown(self):
|
74
|
-
pass
|
75
|
-
|
76
|
-
def flush_cache(self):
|
77
|
-
pass
|
78
|
-
|
79
|
-
def get_server_args(self):
|
80
|
-
pass
|
sglang/backend/litellm.py
DELETED
@@ -1,90 +0,0 @@
|
|
1
|
-
from typing import Mapping, Optional
|
2
|
-
|
3
|
-
from sglang.backend.base_backend import BaseBackend
|
4
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
5
|
-
from sglang.lang.interpreter import StreamExecutor
|
6
|
-
from sglang.lang.ir import SglSamplingParams
|
7
|
-
|
8
|
-
try:
|
9
|
-
import litellm
|
10
|
-
except ImportError as e:
|
11
|
-
litellm = e
|
12
|
-
litellm.num_retries = 1
|
13
|
-
|
14
|
-
|
15
|
-
class LiteLLM(BaseBackend):
|
16
|
-
def __init__(
|
17
|
-
self,
|
18
|
-
model_name,
|
19
|
-
chat_template=None,
|
20
|
-
api_key=None,
|
21
|
-
organization: Optional[str] = None,
|
22
|
-
base_url: Optional[str] = None,
|
23
|
-
timeout: Optional[float] = 600,
|
24
|
-
max_retries: Optional[int] = litellm.num_retries,
|
25
|
-
default_headers: Optional[Mapping[str, str]] = None,
|
26
|
-
):
|
27
|
-
super().__init__()
|
28
|
-
|
29
|
-
if isinstance(litellm, Exception):
|
30
|
-
raise litellm
|
31
|
-
|
32
|
-
self.model_name = model_name
|
33
|
-
|
34
|
-
self.chat_template = chat_template or get_chat_template_by_model_path(
|
35
|
-
model_name
|
36
|
-
)
|
37
|
-
|
38
|
-
self.client_params = {
|
39
|
-
"api_key": api_key,
|
40
|
-
"organization": organization,
|
41
|
-
"base_url": base_url,
|
42
|
-
"timeout": timeout,
|
43
|
-
"max_retries": max_retries,
|
44
|
-
"default_headers": default_headers,
|
45
|
-
}
|
46
|
-
|
47
|
-
def get_chat_template(self):
|
48
|
-
return self.chat_template
|
49
|
-
|
50
|
-
def generate(
|
51
|
-
self,
|
52
|
-
s: StreamExecutor,
|
53
|
-
sampling_params: SglSamplingParams,
|
54
|
-
):
|
55
|
-
if s.messages_:
|
56
|
-
messages = s.messages_
|
57
|
-
else:
|
58
|
-
messages = [{"role": "user", "content": s.text_}]
|
59
|
-
|
60
|
-
ret = litellm.completion(
|
61
|
-
model=self.model_name,
|
62
|
-
messages=messages,
|
63
|
-
**self.client_params,
|
64
|
-
**sampling_params.to_anthropic_kwargs(),
|
65
|
-
)
|
66
|
-
comp = ret.choices[0].message.content
|
67
|
-
|
68
|
-
return comp, {}
|
69
|
-
|
70
|
-
def generate_stream(
|
71
|
-
self,
|
72
|
-
s: StreamExecutor,
|
73
|
-
sampling_params: SglSamplingParams,
|
74
|
-
):
|
75
|
-
if s.messages_:
|
76
|
-
messages = s.messages_
|
77
|
-
else:
|
78
|
-
messages = [{"role": "user", "content": s.text_}]
|
79
|
-
|
80
|
-
ret = litellm.completion(
|
81
|
-
model=self.model_name,
|
82
|
-
messages=messages,
|
83
|
-
stream=True,
|
84
|
-
**self.client_params,
|
85
|
-
**sampling_params.to_litellm_kwargs(),
|
86
|
-
)
|
87
|
-
for chunk in ret:
|
88
|
-
text = chunk.choices[0].delta.content
|
89
|
-
if text is not None:
|
90
|
-
yield text, {}
|
sglang/backend/openai.py
DELETED
@@ -1,438 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import logging
|
3
|
-
import time
|
4
|
-
import warnings
|
5
|
-
from typing import Callable, List, Optional, Union
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
|
9
|
-
from sglang.backend.base_backend import BaseBackend
|
10
|
-
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
11
|
-
from sglang.lang.interpreter import StreamExecutor
|
12
|
-
from sglang.lang.ir import SglSamplingParams
|
13
|
-
|
14
|
-
try:
|
15
|
-
import openai
|
16
|
-
import tiktoken
|
17
|
-
except ImportError as e:
|
18
|
-
openai = tiktoken = e
|
19
|
-
|
20
|
-
|
21
|
-
logger = logging.getLogger("openai")
|
22
|
-
|
23
|
-
|
24
|
-
def create_logit_bias_int(tokenizer):
|
25
|
-
"""Get logit bias for integer numbers."""
|
26
|
-
int_token_ids = []
|
27
|
-
|
28
|
-
tokens = tokenizer._mergeable_ranks
|
29
|
-
for token, token_id in tokens.items():
|
30
|
-
s = tokenizer.decode([token_id])
|
31
|
-
if all([c.isdigit() for c in s]) or s in [" "]:
|
32
|
-
int_token_ids.append(token_id)
|
33
|
-
if len(int_token_ids) >= 300: # OpenAI API limit
|
34
|
-
break
|
35
|
-
special_tokens = tokenizer._special_tokens
|
36
|
-
mask = {t: 100 for t in int_token_ids[:299]}
|
37
|
-
mask[special_tokens["<|endoftext|>"]] = 100
|
38
|
-
return mask
|
39
|
-
|
40
|
-
|
41
|
-
INSTRUCT_MODEL_NAMES = [
|
42
|
-
"gpt-3.5-turbo-instruct",
|
43
|
-
]
|
44
|
-
|
45
|
-
|
46
|
-
@dataclasses.dataclass
|
47
|
-
class TokenUsage:
|
48
|
-
prompt_tokens: int
|
49
|
-
completion_tokens: int
|
50
|
-
|
51
|
-
def reset(self):
|
52
|
-
self.prompt_tokens = self.completion_tokens = 0
|
53
|
-
|
54
|
-
|
55
|
-
class OpenAI(BaseBackend):
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
model_name: str,
|
59
|
-
is_chat_model: Optional[bool] = None,
|
60
|
-
chat_template: Optional[ChatTemplate] = None,
|
61
|
-
is_azure: bool = False,
|
62
|
-
*args,
|
63
|
-
**kwargs,
|
64
|
-
):
|
65
|
-
super().__init__()
|
66
|
-
|
67
|
-
if isinstance(openai, Exception):
|
68
|
-
raise openai
|
69
|
-
|
70
|
-
if is_azure:
|
71
|
-
self.client = openai.AzureOpenAI(*args, **kwargs)
|
72
|
-
else:
|
73
|
-
self.client = openai.OpenAI(*args, **kwargs)
|
74
|
-
|
75
|
-
self.model_name = model_name
|
76
|
-
try:
|
77
|
-
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
78
|
-
except KeyError:
|
79
|
-
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
80
|
-
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
81
|
-
|
82
|
-
self.chat_template = chat_template or get_chat_template_by_model_path(
|
83
|
-
model_name
|
84
|
-
)
|
85
|
-
|
86
|
-
if is_chat_model is not None:
|
87
|
-
self.is_chat_model = is_chat_model
|
88
|
-
else:
|
89
|
-
if model_name in INSTRUCT_MODEL_NAMES:
|
90
|
-
self.is_chat_model = False
|
91
|
-
else:
|
92
|
-
self.is_chat_model = True
|
93
|
-
|
94
|
-
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
95
|
-
|
96
|
-
# Usage
|
97
|
-
self.token_usage = TokenUsage(0, 0)
|
98
|
-
|
99
|
-
# API speculative execution
|
100
|
-
# TODO(ying): This does not support multi-threading (run_batch)
|
101
|
-
self.spec_kwargs = {}
|
102
|
-
self.spec_format = []
|
103
|
-
self.spec_max_num_tries = 3
|
104
|
-
|
105
|
-
def get_chat_template(self):
|
106
|
-
return self.chat_template
|
107
|
-
|
108
|
-
def _prepare_spec_execution(
|
109
|
-
self,
|
110
|
-
sampling_params: SglSamplingParams,
|
111
|
-
num_api_spec_tokens: int,
|
112
|
-
spec_var_name: str,
|
113
|
-
):
|
114
|
-
if "max_tokens" not in self.spec_kwargs:
|
115
|
-
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
116
|
-
else:
|
117
|
-
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
118
|
-
|
119
|
-
params = sampling_params.to_openai_kwargs()
|
120
|
-
for key, value in params.items():
|
121
|
-
if key in ["stop"]:
|
122
|
-
continue
|
123
|
-
if key in ["max_tokens"]:
|
124
|
-
warnings.warn(
|
125
|
-
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
126
|
-
)
|
127
|
-
continue
|
128
|
-
if key not in self.spec_kwargs:
|
129
|
-
self.spec_kwargs[key] = value
|
130
|
-
else:
|
131
|
-
assert (
|
132
|
-
value == self.spec_kwargs[key]
|
133
|
-
), "sampling parameters should be consistent if turn on api speculative execution."
|
134
|
-
self.spec_format.append(
|
135
|
-
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
136
|
-
)
|
137
|
-
return "", {}
|
138
|
-
|
139
|
-
def generate(
|
140
|
-
self,
|
141
|
-
s: StreamExecutor,
|
142
|
-
sampling_params: SglSamplingParams,
|
143
|
-
spec_var_name: str = None,
|
144
|
-
):
|
145
|
-
if sampling_params.dtype is None:
|
146
|
-
if self.is_chat_model:
|
147
|
-
if s.num_api_spec_tokens is None:
|
148
|
-
if not s.text_.endswith(self.chat_prefix):
|
149
|
-
raise RuntimeError(
|
150
|
-
"This use case is not supported if api speculative execution is off. "
|
151
|
-
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
152
|
-
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
153
|
-
)
|
154
|
-
prompt = s.messages_
|
155
|
-
else:
|
156
|
-
return self._prepare_spec_execution(
|
157
|
-
sampling_params, s.num_api_spec_tokens, spec_var_name
|
158
|
-
)
|
159
|
-
else:
|
160
|
-
prompt = s.text_
|
161
|
-
|
162
|
-
kwargs = sampling_params.to_openai_kwargs()
|
163
|
-
comp = openai_completion(
|
164
|
-
client=self.client,
|
165
|
-
token_usage=self.token_usage,
|
166
|
-
is_chat=self.is_chat_model,
|
167
|
-
model=self.model_name,
|
168
|
-
prompt=prompt,
|
169
|
-
**kwargs,
|
170
|
-
)
|
171
|
-
elif sampling_params.dtype in [str, "str", "string"]:
|
172
|
-
assert (
|
173
|
-
not self.is_chat_model
|
174
|
-
), "constrained type not supported on chat model"
|
175
|
-
kwargs = sampling_params.to_openai_kwargs()
|
176
|
-
kwargs.pop("stop")
|
177
|
-
comp = openai_completion(
|
178
|
-
client=self.client,
|
179
|
-
token_usage=self.token_usage,
|
180
|
-
is_chat=self.is_chat_model,
|
181
|
-
model=self.model_name,
|
182
|
-
prompt=s.text_ + '"',
|
183
|
-
stop='"',
|
184
|
-
**kwargs,
|
185
|
-
)
|
186
|
-
comp = '"' + comp + '"'
|
187
|
-
elif sampling_params.dtype in [int, "int"]:
|
188
|
-
assert (
|
189
|
-
not self.is_chat_model
|
190
|
-
), "constrained type not supported on chat model"
|
191
|
-
kwargs = sampling_params.to_openai_kwargs()
|
192
|
-
kwargs.pop("stop")
|
193
|
-
comp = openai_completion(
|
194
|
-
client=self.client,
|
195
|
-
token_usage=self.token_usage,
|
196
|
-
is_chat=self.is_chat_model,
|
197
|
-
model=self.model_name,
|
198
|
-
prompt=s.text_,
|
199
|
-
logit_bias=self.logit_bias_int,
|
200
|
-
stop=[" "],
|
201
|
-
**kwargs,
|
202
|
-
)
|
203
|
-
else:
|
204
|
-
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
205
|
-
|
206
|
-
return comp, {}
|
207
|
-
|
208
|
-
def spec_fill(self, value: str):
|
209
|
-
assert self.is_chat_model
|
210
|
-
self.spec_format.append({"text": value, "stop": None, "name": None})
|
211
|
-
|
212
|
-
def spec_pattern_match(self, comp):
|
213
|
-
for i, term in enumerate(self.spec_format):
|
214
|
-
text = term["text"]
|
215
|
-
if text != "":
|
216
|
-
if comp.startswith(text):
|
217
|
-
comp = comp[len(text) :]
|
218
|
-
else:
|
219
|
-
return False
|
220
|
-
else:
|
221
|
-
pos = comp.find(term["stop"])
|
222
|
-
if pos != -1:
|
223
|
-
term["text"] = comp[:pos]
|
224
|
-
comp = comp[pos:]
|
225
|
-
else:
|
226
|
-
if i == len(self.spec_format) - 1:
|
227
|
-
term["text"] = comp
|
228
|
-
else:
|
229
|
-
return False
|
230
|
-
return True
|
231
|
-
|
232
|
-
def role_end_generate(
|
233
|
-
self,
|
234
|
-
s: StreamExecutor,
|
235
|
-
):
|
236
|
-
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
237
|
-
return
|
238
|
-
|
239
|
-
comp = ""
|
240
|
-
if not all(x["name"] is None for x in self.spec_format):
|
241
|
-
# TODO(ying): throw errors or warnings
|
242
|
-
for i in range(self.spec_max_num_tries):
|
243
|
-
comp = openai_completion(
|
244
|
-
client=self.client,
|
245
|
-
token_usage=self.token_usage,
|
246
|
-
is_chat=self.is_chat_model,
|
247
|
-
model=self.model_name,
|
248
|
-
prompt=s.messages_,
|
249
|
-
**self.spec_kwargs,
|
250
|
-
)
|
251
|
-
if self.spec_pattern_match(comp):
|
252
|
-
break
|
253
|
-
|
254
|
-
for term in self.spec_format:
|
255
|
-
s.text_ += term["text"]
|
256
|
-
name = term["name"]
|
257
|
-
if name is not None:
|
258
|
-
s.variables[name] = term["text"]
|
259
|
-
s.meta_info[name] = {}
|
260
|
-
s.variable_event[name].set()
|
261
|
-
|
262
|
-
self.spec_kwargs = {}
|
263
|
-
self.spec_format = []
|
264
|
-
|
265
|
-
def generate_stream(
|
266
|
-
self,
|
267
|
-
s: StreamExecutor,
|
268
|
-
sampling_params: SglSamplingParams,
|
269
|
-
):
|
270
|
-
if sampling_params.dtype is None:
|
271
|
-
if self.is_chat_model:
|
272
|
-
if not s.text_.endswith(self.chat_prefix):
|
273
|
-
raise RuntimeError(
|
274
|
-
"This use case is not supported. "
|
275
|
-
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
276
|
-
)
|
277
|
-
prompt = s.messages_
|
278
|
-
else:
|
279
|
-
prompt = s.text_
|
280
|
-
|
281
|
-
kwargs = sampling_params.to_openai_kwargs()
|
282
|
-
generator = openai_completion_stream(
|
283
|
-
client=self.client,
|
284
|
-
token_usage=self.token_usage,
|
285
|
-
is_chat=self.is_chat_model,
|
286
|
-
model=self.model_name,
|
287
|
-
prompt=prompt,
|
288
|
-
**kwargs,
|
289
|
-
)
|
290
|
-
return generator
|
291
|
-
else:
|
292
|
-
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
293
|
-
|
294
|
-
def select(
|
295
|
-
self,
|
296
|
-
s: StreamExecutor,
|
297
|
-
choices: List[str],
|
298
|
-
temperature: float,
|
299
|
-
):
|
300
|
-
if self.is_chat_model:
|
301
|
-
raise NotImplementedError(
|
302
|
-
"select/choices is not supported for chat models. "
|
303
|
-
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
304
|
-
)
|
305
|
-
|
306
|
-
n_choices = len(choices)
|
307
|
-
token_ids = [self.tokenizer.encode(x) for x in choices]
|
308
|
-
scores = [0] * n_choices
|
309
|
-
valid = [len(x) > 0 for x in token_ids]
|
310
|
-
prompt_tokens = self.tokenizer.encode(s.text_)
|
311
|
-
|
312
|
-
max_len = max([len(x) for x in token_ids])
|
313
|
-
for step in range(max_len):
|
314
|
-
# Build logit bias
|
315
|
-
logit_bias = {}
|
316
|
-
for i in range(n_choices):
|
317
|
-
if valid[i]:
|
318
|
-
logit_bias[token_ids[i][step]] = 100
|
319
|
-
|
320
|
-
# Call API
|
321
|
-
ret = self.client.completions.create(
|
322
|
-
model=self.model_name,
|
323
|
-
prompt=prompt_tokens,
|
324
|
-
logit_bias=logit_bias,
|
325
|
-
max_tokens=1,
|
326
|
-
temperature=temperature,
|
327
|
-
)
|
328
|
-
ret_str = ret.choices[0].text
|
329
|
-
ret_token = self.tokenizer.encode(ret_str)[0]
|
330
|
-
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
331
|
-
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
332
|
-
|
333
|
-
# TODO:
|
334
|
-
# 1. return logits as the scores
|
335
|
-
# 2. compute logits of the full choice
|
336
|
-
# 3. consider chunk-based decoding
|
337
|
-
|
338
|
-
# Update valid
|
339
|
-
hit = False
|
340
|
-
for i in range(n_choices):
|
341
|
-
if valid[i]:
|
342
|
-
if step == len(token_ids[i]) - 1:
|
343
|
-
valid[i] = False
|
344
|
-
|
345
|
-
if ret_token == token_ids[i][step]:
|
346
|
-
scores[i] += 1
|
347
|
-
hit = True
|
348
|
-
else:
|
349
|
-
valid[i] = False
|
350
|
-
assert hit
|
351
|
-
|
352
|
-
if np.sum(valid) <= 1:
|
353
|
-
break
|
354
|
-
|
355
|
-
prompt_tokens.append(ret_token)
|
356
|
-
|
357
|
-
decision = choices[np.argmax(scores)]
|
358
|
-
return decision, scores, None, None
|
359
|
-
|
360
|
-
|
361
|
-
def openai_completion(
|
362
|
-
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
363
|
-
):
|
364
|
-
for attempt in range(retries):
|
365
|
-
try:
|
366
|
-
if is_chat:
|
367
|
-
if "stop" in kwargs and kwargs["stop"] is None:
|
368
|
-
kwargs.pop("stop")
|
369
|
-
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
370
|
-
comp = ret.choices[0].message.content
|
371
|
-
else:
|
372
|
-
ret = client.completions.create(prompt=prompt, **kwargs)
|
373
|
-
if isinstance(prompt, (list, tuple)):
|
374
|
-
comp = [c.text for c in ret.choices]
|
375
|
-
else:
|
376
|
-
comp = ret.choices[0].text
|
377
|
-
|
378
|
-
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
379
|
-
token_usage.completion_tokens += ret.usage.completion_tokens
|
380
|
-
break
|
381
|
-
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
382
|
-
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
383
|
-
time.sleep(5)
|
384
|
-
if attempt == retries - 1:
|
385
|
-
raise e
|
386
|
-
except Exception as e:
|
387
|
-
logger.error(f"RuntimeError {e}.")
|
388
|
-
raise e
|
389
|
-
|
390
|
-
return comp
|
391
|
-
|
392
|
-
|
393
|
-
def openai_completion_stream(
|
394
|
-
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
395
|
-
):
|
396
|
-
for attempt in range(retries):
|
397
|
-
try:
|
398
|
-
if is_chat:
|
399
|
-
if "stop" in kwargs and kwargs["stop"] is None:
|
400
|
-
kwargs.pop("stop")
|
401
|
-
generator = client.chat.completions.create(
|
402
|
-
messages=prompt,
|
403
|
-
stream=True,
|
404
|
-
stream_options={"include_usage": True},
|
405
|
-
**kwargs,
|
406
|
-
)
|
407
|
-
for ret in generator:
|
408
|
-
if len(ret.choices) == 0:
|
409
|
-
continue
|
410
|
-
try:
|
411
|
-
content = ret.choices[0].delta.content
|
412
|
-
except IndexError:
|
413
|
-
content = None
|
414
|
-
yield content or "", {}
|
415
|
-
else:
|
416
|
-
generator = client.completions.create(
|
417
|
-
prompt=prompt,
|
418
|
-
stream=True,
|
419
|
-
stream_options={"include_usage": True},
|
420
|
-
**kwargs,
|
421
|
-
)
|
422
|
-
for ret in generator:
|
423
|
-
if len(ret.choices) == 0:
|
424
|
-
continue
|
425
|
-
content = ret.choices[0].text
|
426
|
-
yield content or "", {}
|
427
|
-
|
428
|
-
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
429
|
-
token_usage.completion_tokens += ret.usage.completion_tokens
|
430
|
-
break
|
431
|
-
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
432
|
-
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
433
|
-
time.sleep(5)
|
434
|
-
if attempt == retries - 1:
|
435
|
-
raise e
|
436
|
-
except Exception as e:
|
437
|
-
logger.error(f"RuntimeError {e}.")
|
438
|
-
raise e
|