sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,61 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.21"
|
2
2
|
|
3
|
-
|
3
|
+
# SGL API Components
|
4
|
+
from sglang.api import (
|
5
|
+
Runtime,
|
6
|
+
assistant,
|
7
|
+
assistant_begin,
|
8
|
+
assistant_end,
|
9
|
+
flush_cache,
|
10
|
+
function,
|
11
|
+
gen,
|
12
|
+
gen_int,
|
13
|
+
gen_string,
|
14
|
+
get_server_args,
|
15
|
+
image,
|
16
|
+
select,
|
17
|
+
set_default_backend,
|
18
|
+
system,
|
19
|
+
user,
|
20
|
+
user_begin,
|
21
|
+
user_end,
|
22
|
+
video,
|
23
|
+
)
|
24
|
+
|
25
|
+
# SGL Backends
|
26
|
+
from sglang.backend.anthropic import Anthropic
|
27
|
+
from sglang.backend.litellm import LiteLLM
|
28
|
+
from sglang.backend.openai import OpenAI
|
29
|
+
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
30
|
+
from sglang.backend.vertexai import VertexAI
|
31
|
+
|
32
|
+
# Global Configurations
|
4
33
|
from sglang.global_config import global_config
|
34
|
+
|
35
|
+
# public APIs management
|
36
|
+
__all__ = [
|
37
|
+
"global_config",
|
38
|
+
"Anthropic",
|
39
|
+
"LiteLLM",
|
40
|
+
"OpenAI",
|
41
|
+
"RuntimeEndpoint",
|
42
|
+
"VertexAI",
|
43
|
+
"function",
|
44
|
+
"Runtime",
|
45
|
+
"set_default_backend",
|
46
|
+
"flush_cache",
|
47
|
+
"get_server_args",
|
48
|
+
"gen",
|
49
|
+
"gen_int",
|
50
|
+
"gen_string",
|
51
|
+
"image",
|
52
|
+
"video",
|
53
|
+
"select",
|
54
|
+
"system",
|
55
|
+
"user",
|
56
|
+
"assistant",
|
57
|
+
"user_begin",
|
58
|
+
"user_end",
|
59
|
+
"assistant_begin",
|
60
|
+
"assistant_end",
|
61
|
+
]
|
sglang/api.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1
|
-
"""Public
|
1
|
+
"""Public APIs of the language."""
|
2
2
|
|
3
|
+
import os
|
3
4
|
import re
|
4
5
|
from typing import Callable, List, Optional, Union
|
5
6
|
|
6
|
-
from sglang.backend.anthropic import Anthropic
|
7
7
|
from sglang.backend.base_backend import BaseBackend
|
8
|
-
from sglang.backend.openai import OpenAI
|
9
|
-
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
10
|
-
from sglang.backend.vertexai import VertexAI
|
11
8
|
from sglang.global_config import global_config
|
12
9
|
from sglang.lang.ir import (
|
13
10
|
SglExpr,
|
@@ -18,23 +15,25 @@ from sglang.lang.ir import (
|
|
18
15
|
SglRoleBegin,
|
19
16
|
SglRoleEnd,
|
20
17
|
SglSelect,
|
18
|
+
SglVideo,
|
21
19
|
)
|
22
20
|
|
23
21
|
|
24
22
|
def function(
|
25
|
-
func: Optional[Callable] = None,
|
23
|
+
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
26
24
|
):
|
27
25
|
if func:
|
28
|
-
return SglFunction(func,
|
26
|
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
29
27
|
|
30
28
|
def decorator(func):
|
31
|
-
return SglFunction(func,
|
29
|
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
32
30
|
|
33
31
|
return decorator
|
34
32
|
|
35
33
|
|
36
34
|
def Runtime(*args, **kwargs):
|
37
35
|
# Avoid importing unnecessary dependency
|
36
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
38
37
|
from sglang.srt.server import Runtime
|
39
38
|
|
40
39
|
return Runtime(*args, **kwargs)
|
@@ -44,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
|
|
44
43
|
global_config.default_backend = backend
|
45
44
|
|
46
45
|
|
47
|
-
def flush_cache(backend: BaseBackend = None):
|
46
|
+
def flush_cache(backend: Optional[BaseBackend] = None):
|
48
47
|
backend = backend or global_config.default_backend
|
49
48
|
if backend is None:
|
50
49
|
return False
|
51
50
|
return backend.flush_cache()
|
52
51
|
|
53
52
|
|
54
|
-
def get_server_args(backend: BaseBackend = None):
|
53
|
+
def get_server_args(backend: Optional[BaseBackend] = None):
|
55
54
|
backend = backend or global_config.default_backend
|
56
55
|
if backend is None:
|
57
56
|
return None
|
@@ -68,10 +67,16 @@ def gen(
|
|
68
67
|
frequency_penalty: Optional[float] = None,
|
69
68
|
presence_penalty: Optional[float] = None,
|
70
69
|
ignore_eos: Optional[bool] = None,
|
70
|
+
return_logprob: Optional[bool] = None,
|
71
|
+
logprob_start_len: Optional[int] = None,
|
72
|
+
top_logprobs_num: Optional[int] = None,
|
73
|
+
return_text_in_logprobs: Optional[bool] = None,
|
71
74
|
dtype: Optional[type] = None,
|
72
75
|
choices: Optional[List[str]] = None,
|
73
76
|
regex: Optional[str] = None,
|
74
77
|
):
|
78
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
79
|
+
|
75
80
|
if choices:
|
76
81
|
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
77
82
|
|
@@ -92,6 +97,10 @@ def gen(
|
|
92
97
|
frequency_penalty,
|
93
98
|
presence_penalty,
|
94
99
|
ignore_eos,
|
100
|
+
return_logprob,
|
101
|
+
logprob_start_len,
|
102
|
+
top_logprobs_num,
|
103
|
+
return_text_in_logprobs,
|
95
104
|
dtype,
|
96
105
|
regex,
|
97
106
|
)
|
@@ -107,6 +116,10 @@ def gen_int(
|
|
107
116
|
frequency_penalty: Optional[float] = None,
|
108
117
|
presence_penalty: Optional[float] = None,
|
109
118
|
ignore_eos: Optional[bool] = None,
|
119
|
+
return_logprob: Optional[bool] = None,
|
120
|
+
logprob_start_len: Optional[int] = None,
|
121
|
+
top_logprobs_num: Optional[int] = None,
|
122
|
+
return_text_in_logprobs: Optional[bool] = None,
|
110
123
|
):
|
111
124
|
return SglGen(
|
112
125
|
name,
|
@@ -118,6 +131,10 @@ def gen_int(
|
|
118
131
|
frequency_penalty,
|
119
132
|
presence_penalty,
|
120
133
|
ignore_eos,
|
134
|
+
return_logprob,
|
135
|
+
logprob_start_len,
|
136
|
+
top_logprobs_num,
|
137
|
+
return_text_in_logprobs,
|
121
138
|
int,
|
122
139
|
None,
|
123
140
|
)
|
@@ -133,6 +150,10 @@ def gen_string(
|
|
133
150
|
frequency_penalty: Optional[float] = None,
|
134
151
|
presence_penalty: Optional[float] = None,
|
135
152
|
ignore_eos: Optional[bool] = None,
|
153
|
+
return_logprob: Optional[bool] = None,
|
154
|
+
logprob_start_len: Optional[int] = None,
|
155
|
+
top_logprobs_num: Optional[int] = None,
|
156
|
+
return_text_in_logprobs: Optional[bool] = None,
|
136
157
|
):
|
137
158
|
return SglGen(
|
138
159
|
name,
|
@@ -144,6 +165,10 @@ def gen_string(
|
|
144
165
|
frequency_penalty,
|
145
166
|
presence_penalty,
|
146
167
|
ignore_eos,
|
168
|
+
return_logprob,
|
169
|
+
logprob_start_len,
|
170
|
+
top_logprobs_num,
|
171
|
+
return_text_in_logprobs,
|
147
172
|
str,
|
148
173
|
None,
|
149
174
|
)
|
@@ -153,9 +178,13 @@ def image(expr: SglExpr):
|
|
153
178
|
return SglImage(expr)
|
154
179
|
|
155
180
|
|
181
|
+
def video(path: str, num_frames: int):
|
182
|
+
return SglVideo(path, num_frames)
|
183
|
+
|
184
|
+
|
156
185
|
def select(
|
157
186
|
name: Optional[str] = None,
|
158
|
-
choices: List[str] = None,
|
187
|
+
choices: Optional[List[str]] = None,
|
159
188
|
temperature: float = 0.0,
|
160
189
|
):
|
161
190
|
assert choices is not None
|
sglang/backend/anthropic.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import List, Optional, Union
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
|
4
5
|
from sglang.backend.base_backend import BaseBackend
|
5
6
|
from sglang.lang.chat_template import get_chat_template
|
6
7
|
from sglang.lang.interpreter import StreamExecutor
|
@@ -13,7 +14,7 @@ except ImportError as e:
|
|
13
14
|
|
14
15
|
|
15
16
|
class Anthropic(BaseBackend):
|
16
|
-
def __init__(self, model_name):
|
17
|
+
def __init__(self, model_name, *args, **kwargs):
|
17
18
|
super().__init__()
|
18
19
|
|
19
20
|
if isinstance(anthropic, Exception):
|
@@ -21,6 +22,7 @@ class Anthropic(BaseBackend):
|
|
21
22
|
|
22
23
|
self.model_name = model_name
|
23
24
|
self.chat_template = get_chat_template("claude")
|
25
|
+
self.client = anthropic.Anthropic(*args, **kwargs)
|
24
26
|
|
25
27
|
def get_chat_template(self):
|
26
28
|
return self.chat_template
|
@@ -35,8 +37,14 @@ class Anthropic(BaseBackend):
|
|
35
37
|
else:
|
36
38
|
messages = [{"role": "user", "content": s.text_}]
|
37
39
|
|
38
|
-
|
40
|
+
if messages and messages[0]["role"] == "system":
|
41
|
+
system = messages.pop(0)["content"]
|
42
|
+
else:
|
43
|
+
system = ""
|
44
|
+
|
45
|
+
ret = self.client.messages.create(
|
39
46
|
model=self.model_name,
|
47
|
+
system=system,
|
40
48
|
messages=messages,
|
41
49
|
**sampling_params.to_anthropic_kwargs(),
|
42
50
|
)
|
@@ -54,8 +62,14 @@ class Anthropic(BaseBackend):
|
|
54
62
|
else:
|
55
63
|
messages = [{"role": "user", "content": s.text_}]
|
56
64
|
|
57
|
-
|
65
|
+
if messages and messages[0]["role"] == "system":
|
66
|
+
system = messages.pop(0)["content"]
|
67
|
+
else:
|
68
|
+
system = ""
|
69
|
+
|
70
|
+
with self.client.messages.stream(
|
58
71
|
model=self.model_name,
|
72
|
+
system=system,
|
59
73
|
messages=messages,
|
60
74
|
**sampling_params.to_anthropic_kwargs(),
|
61
75
|
) as stream:
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from typing import Mapping, Optional
|
2
|
+
|
3
|
+
from sglang.backend.base_backend import BaseBackend
|
4
|
+
from sglang.lang.chat_template import get_chat_template_by_model_path
|
5
|
+
from sglang.lang.interpreter import StreamExecutor
|
6
|
+
from sglang.lang.ir import SglSamplingParams
|
7
|
+
|
8
|
+
try:
|
9
|
+
import litellm
|
10
|
+
except ImportError as e:
|
11
|
+
litellm = e
|
12
|
+
litellm.num_retries = 1
|
13
|
+
|
14
|
+
|
15
|
+
class LiteLLM(BaseBackend):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
model_name,
|
19
|
+
chat_template=None,
|
20
|
+
api_key=None,
|
21
|
+
organization: Optional[str] = None,
|
22
|
+
base_url: Optional[str] = None,
|
23
|
+
timeout: Optional[float] = 600,
|
24
|
+
max_retries: Optional[int] = litellm.num_retries,
|
25
|
+
default_headers: Optional[Mapping[str, str]] = None,
|
26
|
+
):
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
if isinstance(litellm, Exception):
|
30
|
+
raise litellm
|
31
|
+
|
32
|
+
self.model_name = model_name
|
33
|
+
|
34
|
+
self.chat_template = chat_template or get_chat_template_by_model_path(
|
35
|
+
model_name
|
36
|
+
)
|
37
|
+
|
38
|
+
self.client_params = {
|
39
|
+
"api_key": api_key,
|
40
|
+
"organization": organization,
|
41
|
+
"base_url": base_url,
|
42
|
+
"timeout": timeout,
|
43
|
+
"max_retries": max_retries,
|
44
|
+
"default_headers": default_headers,
|
45
|
+
}
|
46
|
+
|
47
|
+
def get_chat_template(self):
|
48
|
+
return self.chat_template
|
49
|
+
|
50
|
+
def generate(
|
51
|
+
self,
|
52
|
+
s: StreamExecutor,
|
53
|
+
sampling_params: SglSamplingParams,
|
54
|
+
):
|
55
|
+
if s.messages_:
|
56
|
+
messages = s.messages_
|
57
|
+
else:
|
58
|
+
messages = [{"role": "user", "content": s.text_}]
|
59
|
+
|
60
|
+
ret = litellm.completion(
|
61
|
+
model=self.model_name,
|
62
|
+
messages=messages,
|
63
|
+
**self.client_params,
|
64
|
+
**sampling_params.to_anthropic_kwargs(),
|
65
|
+
)
|
66
|
+
comp = ret.choices[0].message.content
|
67
|
+
|
68
|
+
return comp, {}
|
69
|
+
|
70
|
+
def generate_stream(
|
71
|
+
self,
|
72
|
+
s: StreamExecutor,
|
73
|
+
sampling_params: SglSamplingParams,
|
74
|
+
):
|
75
|
+
if s.messages_:
|
76
|
+
messages = s.messages_
|
77
|
+
else:
|
78
|
+
messages = [{"role": "user", "content": s.text_}]
|
79
|
+
|
80
|
+
ret = litellm.completion(
|
81
|
+
model=self.model_name,
|
82
|
+
messages=messages,
|
83
|
+
stream=True,
|
84
|
+
**self.client_params,
|
85
|
+
**sampling_params.to_litellm_kwargs(),
|
86
|
+
)
|
87
|
+
for chunk in ret:
|
88
|
+
text = chunk.choices[0].delta.content
|
89
|
+
if text is not None:
|
90
|
+
yield text, {}
|
sglang/backend/openai.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
2
3
|
import time
|
4
|
+
import warnings
|
3
5
|
from typing import Callable, List, Optional, Union
|
4
6
|
|
5
7
|
import numpy as np
|
8
|
+
|
6
9
|
from sglang.backend.base_backend import BaseBackend
|
7
10
|
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
8
11
|
from sglang.lang.interpreter import StreamExecutor
|
@@ -40,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
|
|
40
43
|
]
|
41
44
|
|
42
45
|
|
46
|
+
@dataclasses.dataclass
|
47
|
+
class TokenUsage:
|
48
|
+
prompt_tokens: int
|
49
|
+
completion_tokens: int
|
50
|
+
|
51
|
+
def reset(self):
|
52
|
+
self.prompt_tokens = self.completion_tokens = 0
|
53
|
+
|
54
|
+
|
43
55
|
class OpenAI(BaseBackend):
|
44
56
|
def __init__(
|
45
57
|
self,
|
@@ -79,40 +91,92 @@ class OpenAI(BaseBackend):
|
|
79
91
|
else:
|
80
92
|
self.is_chat_model = True
|
81
93
|
|
82
|
-
self.
|
94
|
+
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
95
|
+
|
96
|
+
# Usage
|
97
|
+
self.token_usage = TokenUsage(0, 0)
|
98
|
+
|
99
|
+
# API speculative execution
|
100
|
+
# TODO(ying): This does not support multi-threading (run_batch)
|
101
|
+
self.spec_kwargs = {}
|
102
|
+
self.spec_format = []
|
103
|
+
self.spec_max_num_tries = 3
|
83
104
|
|
84
105
|
def get_chat_template(self):
|
85
106
|
return self.chat_template
|
86
107
|
|
108
|
+
def _prepare_spec_execution(
|
109
|
+
self,
|
110
|
+
sampling_params: SglSamplingParams,
|
111
|
+
num_api_spec_tokens: int,
|
112
|
+
spec_var_name: str,
|
113
|
+
):
|
114
|
+
if "max_tokens" not in self.spec_kwargs:
|
115
|
+
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
116
|
+
else:
|
117
|
+
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
118
|
+
|
119
|
+
params = sampling_params.to_openai_kwargs()
|
120
|
+
for key, value in params.items():
|
121
|
+
if key in ["stop"]:
|
122
|
+
continue
|
123
|
+
if key in ["max_tokens"]:
|
124
|
+
warnings.warn(
|
125
|
+
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
126
|
+
)
|
127
|
+
continue
|
128
|
+
if key not in self.spec_kwargs:
|
129
|
+
self.spec_kwargs[key] = value
|
130
|
+
else:
|
131
|
+
assert (
|
132
|
+
value == self.spec_kwargs[key]
|
133
|
+
), "sampling parameters should be consistent if turn on api speculative execution."
|
134
|
+
self.spec_format.append(
|
135
|
+
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
136
|
+
)
|
137
|
+
return "", {}
|
138
|
+
|
87
139
|
def generate(
|
88
140
|
self,
|
89
141
|
s: StreamExecutor,
|
90
142
|
sampling_params: SglSamplingParams,
|
143
|
+
spec_var_name: str = None,
|
91
144
|
):
|
92
145
|
if sampling_params.dtype is None:
|
93
146
|
if self.is_chat_model:
|
94
|
-
if
|
95
|
-
|
96
|
-
|
97
|
-
|
147
|
+
if s.num_api_spec_tokens is None:
|
148
|
+
if not s.text_.endswith(self.chat_prefix):
|
149
|
+
raise RuntimeError(
|
150
|
+
"This use case is not supported if api speculative execution is off. "
|
151
|
+
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
152
|
+
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
153
|
+
)
|
154
|
+
prompt = s.messages_
|
155
|
+
else:
|
156
|
+
return self._prepare_spec_execution(
|
157
|
+
sampling_params, s.num_api_spec_tokens, spec_var_name
|
98
158
|
)
|
99
|
-
prompt = s.messages_
|
100
159
|
else:
|
101
160
|
prompt = s.text_
|
102
161
|
|
103
162
|
kwargs = sampling_params.to_openai_kwargs()
|
104
163
|
comp = openai_completion(
|
105
164
|
client=self.client,
|
165
|
+
token_usage=self.token_usage,
|
106
166
|
is_chat=self.is_chat_model,
|
107
167
|
model=self.model_name,
|
108
168
|
prompt=prompt,
|
109
169
|
**kwargs,
|
110
170
|
)
|
111
171
|
elif sampling_params.dtype in [str, "str", "string"]:
|
172
|
+
assert (
|
173
|
+
not self.is_chat_model
|
174
|
+
), "constrained type not supported on chat model"
|
112
175
|
kwargs = sampling_params.to_openai_kwargs()
|
113
176
|
kwargs.pop("stop")
|
114
177
|
comp = openai_completion(
|
115
178
|
client=self.client,
|
179
|
+
token_usage=self.token_usage,
|
116
180
|
is_chat=self.is_chat_model,
|
117
181
|
model=self.model_name,
|
118
182
|
prompt=s.text_ + '"',
|
@@ -121,10 +185,14 @@ class OpenAI(BaseBackend):
|
|
121
185
|
)
|
122
186
|
comp = '"' + comp + '"'
|
123
187
|
elif sampling_params.dtype in [int, "int"]:
|
188
|
+
assert (
|
189
|
+
not self.is_chat_model
|
190
|
+
), "constrained type not supported on chat model"
|
124
191
|
kwargs = sampling_params.to_openai_kwargs()
|
125
192
|
kwargs.pop("stop")
|
126
193
|
comp = openai_completion(
|
127
194
|
client=self.client,
|
195
|
+
token_usage=self.token_usage,
|
128
196
|
is_chat=self.is_chat_model,
|
129
197
|
model=self.model_name,
|
130
198
|
prompt=s.text_,
|
@@ -137,6 +205,63 @@ class OpenAI(BaseBackend):
|
|
137
205
|
|
138
206
|
return comp, {}
|
139
207
|
|
208
|
+
def spec_fill(self, value: str):
|
209
|
+
assert self.is_chat_model
|
210
|
+
self.spec_format.append({"text": value, "stop": None, "name": None})
|
211
|
+
|
212
|
+
def spec_pattern_match(self, comp):
|
213
|
+
for i, term in enumerate(self.spec_format):
|
214
|
+
text = term["text"]
|
215
|
+
if text != "":
|
216
|
+
if comp.startswith(text):
|
217
|
+
comp = comp[len(text) :]
|
218
|
+
else:
|
219
|
+
return False
|
220
|
+
else:
|
221
|
+
pos = comp.find(term["stop"])
|
222
|
+
if pos != -1:
|
223
|
+
term["text"] = comp[:pos]
|
224
|
+
comp = comp[pos:]
|
225
|
+
else:
|
226
|
+
if i == len(self.spec_format) - 1:
|
227
|
+
term["text"] = comp
|
228
|
+
else:
|
229
|
+
return False
|
230
|
+
return True
|
231
|
+
|
232
|
+
def role_end_generate(
|
233
|
+
self,
|
234
|
+
s: StreamExecutor,
|
235
|
+
):
|
236
|
+
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
237
|
+
return
|
238
|
+
|
239
|
+
comp = ""
|
240
|
+
if not all(x["name"] is None for x in self.spec_format):
|
241
|
+
# TODO(ying): throw errors or warnings
|
242
|
+
for i in range(self.spec_max_num_tries):
|
243
|
+
comp = openai_completion(
|
244
|
+
client=self.client,
|
245
|
+
token_usage=self.token_usage,
|
246
|
+
is_chat=self.is_chat_model,
|
247
|
+
model=self.model_name,
|
248
|
+
prompt=s.messages_,
|
249
|
+
**self.spec_kwargs,
|
250
|
+
)
|
251
|
+
if self.spec_pattern_match(comp):
|
252
|
+
break
|
253
|
+
|
254
|
+
for term in self.spec_format:
|
255
|
+
s.text_ += term["text"]
|
256
|
+
name = term["name"]
|
257
|
+
if name is not None:
|
258
|
+
s.variables[name] = term["text"]
|
259
|
+
s.meta_info[name] = {}
|
260
|
+
s.variable_event[name].set()
|
261
|
+
|
262
|
+
self.spec_kwargs = {}
|
263
|
+
self.spec_format = []
|
264
|
+
|
140
265
|
def generate_stream(
|
141
266
|
self,
|
142
267
|
s: StreamExecutor,
|
@@ -144,7 +269,7 @@ class OpenAI(BaseBackend):
|
|
144
269
|
):
|
145
270
|
if sampling_params.dtype is None:
|
146
271
|
if self.is_chat_model:
|
147
|
-
if not s.text_.endswith(self.
|
272
|
+
if not s.text_.endswith(self.chat_prefix):
|
148
273
|
raise RuntimeError(
|
149
274
|
"This use case is not supported. "
|
150
275
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
@@ -156,6 +281,7 @@ class OpenAI(BaseBackend):
|
|
156
281
|
kwargs = sampling_params.to_openai_kwargs()
|
157
282
|
generator = openai_completion_stream(
|
158
283
|
client=self.client,
|
284
|
+
token_usage=self.token_usage,
|
159
285
|
is_chat=self.is_chat_model,
|
160
286
|
model=self.model_name,
|
161
287
|
prompt=prompt,
|
@@ -201,6 +327,8 @@ class OpenAI(BaseBackend):
|
|
201
327
|
)
|
202
328
|
ret_str = ret.choices[0].text
|
203
329
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
330
|
+
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
331
|
+
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
204
332
|
|
205
333
|
# TODO:
|
206
334
|
# 1. return logits as the scores
|
@@ -227,10 +355,12 @@ class OpenAI(BaseBackend):
|
|
227
355
|
prompt_tokens.append(ret_token)
|
228
356
|
|
229
357
|
decision = choices[np.argmax(scores)]
|
230
|
-
return decision, scores,
|
358
|
+
return decision, scores, None, None
|
231
359
|
|
232
360
|
|
233
|
-
def openai_completion(
|
361
|
+
def openai_completion(
|
362
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
363
|
+
):
|
234
364
|
for attempt in range(retries):
|
235
365
|
try:
|
236
366
|
if is_chat:
|
@@ -244,6 +374,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|
244
374
|
comp = [c.text for c in ret.choices]
|
245
375
|
else:
|
246
376
|
comp = ret.choices[0].text
|
377
|
+
|
378
|
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
379
|
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
247
380
|
break
|
248
381
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
249
382
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
@@ -257,16 +390,23 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|
257
390
|
return comp
|
258
391
|
|
259
392
|
|
260
|
-
def openai_completion_stream(
|
393
|
+
def openai_completion_stream(
|
394
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
395
|
+
):
|
261
396
|
for attempt in range(retries):
|
262
397
|
try:
|
263
398
|
if is_chat:
|
264
399
|
if "stop" in kwargs and kwargs["stop"] is None:
|
265
400
|
kwargs.pop("stop")
|
266
401
|
generator = client.chat.completions.create(
|
267
|
-
messages=prompt,
|
402
|
+
messages=prompt,
|
403
|
+
stream=True,
|
404
|
+
stream_options={"include_usage": True},
|
405
|
+
**kwargs,
|
268
406
|
)
|
269
407
|
for ret in generator:
|
408
|
+
if len(ret.choices) == 0:
|
409
|
+
continue
|
270
410
|
try:
|
271
411
|
content = ret.choices[0].delta.content
|
272
412
|
except IndexError:
|
@@ -274,11 +414,19 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
|
274
414
|
yield content or "", {}
|
275
415
|
else:
|
276
416
|
generator = client.completions.create(
|
277
|
-
prompt=prompt,
|
417
|
+
prompt=prompt,
|
418
|
+
stream=True,
|
419
|
+
stream_options={"include_usage": True},
|
420
|
+
**kwargs,
|
278
421
|
)
|
279
422
|
for ret in generator:
|
423
|
+
if len(ret.choices) == 0:
|
424
|
+
continue
|
280
425
|
content = ret.choices[0].text
|
281
426
|
yield content or "", {}
|
427
|
+
|
428
|
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
429
|
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
282
430
|
break
|
283
431
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
284
432
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|