sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +14 -0
- sglang/backend/anthropic.py +18 -12
- sglang/backend/base_backend.py +6 -0
- sglang/backend/openai.py +41 -12
- sglang/backend/runtime_endpoint.py +57 -6
- sglang/lang/chat_template.py +47 -26
- sglang/lang/interpreter.py +15 -2
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +23 -1
- sglang/srt/constrained/fsm_cache.py +14 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -1
- sglang/srt/layers/extend_attention.py +7 -6
- sglang/srt/layers/radix_attention.py +2 -10
- sglang/srt/layers/token_attention.py +12 -4
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/router/infer_batch.py +6 -2
- sglang/srt/managers/router/model_rpc.py +45 -32
- sglang/srt/managers/router/model_runner.py +40 -25
- sglang/srt/managers/tokenizer_manager.py +2 -0
- sglang/srt/model_config.py +12 -5
- sglang/srt/models/gemma.py +340 -0
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llava.py +2 -4
- sglang/srt/models/mixtral.py +5 -5
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +5 -5
- sglang/srt/models/stablelm.py +293 -0
- sglang/srt/server.py +111 -47
- sglang/srt/server_args.py +44 -9
- sglang/srt/utils.py +1 -0
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +15 -12
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
- sglang-0.1.14.dist-info/RECORD +64 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
- sglang/srt/models/gpt_neox.py +0 -274
- sglang-0.1.12.dist-info/RECORD +0 -63
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
sglang/api.py
CHANGED
@@ -44,6 +44,20 @@ def set_default_backend(backend: BaseBackend):
|
|
44
44
|
global_config.default_backend = backend
|
45
45
|
|
46
46
|
|
47
|
+
def flush_cache(backend: BaseBackend = None):
|
48
|
+
backend = backend or global_config.default_backend
|
49
|
+
if backend is None:
|
50
|
+
return False
|
51
|
+
return backend.flush_cache()
|
52
|
+
|
53
|
+
|
54
|
+
def get_server_args(backend: BaseBackend = None):
|
55
|
+
backend = backend or global_config.default_backend
|
56
|
+
if backend is None:
|
57
|
+
return None
|
58
|
+
return backend.get_server_args()
|
59
|
+
|
60
|
+
|
47
61
|
def gen(
|
48
62
|
name: Optional[str] = None,
|
49
63
|
max_tokens: Optional[int] = None,
|
sglang/backend/anthropic.py
CHANGED
@@ -30,13 +30,17 @@ class Anthropic(BaseBackend):
|
|
30
30
|
s: StreamExecutor,
|
31
31
|
sampling_params: SglSamplingParams,
|
32
32
|
):
|
33
|
-
|
34
|
-
|
33
|
+
if s.messages_:
|
34
|
+
messages = s.messages_
|
35
|
+
else:
|
36
|
+
messages = [{"role": "user", "content": s.text_}]
|
37
|
+
|
38
|
+
ret = anthropic.Anthropic().messages.create(
|
35
39
|
model=self.model_name,
|
36
|
-
|
40
|
+
messages=messages,
|
37
41
|
**sampling_params.to_anthropic_kwargs(),
|
38
42
|
)
|
39
|
-
comp = ret.
|
43
|
+
comp = ret.content[0].text
|
40
44
|
|
41
45
|
return comp, {}
|
42
46
|
|
@@ -45,13 +49,15 @@ class Anthropic(BaseBackend):
|
|
45
49
|
s: StreamExecutor,
|
46
50
|
sampling_params: SglSamplingParams,
|
47
51
|
):
|
48
|
-
|
49
|
-
|
52
|
+
if s.messages_:
|
53
|
+
messages = s.messages_
|
54
|
+
else:
|
55
|
+
messages = [{"role": "user", "content": s.text_}]
|
56
|
+
|
57
|
+
with anthropic.Anthropic().messages.stream(
|
50
58
|
model=self.model_name,
|
51
|
-
|
52
|
-
stream=True,
|
59
|
+
messages=messages,
|
53
60
|
**sampling_params.to_anthropic_kwargs(),
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
|
-
yield ret.completion, {}
|
61
|
+
) as stream:
|
62
|
+
for text in stream.text_stream:
|
63
|
+
yield text, {}
|
sglang/backend/base_backend.py
CHANGED
sglang/backend/openai.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
from sglang.backend.base_backend import BaseBackend
|
7
|
-
from sglang.lang.chat_template import
|
7
|
+
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
8
8
|
from sglang.lang.interpreter import StreamExecutor
|
9
9
|
from sglang.lang.ir import SglSamplingParams
|
10
10
|
|
@@ -41,23 +41,45 @@ INSTRUCT_MODEL_NAMES = [
|
|
41
41
|
|
42
42
|
|
43
43
|
class OpenAI(BaseBackend):
|
44
|
-
def __init__(
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
model_name: str,
|
47
|
+
is_chat_model: Optional[bool] = None,
|
48
|
+
chat_template: Optional[ChatTemplate] = None,
|
49
|
+
is_azure: bool = False,
|
50
|
+
*args,
|
51
|
+
**kwargs,
|
52
|
+
):
|
45
53
|
super().__init__()
|
46
54
|
|
47
55
|
if isinstance(openai, Exception):
|
48
56
|
raise openai
|
49
57
|
|
50
|
-
|
58
|
+
if is_azure:
|
59
|
+
self.client = openai.AzureOpenAI(*args, **kwargs)
|
60
|
+
else:
|
61
|
+
self.client = openai.OpenAI(*args, **kwargs)
|
62
|
+
|
51
63
|
self.model_name = model_name
|
52
|
-
|
64
|
+
try:
|
65
|
+
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
66
|
+
except KeyError:
|
67
|
+
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
53
68
|
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
54
69
|
|
55
|
-
|
56
|
-
|
70
|
+
self.chat_template = chat_template or get_chat_template_by_model_path(
|
71
|
+
model_name
|
72
|
+
)
|
73
|
+
|
74
|
+
if is_chat_model is not None:
|
75
|
+
self.is_chat_model = is_chat_model
|
57
76
|
else:
|
58
|
-
|
77
|
+
if model_name in INSTRUCT_MODEL_NAMES:
|
78
|
+
self.is_chat_model = False
|
79
|
+
else:
|
80
|
+
self.is_chat_model = True
|
59
81
|
|
60
|
-
self.
|
82
|
+
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
61
83
|
|
62
84
|
def get_chat_template(self):
|
63
85
|
return self.chat_template
|
@@ -69,7 +91,7 @@ class OpenAI(BaseBackend):
|
|
69
91
|
):
|
70
92
|
if sampling_params.dtype is None:
|
71
93
|
if self.is_chat_model:
|
72
|
-
if not s.text_.endswith(
|
94
|
+
if not s.text_.endswith(self.chat_begin_str):
|
73
95
|
raise RuntimeError(
|
74
96
|
"This use case is not supported. "
|
75
97
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
@@ -122,7 +144,11 @@ class OpenAI(BaseBackend):
|
|
122
144
|
):
|
123
145
|
if sampling_params.dtype is None:
|
124
146
|
if self.is_chat_model:
|
125
|
-
|
147
|
+
if not s.text_.endswith(self.chat_begin_str):
|
148
|
+
raise RuntimeError(
|
149
|
+
"This use case is not supported. "
|
150
|
+
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
151
|
+
)
|
126
152
|
prompt = s.messages_
|
127
153
|
else:
|
128
154
|
prompt = s.text_
|
@@ -137,7 +163,7 @@ class OpenAI(BaseBackend):
|
|
137
163
|
)
|
138
164
|
return generator
|
139
165
|
else:
|
140
|
-
raise ValueError(f"Unknown dtype: {dtype}")
|
166
|
+
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
141
167
|
|
142
168
|
def select(
|
143
169
|
self,
|
@@ -241,7 +267,10 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
|
241
267
|
messages=prompt, stream=True, **kwargs
|
242
268
|
)
|
243
269
|
for ret in generator:
|
244
|
-
|
270
|
+
try:
|
271
|
+
content = ret.choices[0].delta.content
|
272
|
+
except IndexError:
|
273
|
+
content = None
|
245
274
|
yield content or "", {}
|
246
275
|
else:
|
247
276
|
generator = client.completions.create(
|
@@ -12,15 +12,26 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
|
|
12
12
|
|
13
13
|
|
14
14
|
class RuntimeEndpoint(BaseBackend):
|
15
|
-
def __init__(
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
base_url: str,
|
18
|
+
auth_token: Optional[str] = None,
|
19
|
+
api_key: Optional[str] = None,
|
20
|
+
verify: Optional[str] = None,
|
21
|
+
):
|
16
22
|
super().__init__()
|
17
23
|
self.support_concate_and_append = True
|
18
24
|
|
19
25
|
self.base_url = base_url
|
20
26
|
self.auth_token = auth_token
|
27
|
+
self.api_key = api_key
|
28
|
+
self.verify = verify
|
21
29
|
|
22
30
|
res = http_request(
|
23
|
-
self.base_url + "/get_model_info",
|
31
|
+
self.base_url + "/get_model_info",
|
32
|
+
auth_token=self.auth_token,
|
33
|
+
api_key=self.api_key,
|
34
|
+
verify=self.verify,
|
24
35
|
)
|
25
36
|
assert res.status_code == 200
|
26
37
|
self.model_info = res.json()
|
@@ -32,6 +43,22 @@ class RuntimeEndpoint(BaseBackend):
|
|
32
43
|
def get_model_name(self):
|
33
44
|
return self.model_info["model_path"]
|
34
45
|
|
46
|
+
def flush_cache(self):
|
47
|
+
res = http_request(
|
48
|
+
self.base_url + "/flush_cache",
|
49
|
+
auth_token=self.auth_token,
|
50
|
+
verify=self.verify,
|
51
|
+
)
|
52
|
+
return res.status_code == 200
|
53
|
+
|
54
|
+
def get_server_args(self):
|
55
|
+
res = http_request(
|
56
|
+
self.base_url + "/get_server_args",
|
57
|
+
auth_token=self.auth_token,
|
58
|
+
verify=self.verify,
|
59
|
+
)
|
60
|
+
return res.json()
|
61
|
+
|
35
62
|
def get_chat_template(self):
|
36
63
|
return self.chat_template
|
37
64
|
|
@@ -40,6 +67,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
40
67
|
self.base_url + "/generate",
|
41
68
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
42
69
|
auth_token=self.auth_token,
|
70
|
+
api_key=self.api_key,
|
71
|
+
verify=self.verify,
|
43
72
|
)
|
44
73
|
assert res.status_code == 200
|
45
74
|
|
@@ -48,6 +77,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
48
77
|
self.base_url + "/generate",
|
49
78
|
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
50
79
|
auth_token=self.auth_token,
|
80
|
+
api_key=self.api_key,
|
81
|
+
verify=self.verify,
|
51
82
|
)
|
52
83
|
assert res.status_code == 200
|
53
84
|
|
@@ -55,7 +86,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
55
86
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
56
87
|
self._add_images(s, data)
|
57
88
|
res = http_request(
|
58
|
-
self.base_url + "/generate",
|
89
|
+
self.base_url + "/generate",
|
90
|
+
json=data,
|
91
|
+
auth_token=self.auth_token,
|
92
|
+
api_key=self.api_key,
|
93
|
+
verify=self.verify,
|
59
94
|
)
|
60
95
|
assert res.status_code == 200
|
61
96
|
|
@@ -87,7 +122,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
87
122
|
self._add_images(s, data)
|
88
123
|
|
89
124
|
res = http_request(
|
90
|
-
self.base_url + "/generate",
|
125
|
+
self.base_url + "/generate",
|
126
|
+
json=data,
|
127
|
+
auth_token=self.auth_token,
|
128
|
+
api_key=self.api_key,
|
129
|
+
verify=self.verify,
|
91
130
|
)
|
92
131
|
obj = res.json()
|
93
132
|
comp = obj["text"]
|
@@ -126,6 +165,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
126
165
|
json=data,
|
127
166
|
stream=True,
|
128
167
|
auth_token=self.auth_token,
|
168
|
+
api_key=self.api_key,
|
169
|
+
verify=self.verify,
|
129
170
|
)
|
130
171
|
pos = 0
|
131
172
|
|
@@ -157,7 +198,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
157
198
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
158
199
|
self._add_images(s, data)
|
159
200
|
res = http_request(
|
160
|
-
self.base_url + "/generate",
|
201
|
+
self.base_url + "/generate",
|
202
|
+
json=data,
|
203
|
+
auth_token=self.auth_token,
|
204
|
+
api_key=self.api_key,
|
205
|
+
verify=self.verify,
|
161
206
|
)
|
162
207
|
assert res.status_code == 200
|
163
208
|
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
@@ -171,7 +216,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
171
216
|
}
|
172
217
|
self._add_images(s, data)
|
173
218
|
res = http_request(
|
174
|
-
self.base_url + "/generate",
|
219
|
+
self.base_url + "/generate",
|
220
|
+
json=data,
|
221
|
+
auth_token=self.auth_token,
|
222
|
+
api_key=self.api_key,
|
223
|
+
verify=self.verify,
|
175
224
|
)
|
176
225
|
assert res.status_code == 200
|
177
226
|
obj = res.json()
|
@@ -188,6 +237,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
188
237
|
self.base_url + "/concate_and_append_request",
|
189
238
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
190
239
|
auth_token=self.auth_token,
|
240
|
+
api_key=self.api_key,
|
241
|
+
verify=self.verify,
|
191
242
|
)
|
192
243
|
assert res.status_code == 200
|
193
244
|
|
sglang/lang/chat_template.py
CHANGED
@@ -12,42 +12,43 @@ class ChatTemplateStyle(Enum):
|
|
12
12
|
class ChatTemplate:
|
13
13
|
name: str
|
14
14
|
default_system_prompt: str
|
15
|
-
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
15
|
+
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
16
16
|
stop_str: List[str] = ()
|
17
17
|
image_token: str = "<image>"
|
18
18
|
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
19
19
|
|
20
|
-
def get_prefix_and_suffix(
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
20
|
+
def get_prefix_and_suffix(
|
21
|
+
self, role: str, hist_messages: List[Dict]
|
22
|
+
) -> Tuple[str, str]:
|
23
|
+
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
24
|
+
|
25
|
+
if self.style == ChatTemplateStyle.LLAMA2:
|
26
|
+
if role == "system" and not hist_messages:
|
27
|
+
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
28
|
+
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
|
29
|
+
"system", ("", "")
|
29
30
|
)
|
31
|
+
return (user_prefix + system_prefix, system_suffix)
|
30
32
|
elif (
|
31
|
-
|
32
|
-
and
|
33
|
+
role == "user"
|
34
|
+
and len(hist_messages) == 1
|
33
35
|
and hist_messages[0]["content"] is not None
|
34
36
|
):
|
35
|
-
return ("",
|
36
|
-
|
37
|
-
|
38
|
-
raise ValueError(f"Invalid style: {self.style}")
|
37
|
+
return ("", suffix)
|
38
|
+
|
39
|
+
return prefix, suffix
|
39
40
|
|
40
|
-
def get_prompt(self, messages):
|
41
|
+
def get_prompt(self, messages: List[Dict]) -> str:
|
41
42
|
prompt = ""
|
42
|
-
for i in
|
43
|
-
role, content =
|
43
|
+
for i, message in enumerate(messages):
|
44
|
+
role, content = message["role"], message["content"]
|
44
45
|
if role == "system" and content is None:
|
45
46
|
content = self.default_system_prompt
|
46
47
|
if content is None:
|
47
48
|
continue
|
48
49
|
|
49
50
|
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
50
|
-
prompt += prefix
|
51
|
+
prompt += f"{prefix}{content}{suffix}"
|
51
52
|
return prompt
|
52
53
|
|
53
54
|
|
@@ -106,9 +107,9 @@ register_chat_template(
|
|
106
107
|
name="chatml",
|
107
108
|
default_system_prompt=None,
|
108
109
|
role_prefix_and_suffix={
|
109
|
-
"system": ("<|im_start|>system\n", "
|
110
|
-
"user": ("<|im_start|>user\n", "
|
111
|
-
"assistant": ("<|im_start|>assistant\n", "
|
110
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
111
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
112
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
112
113
|
},
|
113
114
|
style=ChatTemplateStyle.PLAIN,
|
114
115
|
stop_str=("<|im_end|>",),
|
@@ -121,9 +122,9 @@ register_chat_template(
|
|
121
122
|
name="chatml-llava",
|
122
123
|
default_system_prompt="Answer the questions.",
|
123
124
|
role_prefix_and_suffix={
|
124
|
-
"system": ("<|im_start|>system\n", "
|
125
|
-
"user": ("<|im_start|>user\n", "
|
126
|
-
"assistant": ("<|im_start|>assistant\n", "
|
125
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
126
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
127
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
127
128
|
},
|
128
129
|
style=ChatTemplateStyle.PLAIN,
|
129
130
|
stop_str=("<|im_end|>",),
|
@@ -178,6 +179,19 @@ register_chat_template(
|
|
178
179
|
)
|
179
180
|
)
|
180
181
|
|
182
|
+
register_chat_template(
|
183
|
+
ChatTemplate(
|
184
|
+
name="gemma-it",
|
185
|
+
default_system_prompt=None,
|
186
|
+
role_prefix_and_suffix={
|
187
|
+
"system": ("", ""),
|
188
|
+
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
189
|
+
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
|
190
|
+
},
|
191
|
+
style=ChatTemplateStyle.PLAIN,
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
181
195
|
|
182
196
|
@register_chat_template_matching_function
|
183
197
|
def match_vicuna(model_path: str):
|
@@ -218,6 +232,13 @@ def match_chat_yi(model_path: str):
|
|
218
232
|
return get_chat_template("yi")
|
219
233
|
|
220
234
|
|
235
|
+
@register_chat_template_matching_function
|
236
|
+
def match_gemma_it(model_path: str):
|
237
|
+
model_path = model_path.lower()
|
238
|
+
if "gemma" in model_path and "it" in model_path:
|
239
|
+
return get_chat_template("gemma-it")
|
240
|
+
|
241
|
+
|
221
242
|
if __name__ == "__main__":
|
222
243
|
messages = [
|
223
244
|
{"role": "system", "content": None}, # None means default
|
sglang/lang/interpreter.py
CHANGED
@@ -245,6 +245,9 @@ class StreamExecutor:
|
|
245
245
|
self.variable_event[name].wait()
|
246
246
|
return self.variables[name]
|
247
247
|
|
248
|
+
def set_var(self, name, value):
|
249
|
+
self.variables[name] = value
|
250
|
+
|
248
251
|
def get_meta_info(self, name):
|
249
252
|
if name in self.variable_event:
|
250
253
|
self.variable_event[name].wait()
|
@@ -583,6 +586,10 @@ class StreamExecutor:
|
|
583
586
|
if self.chat_template.stop_str:
|
584
587
|
if not clone:
|
585
588
|
clone = self.default_sampling_para.clone()
|
589
|
+
if clone.stop == ():
|
590
|
+
clone.stop = []
|
591
|
+
elif isinstance(clone.stop, str):
|
592
|
+
clone.stop = [clone.stop]
|
586
593
|
clone.stop += self.chat_template.stop_str
|
587
594
|
|
588
595
|
return clone or self.default_sampling_para
|
@@ -679,7 +686,7 @@ class ProgramState:
|
|
679
686
|
if var_name is None:
|
680
687
|
yield self.text()
|
681
688
|
else:
|
682
|
-
yield self.get_var(
|
689
|
+
yield self.get_var(var_name)
|
683
690
|
|
684
691
|
async def text_async_iter(
|
685
692
|
self, var_name: Optional[str] = None, return_meta_data: bool = False
|
@@ -717,11 +724,14 @@ class ProgramState:
|
|
717
724
|
if var_name is None:
|
718
725
|
yield self.text()
|
719
726
|
else:
|
720
|
-
yield self.get_var(
|
727
|
+
yield self.get_var(var_name)
|
721
728
|
|
722
729
|
def get_var(self, name):
|
723
730
|
return self.stream_executor.get_var(name)
|
724
731
|
|
732
|
+
def set_var(self, name, value):
|
733
|
+
return self.stream_executor.set_var(name, value)
|
734
|
+
|
725
735
|
def get_meta_info(self, name):
|
726
736
|
return self.stream_executor.get_meta_info(name)
|
727
737
|
|
@@ -732,6 +742,9 @@ class ProgramState:
|
|
732
742
|
def __getitem__(self, name):
|
733
743
|
return self.get_var(name)
|
734
744
|
|
745
|
+
def __setitem__(self, name, value):
|
746
|
+
self.set_var(name, value)
|
747
|
+
|
735
748
|
def __del__(self):
|
736
749
|
self.stream_executor.end()
|
737
750
|
|
sglang/lang/ir.py
CHANGED
@@ -73,7 +73,7 @@ class SglSamplingParams:
|
|
73
73
|
"Regular expression is not supported in the Anthropic backend."
|
74
74
|
)
|
75
75
|
return {
|
76
|
-
"
|
76
|
+
"max_tokens": self.max_new_tokens,
|
77
77
|
"stop_sequences": (
|
78
78
|
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
|
79
79
|
),
|
@@ -1,9 +1,31 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Dict, Optional, Union
|
3
|
+
|
1
4
|
from outlines.caching import cache as disk_cache
|
2
5
|
from outlines.caching import disable_cache
|
3
6
|
from outlines.fsm.fsm import RegexFSM
|
4
|
-
from outlines.fsm.json_schema import build_regex_from_object
|
5
7
|
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
6
8
|
from outlines.models.transformers import TransformerTokenizer
|
9
|
+
from pydantic import BaseModel
|
10
|
+
|
11
|
+
try:
|
12
|
+
from outlines.fsm.json_schema import build_regex_from_object
|
13
|
+
except ImportError:
|
14
|
+
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
15
|
+
# which only accepts string schema as input.
|
16
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
17
|
+
|
18
|
+
def build_regex_from_object(
|
19
|
+
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
20
|
+
):
|
21
|
+
if isinstance(object, type(BaseModel)):
|
22
|
+
schema = json.dumps(object.model_json_schema())
|
23
|
+
elif isinstance(object, Dict):
|
24
|
+
schema = json.dumps(object)
|
25
|
+
else:
|
26
|
+
schema = object
|
27
|
+
return build_regex_from_schema(schema, whitespace_pattern)
|
28
|
+
|
7
29
|
|
8
30
|
__all__ = [
|
9
31
|
"RegexFSM",
|
@@ -5,9 +5,20 @@ from sglang.srt.constrained.base_cache import BaseCache
|
|
5
5
|
class FSMCache(BaseCache):
|
6
6
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
7
|
super().__init__(enable=enable)
|
8
|
-
|
9
|
-
|
10
|
-
)
|
8
|
+
|
9
|
+
from importlib.metadata import version
|
10
|
+
if version("outlines") >= "0.0.35":
|
11
|
+
from transformers import AutoTokenizer
|
12
|
+
|
13
|
+
tokenizer_args_dict.setdefault("padding_side", "left")
|
14
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
15
|
+
tokenizer_path, **tokenizer_args_dict
|
16
|
+
)
|
17
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
18
|
+
else:
|
19
|
+
self.outlines_tokenizer = TransformerTokenizer(
|
20
|
+
tokenizer_path, **tokenizer_args_dict
|
21
|
+
)
|
11
22
|
|
12
23
|
def init_value(self, regex):
|
13
24
|
return RegexFSM(regex, self.outlines_tokenizer)
|
@@ -129,7 +129,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
129
129
|
|
130
130
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
131
131
|
assert Lq == Lk and Lk == Lv
|
132
|
-
assert Lk in {16, 32, 64, 128}
|
132
|
+
assert Lk in {16, 32, 64, 128, 256}
|
133
133
|
|
134
134
|
sm_scale = 1.0 / (Lq**0.5)
|
135
135
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
@@ -181,19 +181,20 @@ def extend_attention_fwd(
|
|
181
181
|
|
182
182
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
183
183
|
"""
|
184
|
-
if CUDA_CAPABILITY[0] >= 8:
|
185
|
-
BLOCK_M, BLOCK_N = 128, 128
|
186
|
-
else:
|
187
|
-
BLOCK_M, BLOCK_N = 64, 64
|
188
|
-
|
189
184
|
Lq, Lk, Lv, Lo = (
|
190
185
|
q_extend.shape[-1],
|
191
186
|
k_extend.shape[-1],
|
192
187
|
v_extend.shape[-1],
|
193
188
|
o_extend.shape[-1],
|
194
189
|
)
|
190
|
+
|
195
191
|
assert Lq == Lk and Lk == Lv and Lv == Lo
|
196
|
-
assert Lq in {16, 32, 64, 128}
|
192
|
+
assert Lq in {16, 32, 64, 128, 256}
|
193
|
+
|
194
|
+
if CUDA_CAPABILITY[0] >= 8:
|
195
|
+
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
196
|
+
else:
|
197
|
+
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
197
198
|
|
198
199
|
sm_scale = 1.0 / (Lq**0.5)
|
199
200
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
@@ -1,15 +1,9 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
import torch
|
4
2
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
5
3
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
6
4
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
7
5
|
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
8
6
|
from torch import nn
|
9
|
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
10
|
-
get_tensor_model_parallel_rank,
|
11
|
-
get_tensor_model_parallel_world_size,
|
12
|
-
)
|
13
7
|
|
14
8
|
|
15
9
|
class RadixAttention(nn.Module):
|
@@ -21,11 +15,9 @@ class RadixAttention(nn.Module):
|
|
21
15
|
self.head_dim = head_dim
|
22
16
|
self.layer_id = layer_id
|
23
17
|
|
24
|
-
from sglang.srt.managers.router.model_runner import
|
25
|
-
|
26
|
-
self.use_flashinfer = "flashinfer" in global_model_mode
|
18
|
+
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
27
19
|
|
28
|
-
if
|
20
|
+
if global_server_args_dict.get("enable_flashinfer", False):
|
29
21
|
self.prefill_forward = self.prefill_forward_flashinfer
|
30
22
|
self.extend_forward = self.prefill_forward_flashinfer
|
31
23
|
self.decode_forward = self.decode_forward_flashinfer
|