sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import json
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from sglang.global_config import global_config
|
7
|
+
from sglang.lang.backend.base_backend import BaseBackend
|
8
|
+
from sglang.lang.chat_template import get_chat_template_by_model_path
|
9
|
+
from sglang.lang.interpreter import StreamExecutor
|
10
|
+
from sglang.lang.ir import SglSamplingParams
|
11
|
+
from sglang.utils import http_request
|
12
|
+
|
13
|
+
|
14
|
+
class RuntimeEndpoint(BaseBackend):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
base_url: str,
|
18
|
+
auth_token: Optional[str] = None,
|
19
|
+
api_key: Optional[str] = None,
|
20
|
+
verify: Optional[str] = None,
|
21
|
+
):
|
22
|
+
super().__init__()
|
23
|
+
self.support_concate_and_append = True
|
24
|
+
|
25
|
+
self.base_url = base_url
|
26
|
+
self.auth_token = auth_token
|
27
|
+
self.api_key = api_key
|
28
|
+
self.verify = verify
|
29
|
+
|
30
|
+
res = http_request(
|
31
|
+
self.base_url + "/get_model_info",
|
32
|
+
auth_token=self.auth_token,
|
33
|
+
api_key=self.api_key,
|
34
|
+
verify=self.verify,
|
35
|
+
)
|
36
|
+
self._assert_success(res)
|
37
|
+
self.model_info = res.json()
|
38
|
+
|
39
|
+
self.chat_template = get_chat_template_by_model_path(
|
40
|
+
self.model_info["model_path"]
|
41
|
+
)
|
42
|
+
|
43
|
+
def get_model_name(self):
|
44
|
+
return self.model_info["model_path"]
|
45
|
+
|
46
|
+
def flush_cache(self):
|
47
|
+
res = http_request(
|
48
|
+
self.base_url + "/flush_cache",
|
49
|
+
auth_token=self.auth_token,
|
50
|
+
verify=self.verify,
|
51
|
+
)
|
52
|
+
self._assert_success(res)
|
53
|
+
|
54
|
+
def get_server_args(self):
|
55
|
+
res = http_request(
|
56
|
+
self.base_url + "/get_server_args",
|
57
|
+
auth_token=self.auth_token,
|
58
|
+
verify=self.verify,
|
59
|
+
)
|
60
|
+
self._assert_success(res)
|
61
|
+
return res.json()
|
62
|
+
|
63
|
+
def get_chat_template(self):
|
64
|
+
return self.chat_template
|
65
|
+
|
66
|
+
def cache_prefix(self, prefix_str: str):
|
67
|
+
res = http_request(
|
68
|
+
self.base_url + "/generate",
|
69
|
+
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
70
|
+
auth_token=self.auth_token,
|
71
|
+
api_key=self.api_key,
|
72
|
+
verify=self.verify,
|
73
|
+
)
|
74
|
+
self._assert_success(res)
|
75
|
+
|
76
|
+
def commit_lazy_operations(self, s: StreamExecutor):
|
77
|
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
78
|
+
self._add_images(s, data)
|
79
|
+
res = http_request(
|
80
|
+
self.base_url + "/generate",
|
81
|
+
json=data,
|
82
|
+
auth_token=self.auth_token,
|
83
|
+
api_key=self.api_key,
|
84
|
+
verify=self.verify,
|
85
|
+
)
|
86
|
+
self._assert_success(res)
|
87
|
+
|
88
|
+
def fill_image(self, s: StreamExecutor):
|
89
|
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
90
|
+
self._add_images(s, data)
|
91
|
+
res = http_request(
|
92
|
+
self.base_url + "/generate",
|
93
|
+
json=data,
|
94
|
+
auth_token=self.auth_token,
|
95
|
+
api_key=self.api_key,
|
96
|
+
verify=self.verify,
|
97
|
+
)
|
98
|
+
self._assert_success(res)
|
99
|
+
|
100
|
+
def generate(
|
101
|
+
self,
|
102
|
+
s: StreamExecutor,
|
103
|
+
sampling_params: SglSamplingParams,
|
104
|
+
):
|
105
|
+
if sampling_params.dtype is None:
|
106
|
+
data = {
|
107
|
+
"text": s.text_,
|
108
|
+
"sampling_params": {
|
109
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
110
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
111
|
+
**sampling_params.to_srt_kwargs(),
|
112
|
+
},
|
113
|
+
}
|
114
|
+
elif sampling_params.dtype in [int, "int"]:
|
115
|
+
data = {
|
116
|
+
"text": s.text_,
|
117
|
+
"sampling_params": {
|
118
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
119
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
120
|
+
"dtype": "int",
|
121
|
+
**sampling_params.to_srt_kwargs(),
|
122
|
+
},
|
123
|
+
}
|
124
|
+
else:
|
125
|
+
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
126
|
+
|
127
|
+
for item in [
|
128
|
+
"return_logprob",
|
129
|
+
"logprob_start_len",
|
130
|
+
"top_logprobs_num",
|
131
|
+
"return_text_in_logprobs",
|
132
|
+
]:
|
133
|
+
value = getattr(sampling_params, item, None)
|
134
|
+
if value is not None:
|
135
|
+
data[item] = value
|
136
|
+
|
137
|
+
self._add_images(s, data)
|
138
|
+
|
139
|
+
res = http_request(
|
140
|
+
self.base_url + "/generate",
|
141
|
+
json=data,
|
142
|
+
auth_token=self.auth_token,
|
143
|
+
api_key=self.api_key,
|
144
|
+
verify=self.verify,
|
145
|
+
)
|
146
|
+
self._assert_success(res)
|
147
|
+
|
148
|
+
obj = res.json()
|
149
|
+
comp = obj["text"]
|
150
|
+
return comp, obj["meta_info"]
|
151
|
+
|
152
|
+
def generate_stream(
|
153
|
+
self,
|
154
|
+
s: StreamExecutor,
|
155
|
+
sampling_params: SglSamplingParams,
|
156
|
+
):
|
157
|
+
if sampling_params.dtype is None:
|
158
|
+
data = {
|
159
|
+
"text": s.text_,
|
160
|
+
"sampling_params": {
|
161
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
162
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
163
|
+
**sampling_params.to_srt_kwargs(),
|
164
|
+
},
|
165
|
+
}
|
166
|
+
elif sampling_params.dtype in [int, "int"]:
|
167
|
+
data = {
|
168
|
+
"text": s.text_,
|
169
|
+
"sampling_params": {
|
170
|
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
171
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
172
|
+
"dtype": "int",
|
173
|
+
**sampling_params.to_srt_kwargs(),
|
174
|
+
},
|
175
|
+
}
|
176
|
+
else:
|
177
|
+
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
178
|
+
|
179
|
+
for item in [
|
180
|
+
"return_logprob",
|
181
|
+
"logprob_start_len",
|
182
|
+
"top_logprobs_num",
|
183
|
+
"return_text_in_logprobs",
|
184
|
+
]:
|
185
|
+
value = getattr(sampling_params, item, None)
|
186
|
+
if value is not None:
|
187
|
+
data[item] = value
|
188
|
+
|
189
|
+
data["stream"] = True
|
190
|
+
self._add_images(s, data)
|
191
|
+
|
192
|
+
res = http_request(
|
193
|
+
self.base_url + "/generate",
|
194
|
+
json=data,
|
195
|
+
stream=True,
|
196
|
+
auth_token=self.auth_token,
|
197
|
+
api_key=self.api_key,
|
198
|
+
verify=self.verify,
|
199
|
+
)
|
200
|
+
self._assert_success(res)
|
201
|
+
pos = 0
|
202
|
+
|
203
|
+
for chunk in res.iter_lines(decode_unicode=False):
|
204
|
+
chunk = chunk.decode("utf-8")
|
205
|
+
if chunk and chunk.startswith("data:"):
|
206
|
+
if chunk == "data: [DONE]":
|
207
|
+
break
|
208
|
+
data = json.loads(chunk[5:].strip("\n"))
|
209
|
+
chunk_text = data["text"][pos:]
|
210
|
+
meta_info = data["meta_info"]
|
211
|
+
pos += len(chunk_text)
|
212
|
+
yield chunk_text, meta_info
|
213
|
+
|
214
|
+
def select(
|
215
|
+
self,
|
216
|
+
s: StreamExecutor,
|
217
|
+
choices: List[str],
|
218
|
+
temperature: float,
|
219
|
+
):
|
220
|
+
assert temperature <= 1e-5
|
221
|
+
|
222
|
+
# Cache common prefix
|
223
|
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
224
|
+
self._add_images(s, data)
|
225
|
+
res = http_request(
|
226
|
+
self.base_url + "/generate",
|
227
|
+
json=data,
|
228
|
+
auth_token=self.auth_token,
|
229
|
+
api_key=self.api_key,
|
230
|
+
verify=self.verify,
|
231
|
+
)
|
232
|
+
self._assert_success(res)
|
233
|
+
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
234
|
+
|
235
|
+
# Compute logprob
|
236
|
+
data = {
|
237
|
+
"text": [s.text_ + c for c in choices],
|
238
|
+
"sampling_params": {"max_new_tokens": 0},
|
239
|
+
"return_logprob": True,
|
240
|
+
"logprob_start_len": max(prompt_len - 2, 0),
|
241
|
+
}
|
242
|
+
self._add_images(s, data)
|
243
|
+
res = http_request(
|
244
|
+
self.base_url + "/generate",
|
245
|
+
json=data,
|
246
|
+
auth_token=self.auth_token,
|
247
|
+
api_key=self.api_key,
|
248
|
+
verify=self.verify,
|
249
|
+
)
|
250
|
+
self._assert_success(res)
|
251
|
+
obj = res.json()
|
252
|
+
normalized_prompt_logprobs = [
|
253
|
+
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
254
|
+
]
|
255
|
+
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
256
|
+
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
257
|
+
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
258
|
+
|
259
|
+
return (
|
260
|
+
decision,
|
261
|
+
normalized_prompt_logprobs,
|
262
|
+
prefill_token_logprobs,
|
263
|
+
decode_token_logprobs,
|
264
|
+
)
|
265
|
+
|
266
|
+
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
267
|
+
res = http_request(
|
268
|
+
self.base_url + "/concate_and_append_request",
|
269
|
+
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
270
|
+
auth_token=self.auth_token,
|
271
|
+
api_key=self.api_key,
|
272
|
+
verify=self.verify,
|
273
|
+
)
|
274
|
+
self._assert_success(res)
|
275
|
+
|
276
|
+
def _add_images(self, s: StreamExecutor, data):
|
277
|
+
if s.images_:
|
278
|
+
assert len(s.images_) == 1, "Only support one image."
|
279
|
+
data["image_data"] = s.images_[0][1]
|
280
|
+
|
281
|
+
def _assert_success(self, res):
|
282
|
+
if res.status_code != 200:
|
283
|
+
raise RuntimeError(res.json())
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from sglang.lang.backend.base_backend import BaseBackend
|
6
|
+
from sglang.lang.chat_template import get_chat_template
|
7
|
+
from sglang.lang.interpreter import StreamExecutor
|
8
|
+
from sglang.lang.ir import SglSamplingParams
|
9
|
+
|
10
|
+
try:
|
11
|
+
import vertexai
|
12
|
+
from vertexai.preview.generative_models import (
|
13
|
+
GenerationConfig,
|
14
|
+
GenerativeModel,
|
15
|
+
Image,
|
16
|
+
)
|
17
|
+
except ImportError as e:
|
18
|
+
GenerativeModel = e
|
19
|
+
|
20
|
+
|
21
|
+
class VertexAI(BaseBackend):
|
22
|
+
def __init__(self, model_name, safety_settings=None):
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
if isinstance(GenerativeModel, Exception):
|
26
|
+
raise GenerativeModel
|
27
|
+
|
28
|
+
project_id = os.environ["GCP_PROJECT_ID"]
|
29
|
+
location = os.environ.get("GCP_LOCATION")
|
30
|
+
vertexai.init(project=project_id, location=location)
|
31
|
+
|
32
|
+
self.model_name = model_name
|
33
|
+
self.chat_template = get_chat_template("default")
|
34
|
+
self.safety_settings = safety_settings
|
35
|
+
|
36
|
+
def get_chat_template(self):
|
37
|
+
return self.chat_template
|
38
|
+
|
39
|
+
def generate(
|
40
|
+
self,
|
41
|
+
s: StreamExecutor,
|
42
|
+
sampling_params: SglSamplingParams,
|
43
|
+
):
|
44
|
+
if s.messages_:
|
45
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
46
|
+
else:
|
47
|
+
# single-turn
|
48
|
+
prompt = (
|
49
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
50
|
+
if s.cur_images
|
51
|
+
else s.text_
|
52
|
+
)
|
53
|
+
ret = GenerativeModel(self.model_name).generate_content(
|
54
|
+
prompt,
|
55
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
56
|
+
safety_settings=self.safety_settings,
|
57
|
+
)
|
58
|
+
|
59
|
+
comp = ret.text
|
60
|
+
|
61
|
+
return comp, {}
|
62
|
+
|
63
|
+
def generate_stream(
|
64
|
+
self,
|
65
|
+
s: StreamExecutor,
|
66
|
+
sampling_params: SglSamplingParams,
|
67
|
+
):
|
68
|
+
if s.messages_:
|
69
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
70
|
+
else:
|
71
|
+
# single-turn
|
72
|
+
prompt = (
|
73
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
74
|
+
if s.cur_images
|
75
|
+
else s.text_
|
76
|
+
)
|
77
|
+
generator = GenerativeModel(self.model_name).generate_content(
|
78
|
+
prompt,
|
79
|
+
stream=True,
|
80
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
81
|
+
safety_settings=self.safety_settings,
|
82
|
+
)
|
83
|
+
for ret in generator:
|
84
|
+
yield ret.text, {}
|
85
|
+
|
86
|
+
def text_to_vertexai_input(self, text, images):
|
87
|
+
input = []
|
88
|
+
# split with image token
|
89
|
+
text_segs = text.split(self.chat_template.image_token)
|
90
|
+
for image_path, image_base64_data in images:
|
91
|
+
text_seg = text_segs.pop(0)
|
92
|
+
if text_seg != "":
|
93
|
+
input.append(text_seg)
|
94
|
+
input.append(Image.from_bytes(image_base64_data))
|
95
|
+
text_seg = text_segs.pop(0)
|
96
|
+
if text_seg != "":
|
97
|
+
input.append(text_seg)
|
98
|
+
return input
|
99
|
+
|
100
|
+
def messages_to_vertexai_input(self, messages):
|
101
|
+
vertexai_message = []
|
102
|
+
# from openai message format to vertexai message format
|
103
|
+
for msg in messages:
|
104
|
+
if isinstance(msg["content"], str):
|
105
|
+
text = msg["content"]
|
106
|
+
else:
|
107
|
+
text = msg["content"][0]["text"]
|
108
|
+
|
109
|
+
if msg["role"] == "system":
|
110
|
+
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
111
|
+
vertexai_message.append(
|
112
|
+
{
|
113
|
+
"role": "user",
|
114
|
+
"parts": [{"text": "System prompt: " + text}],
|
115
|
+
}
|
116
|
+
)
|
117
|
+
vertexai_message.append(
|
118
|
+
{
|
119
|
+
"role": "model",
|
120
|
+
"parts": [{"text": "Understood."}],
|
121
|
+
}
|
122
|
+
)
|
123
|
+
continue
|
124
|
+
if msg["role"] == "user":
|
125
|
+
vertexai_msg = {
|
126
|
+
"role": "user",
|
127
|
+
"parts": [{"text": text}],
|
128
|
+
}
|
129
|
+
elif msg["role"] == "assistant":
|
130
|
+
vertexai_msg = {
|
131
|
+
"role": "model",
|
132
|
+
"parts": [{"text": text}],
|
133
|
+
}
|
134
|
+
|
135
|
+
# images
|
136
|
+
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
137
|
+
for image in msg["content"][1:]:
|
138
|
+
assert image["type"] == "image_url"
|
139
|
+
vertexai_msg["parts"].append(
|
140
|
+
{
|
141
|
+
"inline_data": {
|
142
|
+
"data": image["image_url"]["url"].split(",")[1],
|
143
|
+
"mime_type": "image/jpeg",
|
144
|
+
}
|
145
|
+
}
|
146
|
+
)
|
147
|
+
|
148
|
+
vertexai_message.append(vertexai_msg)
|
149
|
+
return vertexai_message
|
sglang/lang/interpreter.py
CHANGED
@@ -288,6 +288,7 @@ class StreamExecutor:
|
|
288
288
|
exes[i].text_ = str(self.text_)
|
289
289
|
exes[i].messages_ = list(self.messages_)
|
290
290
|
exes[i].cur_role = self.cur_role
|
291
|
+
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
|
291
292
|
exes[i].fork_start_text_pos = len(self.text_)
|
292
293
|
exes[i].images_ = list(self.images_)
|
293
294
|
|
sglang/lang/tracer.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
import uuid
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Union
|
5
5
|
|
6
|
-
from sglang.backend.base_backend import BaseBackend
|
7
6
|
from sglang.global_config import global_config
|
7
|
+
from sglang.lang.backend.base_backend import BaseBackend
|
8
8
|
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
9
9
|
from sglang.lang.ir import (
|
10
10
|
SglArgument,
|
sglang/launch_server.py
CHANGED
sglang/launch_server_llavavid.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""Launch the inference server for Llava-video model."""
|
2
2
|
|
3
3
|
import argparse
|
4
|
-
import multiprocessing as mp
|
5
4
|
|
6
5
|
from sglang.srt.server import ServerArgs, launch_server
|
7
6
|
|
@@ -27,6 +26,4 @@ if __name__ == "__main__":
|
|
27
26
|
|
28
27
|
server_args = ServerArgs.from_cli_args(args)
|
29
28
|
|
30
|
-
|
31
|
-
|
32
|
-
launch_server(server_args, pipe_writer, model_overide_args)
|
29
|
+
launch_server(server_args, model_overide_args, None)
|
sglang/srt/conversation.py
CHANGED
@@ -6,7 +6,7 @@ import dataclasses
|
|
6
6
|
from enum import IntEnum, auto
|
7
7
|
from typing import Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
|
-
from sglang.srt.
|
9
|
+
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
10
10
|
|
11
11
|
|
12
12
|
class SeparatorStyle(IntEnum):
|
@@ -4,19 +4,26 @@ import functools
|
|
4
4
|
import json
|
5
5
|
import os
|
6
6
|
import warnings
|
7
|
-
from typing import AbstractSet, Collection, Literal, Optional, Union
|
7
|
+
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
8
8
|
|
9
9
|
from huggingface_hub import snapshot_download
|
10
10
|
from transformers import (
|
11
11
|
AutoConfig,
|
12
12
|
AutoProcessor,
|
13
13
|
AutoTokenizer,
|
14
|
+
PretrainedConfig,
|
14
15
|
PreTrainedTokenizer,
|
15
16
|
PreTrainedTokenizerFast,
|
16
17
|
)
|
18
|
+
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
17
19
|
|
18
20
|
from sglang.srt.utils import is_multimodal_model
|
19
21
|
|
22
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
23
|
+
ChatGLMConfig.model_type: ChatGLMConfig,
|
24
|
+
DbrxConfig.model_type: DbrxConfig,
|
25
|
+
}
|
26
|
+
|
20
27
|
|
21
28
|
def download_from_hf(model_path: str):
|
22
29
|
if os.path.exists(model_path):
|
@@ -40,6 +47,9 @@ def get_config(
|
|
40
47
|
config = AutoConfig.from_pretrained(
|
41
48
|
model, trust_remote_code=trust_remote_code, revision=revision
|
42
49
|
)
|
50
|
+
if config.model_type in _CONFIG_REGISTRY:
|
51
|
+
config_class = _CONFIG_REGISTRY[config.model_type]
|
52
|
+
config = config_class.from_pretrained(model, revision=revision)
|
43
53
|
if model_overide_args:
|
44
54
|
config.update(model_overide_args)
|
45
55
|
return config
|
@@ -63,6 +73,8 @@ def get_context_length(config):
|
|
63
73
|
rope_scaling = getattr(config, "rope_scaling", None)
|
64
74
|
if rope_scaling:
|
65
75
|
rope_scaling_factor = config.rope_scaling["factor"]
|
76
|
+
if config.rope_scaling["rope_type"] == "llama3":
|
77
|
+
rope_scaling_factor = 1
|
66
78
|
else:
|
67
79
|
rope_scaling_factor = 1
|
68
80
|
|
@@ -4,8 +4,6 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
-
from sglang.srt.utils import wrap_kernel_launcher
|
8
|
-
|
9
7
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
10
8
|
|
11
9
|
|
@@ -119,9 +117,6 @@ def _fwd_kernel(
|
|
119
117
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
120
118
|
|
121
119
|
|
122
|
-
cached_kernel = None
|
123
|
-
|
124
|
-
|
125
120
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
126
121
|
if CUDA_CAPABILITY[0] >= 8:
|
127
122
|
BLOCK = 128
|
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
139
134
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
140
135
|
num_warps = 4 if Lk <= 64 else 8
|
141
136
|
|
142
|
-
global cached_kernel
|
143
|
-
if cached_kernel:
|
144
|
-
cached_kernel(
|
145
|
-
grid,
|
146
|
-
num_warps,
|
147
|
-
q,
|
148
|
-
k,
|
149
|
-
v,
|
150
|
-
sm_scale,
|
151
|
-
b_start_loc,
|
152
|
-
b_seq_len,
|
153
|
-
o,
|
154
|
-
q.stride(0),
|
155
|
-
q.stride(1),
|
156
|
-
k.stride(0),
|
157
|
-
k.stride(1),
|
158
|
-
v.stride(0),
|
159
|
-
v.stride(1),
|
160
|
-
o.stride(0),
|
161
|
-
o.stride(1),
|
162
|
-
)
|
163
|
-
return
|
164
|
-
|
165
137
|
_fwd_kernel[grid](
|
166
138
|
q,
|
167
139
|
k,
|
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
185
157
|
num_warps=num_warps,
|
186
158
|
num_stages=1,
|
187
159
|
)
|
188
|
-
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
@@ -3,7 +3,6 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
6
|
-
from sglang.srt.utils import wrap_kernel_launcher
|
7
6
|
|
8
7
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
9
8
|
|
@@ -172,9 +171,6 @@ def _fwd_kernel(
|
|
172
171
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
173
172
|
|
174
173
|
|
175
|
-
cached_kernel = None
|
176
|
-
|
177
|
-
|
178
174
|
def extend_attention_fwd(
|
179
175
|
q_extend,
|
180
176
|
k_extend,
|
@@ -222,40 +218,6 @@ def extend_attention_fwd(
|
|
222
218
|
num_warps = 4 if Lk <= 64 else 8
|
223
219
|
num_stages = 1
|
224
220
|
|
225
|
-
global cached_kernel
|
226
|
-
if cached_kernel:
|
227
|
-
cached_kernel(
|
228
|
-
grid,
|
229
|
-
num_warps,
|
230
|
-
q_extend,
|
231
|
-
k_extend,
|
232
|
-
v_extend,
|
233
|
-
o_extend,
|
234
|
-
k_buffer,
|
235
|
-
v_buffer,
|
236
|
-
req_to_tokens,
|
237
|
-
b_req_idx,
|
238
|
-
b_seq_len,
|
239
|
-
b_start_loc_extend,
|
240
|
-
b_seq_len_extend,
|
241
|
-
sm_scale,
|
242
|
-
kv_group_num,
|
243
|
-
q_extend.stride(0),
|
244
|
-
q_extend.stride(1),
|
245
|
-
k_extend.stride(0),
|
246
|
-
k_extend.stride(1),
|
247
|
-
v_extend.stride(0),
|
248
|
-
v_extend.stride(1),
|
249
|
-
o_extend.stride(0),
|
250
|
-
o_extend.stride(1),
|
251
|
-
k_buffer.stride(0),
|
252
|
-
k_buffer.stride(1),
|
253
|
-
v_buffer.stride(0),
|
254
|
-
v_buffer.stride(1),
|
255
|
-
req_to_tokens.stride(0),
|
256
|
-
)
|
257
|
-
return
|
258
|
-
|
259
221
|
_fwd_kernel[grid](
|
260
222
|
q_extend,
|
261
223
|
k_extend,
|
@@ -290,7 +252,6 @@ def extend_attention_fwd(
|
|
290
252
|
num_stages=num_stages,
|
291
253
|
logit_cap=logit_cap,
|
292
254
|
)
|
293
|
-
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
294
255
|
|
295
256
|
|
296
257
|
def redundant_attention(
|