sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/test/test_programs.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
This file contains the SGL programs used for unit testing.
|
3
|
-
"""
|
1
|
+
"""This file contains the SGL programs used for unit testing."""
|
4
2
|
|
5
3
|
import json
|
6
4
|
import re
|
@@ -226,7 +224,7 @@ Action 3: Finish [United States].\n
|
|
226
224
|
|
227
225
|
def test_parallel_decoding():
|
228
226
|
max_tokens = 64
|
229
|
-
|
227
|
+
fork_size = 5
|
230
228
|
|
231
229
|
@sgl.function
|
232
230
|
def parallel_decoding(s, topic):
|
@@ -234,17 +232,17 @@ def test_parallel_decoding():
|
|
234
232
|
s += "USER: Give some tips for " + topic + ".\n"
|
235
233
|
s += (
|
236
234
|
"ASSISTANT: Okay. Here are "
|
237
|
-
+ str(
|
235
|
+
+ str(fork_size)
|
238
236
|
+ " concise tips, each under 8 words:\n"
|
239
237
|
)
|
240
238
|
|
241
239
|
# Generate skeleton
|
242
|
-
for i in range(1, 1 +
|
240
|
+
for i in range(1, 1 + fork_size):
|
243
241
|
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
|
244
242
|
|
245
243
|
# Generate detailed tips
|
246
|
-
forks = s.fork(
|
247
|
-
for i in range(
|
244
|
+
forks = s.fork(fork_size)
|
245
|
+
for i in range(fork_size):
|
248
246
|
forks[
|
249
247
|
i
|
250
248
|
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
|
@@ -253,7 +251,7 @@ def test_parallel_decoding():
|
|
253
251
|
|
254
252
|
# Concatenate tips and summarize
|
255
253
|
s += "Here are these tips with detailed explanation:\n"
|
256
|
-
for i in range(
|
254
|
+
for i in range(fork_size):
|
257
255
|
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
|
258
256
|
|
259
257
|
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
@@ -296,7 +294,7 @@ def test_parallel_encoding(check_answer=True):
|
|
296
294
|
def test_image_qa():
|
297
295
|
@sgl.function
|
298
296
|
def image_qa(s, question):
|
299
|
-
s += sgl.user(sgl.image("
|
297
|
+
s += sgl.user(sgl.image("example_image.png") + question)
|
300
298
|
s += sgl.assistant(sgl.gen("answer"))
|
301
299
|
|
302
300
|
state = image_qa.run(
|
@@ -304,6 +302,7 @@ def test_image_qa():
|
|
304
302
|
temperature=0,
|
305
303
|
max_new_tokens=64,
|
306
304
|
)
|
305
|
+
|
307
306
|
assert (
|
308
307
|
"taxi" in state.messages()[-1]["content"]
|
309
308
|
or "car" in state.messages()[-1]["content"]
|
@@ -313,6 +312,7 @@ def test_image_qa():
|
|
313
312
|
def test_stream():
|
314
313
|
@sgl.function
|
315
314
|
def qa(s, question):
|
315
|
+
s += sgl.system("You are a helpful assistant.")
|
316
316
|
s += sgl.user(question)
|
317
317
|
s += sgl.assistant(sgl.gen("answer"))
|
318
318
|
|
@@ -348,3 +348,66 @@ def test_regex():
|
|
348
348
|
state = regex_gen.run()
|
349
349
|
answer = state["answer"]
|
350
350
|
assert re.match(regex, answer)
|
351
|
+
|
352
|
+
|
353
|
+
def test_completion_speculative():
|
354
|
+
@sgl.function(num_api_spec_tokens=64)
|
355
|
+
def gen_character_spec(s):
|
356
|
+
s += "Construct a character within the following format:\n"
|
357
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
358
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
359
|
+
s += (
|
360
|
+
"Name:"
|
361
|
+
+ sgl.gen("name", stop="\n")
|
362
|
+
+ "\nBirthday:"
|
363
|
+
+ sgl.gen("birthday", stop="\n")
|
364
|
+
)
|
365
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
366
|
+
|
367
|
+
@sgl.function
|
368
|
+
def gen_character_no_spec(s):
|
369
|
+
s += "Construct a character within the following format:\n"
|
370
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
371
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
372
|
+
s += (
|
373
|
+
"Name:"
|
374
|
+
+ sgl.gen("name", stop="\n")
|
375
|
+
+ "\nBirthday:"
|
376
|
+
+ sgl.gen("birthday", stop="\n")
|
377
|
+
)
|
378
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
379
|
+
|
380
|
+
token_usage = sgl.global_config.default_backend.token_usage
|
381
|
+
|
382
|
+
token_usage.reset()
|
383
|
+
gen_character_spec().sync()
|
384
|
+
usage_with_spec = token_usage.prompt_tokens
|
385
|
+
|
386
|
+
token_usage.reset()
|
387
|
+
gen_character_no_spec().sync()
|
388
|
+
usage_with_no_spec = token_usage.prompt_tokens
|
389
|
+
|
390
|
+
assert (
|
391
|
+
usage_with_spec < usage_with_no_spec
|
392
|
+
), f"{usage_with_spec} vs {usage_with_no_spec}"
|
393
|
+
|
394
|
+
|
395
|
+
def test_chat_completion_speculative():
|
396
|
+
@sgl.function(num_api_spec_tokens=256)
|
397
|
+
def gen_character_spec(s):
|
398
|
+
s += sgl.system("You are a helpful assistant.")
|
399
|
+
s += sgl.user("Construct a character within the following format:")
|
400
|
+
s += sgl.assistant(
|
401
|
+
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
402
|
+
)
|
403
|
+
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
404
|
+
s += sgl.assistant(
|
405
|
+
"Name:"
|
406
|
+
+ sgl.gen("name", stop="\n")
|
407
|
+
+ "\nBirthday:"
|
408
|
+
+ sgl.gen("birthday", stop="\n")
|
409
|
+
+ "\nJob:"
|
410
|
+
+ sgl.gen("job", stop="\n")
|
411
|
+
)
|
412
|
+
|
413
|
+
gen_character_spec().sync()
|
sglang/test/test_utils.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1
1
|
"""Common utilities for testing and benchmarking"""
|
2
2
|
|
3
|
+
import asyncio
|
4
|
+
from functools import partial
|
5
|
+
|
3
6
|
import numpy as np
|
4
7
|
import requests
|
8
|
+
|
5
9
|
from sglang.backend.openai import OpenAI
|
6
10
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
7
11
|
from sglang.global_config import global_config
|
12
|
+
from sglang.utils import get_exception_traceback
|
8
13
|
|
9
14
|
|
10
|
-
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
|
15
|
+
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
16
|
+
assert url is not None
|
17
|
+
|
11
18
|
data = {
|
12
19
|
"inputs": prompt,
|
13
20
|
"parameters": {
|
@@ -22,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
|
|
22
29
|
return pred
|
23
30
|
|
24
31
|
|
25
|
-
def call_generate_vllm(prompt, temperature, max_tokens, stop,
|
32
|
+
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
|
33
|
+
assert url is not None
|
34
|
+
|
26
35
|
data = {
|
27
36
|
"prompt": prompt,
|
28
37
|
"temperature": temperature,
|
@@ -40,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
|
40
49
|
|
41
50
|
|
42
51
|
def call_generate_outlines(
|
43
|
-
prompt, temperature, max_tokens,
|
52
|
+
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
|
44
53
|
):
|
54
|
+
assert url is not None
|
55
|
+
|
45
56
|
data = {
|
46
57
|
"prompt": prompt,
|
47
58
|
"temperature": temperature,
|
@@ -59,7 +70,9 @@ def call_generate_outlines(
|
|
59
70
|
return pred
|
60
71
|
|
61
72
|
|
62
|
-
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
73
|
+
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
74
|
+
assert url is not None
|
75
|
+
|
63
76
|
data = {
|
64
77
|
"text": prompt,
|
65
78
|
"sampling_params": {
|
@@ -75,7 +88,98 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
|
75
88
|
return pred
|
76
89
|
|
77
90
|
|
78
|
-
def
|
91
|
+
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
92
|
+
import grpc
|
93
|
+
from ginfer import sampler_pb2, sampler_pb2_grpc
|
94
|
+
|
95
|
+
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
96
|
+
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
97
|
+
|
98
|
+
if stop is None:
|
99
|
+
stop_strings = None
|
100
|
+
else:
|
101
|
+
stop_strings = [stop]
|
102
|
+
|
103
|
+
sample_request = sampler_pb2.SampleTextRequest(
|
104
|
+
prompt=prompt,
|
105
|
+
settings=sampler_pb2.SampleSettings(
|
106
|
+
max_len=max_tokens,
|
107
|
+
rng_seed=0,
|
108
|
+
temperature=max(temperature, 1e-7),
|
109
|
+
nucleus_p=1,
|
110
|
+
stop_strings=stop_strings,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
stream = sampler.SampleText(sample_request)
|
114
|
+
response = "".join([x.text for x in stream])
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
118
|
+
def call_generate_guidance(
|
119
|
+
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
120
|
+
):
|
121
|
+
assert model is not None
|
122
|
+
from guidance import gen
|
123
|
+
|
124
|
+
rets = []
|
125
|
+
for _ in range(n):
|
126
|
+
out = (
|
127
|
+
model
|
128
|
+
+ prompt
|
129
|
+
+ gen(
|
130
|
+
name="answer",
|
131
|
+
max_tokens=max_tokens,
|
132
|
+
temperature=temperature,
|
133
|
+
stop=stop,
|
134
|
+
regex=regex,
|
135
|
+
)
|
136
|
+
)
|
137
|
+
rets.append(out["answer"])
|
138
|
+
return rets if n > 1 else rets[0]
|
139
|
+
|
140
|
+
|
141
|
+
async def call_generate_lmql(
|
142
|
+
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
|
143
|
+
):
|
144
|
+
assert model is not None
|
145
|
+
import lmql
|
146
|
+
|
147
|
+
if stop != None:
|
148
|
+
|
149
|
+
@lmql.query(model=model)
|
150
|
+
async def program(question, max_tokens, stop):
|
151
|
+
'''lmql
|
152
|
+
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
|
153
|
+
return ANSWER
|
154
|
+
'''
|
155
|
+
|
156
|
+
else:
|
157
|
+
|
158
|
+
@lmql.query(model=model)
|
159
|
+
async def program(question, max_tokens):
|
160
|
+
'''lmql
|
161
|
+
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
|
162
|
+
return ANSWER
|
163
|
+
'''
|
164
|
+
|
165
|
+
tasks = [
|
166
|
+
program(
|
167
|
+
question=prompt,
|
168
|
+
temperature=temperature,
|
169
|
+
max_tokens=max_tokens,
|
170
|
+
stop=stop,
|
171
|
+
max_len=max_len,
|
172
|
+
**kwargs,
|
173
|
+
)
|
174
|
+
for _ in range(n)
|
175
|
+
]
|
176
|
+
rets = await asyncio.gather(*tasks)
|
177
|
+
return rets if n > 1 else rets[0]
|
178
|
+
|
179
|
+
|
180
|
+
def call_select_lightllm(context, choices, url=None):
|
181
|
+
assert url is not None
|
182
|
+
|
79
183
|
scores = []
|
80
184
|
for i in range(len(choices)):
|
81
185
|
data = {
|
@@ -90,7 +194,9 @@ def call_select_lightllm(context, choices, url):
|
|
90
194
|
return np.argmax(scores)
|
91
195
|
|
92
196
|
|
93
|
-
def call_select_vllm(context, choices, url):
|
197
|
+
def call_select_vllm(context, choices, url=None):
|
198
|
+
assert url is not None
|
199
|
+
|
94
200
|
scores = []
|
95
201
|
for i in range(len(choices)):
|
96
202
|
data = {
|
@@ -112,6 +218,31 @@ def call_select_vllm(context, choices, url):
|
|
112
218
|
"""
|
113
219
|
|
114
220
|
|
221
|
+
def call_select_guidance(context, choices, model=None):
|
222
|
+
assert model is not None
|
223
|
+
from guidance import select
|
224
|
+
|
225
|
+
out = model + context + select(choices, name="answer")
|
226
|
+
return choices.index(out["answer"])
|
227
|
+
|
228
|
+
|
229
|
+
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
|
230
|
+
assert model is not None
|
231
|
+
import lmql
|
232
|
+
|
233
|
+
@lmql.query(model=model)
|
234
|
+
async def program(ctx, choices):
|
235
|
+
'''lmql
|
236
|
+
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
|
237
|
+
return ANSWER
|
238
|
+
'''
|
239
|
+
|
240
|
+
answer = await program(
|
241
|
+
ctx=context, choices=choices, temperature=temperature, max_len=max_len
|
242
|
+
)
|
243
|
+
return choices.index(answer)
|
244
|
+
|
245
|
+
|
115
246
|
def add_common_other_args_and_parse(parser):
|
116
247
|
parser.add_argument("--parallel", type=int, default=64)
|
117
248
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
@@ -120,8 +251,18 @@ def add_common_other_args_and_parse(parser):
|
|
120
251
|
"--backend",
|
121
252
|
type=str,
|
122
253
|
required=True,
|
123
|
-
choices=[
|
254
|
+
choices=[
|
255
|
+
"vllm",
|
256
|
+
"outlines",
|
257
|
+
"lightllm",
|
258
|
+
"ginfer",
|
259
|
+
"guidance",
|
260
|
+
"lmql",
|
261
|
+
"srt-raw",
|
262
|
+
"llama.cpp",
|
263
|
+
],
|
124
264
|
)
|
265
|
+
parser.add_argument("--n-ctx", type=int, default=4096)
|
125
266
|
parser.add_argument(
|
126
267
|
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
|
127
268
|
)
|
@@ -131,9 +272,11 @@ def add_common_other_args_and_parse(parser):
|
|
131
272
|
if args.port is None:
|
132
273
|
default_port = {
|
133
274
|
"vllm": 21000,
|
275
|
+
"outlines": 21000,
|
134
276
|
"lightllm": 22000,
|
135
277
|
"lmql": 23000,
|
136
278
|
"srt-raw": 30000,
|
279
|
+
"ginfer": 9988,
|
137
280
|
}
|
138
281
|
args.port = default_port.get(args.backend, None)
|
139
282
|
return args
|
@@ -160,3 +303,79 @@ def select_sglang_backend(args):
|
|
160
303
|
else:
|
161
304
|
raise ValueError(f"Invalid backend: {args.backend}")
|
162
305
|
return backend
|
306
|
+
|
307
|
+
|
308
|
+
def _get_call_generate(args):
|
309
|
+
if args.backend == "lightllm":
|
310
|
+
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
|
311
|
+
elif args.backend == "vllm":
|
312
|
+
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
313
|
+
elif args.backend == "srt-raw":
|
314
|
+
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
315
|
+
elif args.backend == "ginfer":
|
316
|
+
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
|
317
|
+
elif args.backend == "outlines":
|
318
|
+
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
319
|
+
elif args.backend == "guidance":
|
320
|
+
from guidance import models
|
321
|
+
|
322
|
+
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
|
323
|
+
call_generate = partial(call_generate_guidance, model=model)
|
324
|
+
call_generate("Hello,", 1.0, 8, ".")
|
325
|
+
return call_generate
|
326
|
+
elif args.backend == "lmql":
|
327
|
+
import lmql
|
328
|
+
|
329
|
+
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
330
|
+
return partial(call_generate_lmql, model=model)
|
331
|
+
else:
|
332
|
+
raise ValueError(f"Invalid backend: {args.backend}")
|
333
|
+
|
334
|
+
|
335
|
+
def _get_call_select(args):
|
336
|
+
if args.backend == "lightllm":
|
337
|
+
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
|
338
|
+
elif args.backend == "vllm":
|
339
|
+
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
|
340
|
+
elif args.backend == "guidance":
|
341
|
+
from guidance import models
|
342
|
+
|
343
|
+
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
|
344
|
+
call_select = partial(call_select_guidance, model=model)
|
345
|
+
|
346
|
+
call_select("Hello,", ["world", "earth"])
|
347
|
+
return call_select
|
348
|
+
|
349
|
+
elif args.backend == "lmql":
|
350
|
+
import lmql
|
351
|
+
|
352
|
+
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
353
|
+
return partial(call_select_lmql, model=model)
|
354
|
+
else:
|
355
|
+
raise ValueError(f"Invalid backend: {args.backend}")
|
356
|
+
|
357
|
+
|
358
|
+
def get_call_generate(args):
|
359
|
+
call_generate = _get_call_generate(args)
|
360
|
+
|
361
|
+
def func(*args, **kwargs):
|
362
|
+
try:
|
363
|
+
return call_generate(*args, **kwargs)
|
364
|
+
except Exception:
|
365
|
+
print("Exception in call_generate:\n" + get_exception_traceback())
|
366
|
+
raise
|
367
|
+
|
368
|
+
return func
|
369
|
+
|
370
|
+
|
371
|
+
def get_call_select(args):
|
372
|
+
call_select = _get_call_select(args)
|
373
|
+
|
374
|
+
def func(*args, **kwargs):
|
375
|
+
try:
|
376
|
+
return call_select(*args, **kwargs)
|
377
|
+
except Exception:
|
378
|
+
print("Exception in call_select:\n" + get_exception_traceback())
|
379
|
+
raise
|
380
|
+
|
381
|
+
return func
|
sglang/utils.py
CHANGED
@@ -2,40 +2,26 @@
|
|
2
2
|
|
3
3
|
import base64
|
4
4
|
import json
|
5
|
+
import logging
|
6
|
+
import signal
|
7
|
+
import sys
|
5
8
|
import threading
|
9
|
+
import traceback
|
6
10
|
import urllib.request
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
7
12
|
from io import BytesIO
|
8
13
|
from json import dumps
|
9
14
|
|
15
|
+
import numpy as np
|
10
16
|
import requests
|
11
17
|
|
18
|
+
logger = logging.getLogger(__name__)
|
12
19
|
|
13
|
-
def get_available_gpu_memory(gpu_id, distributed=True):
|
14
|
-
"""
|
15
|
-
Get available memory for cuda:gpu_id device.
|
16
|
-
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
17
|
-
"""
|
18
|
-
import torch
|
19
20
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
print(
|
25
|
-
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
26
|
-
"which may cause useless memory allocation for torch CUDA context.",
|
27
|
-
)
|
28
|
-
|
29
|
-
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
30
|
-
|
31
|
-
if distributed:
|
32
|
-
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
33
|
-
torch.device("cuda", gpu_id)
|
34
|
-
)
|
35
|
-
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
36
|
-
free_gpu_memory = tensor.item()
|
37
|
-
|
38
|
-
return free_gpu_memory / (1 << 30)
|
21
|
+
def get_exception_traceback():
|
22
|
+
etype, value, tb = sys.exc_info()
|
23
|
+
err_str = "".join(traceback.format_exception(etype, value, tb))
|
24
|
+
return err_str
|
39
25
|
|
40
26
|
|
41
27
|
def is_same_type(values):
|
@@ -110,8 +96,12 @@ def http_request(
|
|
110
96
|
data = None
|
111
97
|
else:
|
112
98
|
data = bytes(dumps(json), encoding="utf-8")
|
113
|
-
|
114
|
-
|
99
|
+
|
100
|
+
try:
|
101
|
+
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
102
|
+
return HttpResponse(resp)
|
103
|
+
except urllib.error.HTTPError as e:
|
104
|
+
return HttpResponse(e)
|
115
105
|
|
116
106
|
|
117
107
|
def encode_image_base64(image_path):
|
@@ -130,6 +120,75 @@ def encode_image_base64(image_path):
|
|
130
120
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
131
121
|
|
132
122
|
|
123
|
+
def encode_frame(frame):
|
124
|
+
import cv2 # pip install opencv-python-headless
|
125
|
+
from PIL import Image
|
126
|
+
|
127
|
+
# Convert the frame to RGB (OpenCV uses BGR by default)
|
128
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
129
|
+
|
130
|
+
# Convert the frame to PIL Image to easily convert to bytes
|
131
|
+
im_pil = Image.fromarray(frame)
|
132
|
+
|
133
|
+
# Convert to bytes
|
134
|
+
buffered = BytesIO()
|
135
|
+
|
136
|
+
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
|
137
|
+
|
138
|
+
im_pil.save(buffered, format="PNG")
|
139
|
+
|
140
|
+
frame_bytes = buffered.getvalue()
|
141
|
+
|
142
|
+
# Return the bytes of the frame
|
143
|
+
return frame_bytes
|
144
|
+
|
145
|
+
|
146
|
+
def encode_video_base64(video_path, num_frames=16):
|
147
|
+
import cv2 # pip install opencv-python-headless
|
148
|
+
|
149
|
+
cap = cv2.VideoCapture(video_path)
|
150
|
+
if not cap.isOpened():
|
151
|
+
raise IOError(f"Could not open video file:{video_path}")
|
152
|
+
|
153
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
154
|
+
print(f"target_frames: {num_frames}")
|
155
|
+
|
156
|
+
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
157
|
+
|
158
|
+
frames = []
|
159
|
+
for i in range(total_frames):
|
160
|
+
ret, frame = cap.read()
|
161
|
+
if ret:
|
162
|
+
frames.append(frame)
|
163
|
+
else:
|
164
|
+
# Handle the case where the frame could not be read
|
165
|
+
# print(f"Warning: Could not read frame at index {i}.")
|
166
|
+
pass
|
167
|
+
|
168
|
+
cap.release()
|
169
|
+
|
170
|
+
# Safely select frames based on frame_indices, avoiding IndexError
|
171
|
+
frames = [frames[i] for i in frame_indices if i < len(frames)]
|
172
|
+
|
173
|
+
# If there are not enough frames, duplicate the last frame until we reach the target
|
174
|
+
while len(frames) < num_frames:
|
175
|
+
frames.append(frames[-1])
|
176
|
+
|
177
|
+
# Use ThreadPoolExecutor to process and encode frames in parallel
|
178
|
+
with ThreadPoolExecutor() as executor:
|
179
|
+
encoded_frames = list(executor.map(encode_frame, frames))
|
180
|
+
|
181
|
+
# encoded_frames = list(map(encode_frame, frames))
|
182
|
+
|
183
|
+
# Concatenate all frames bytes
|
184
|
+
video_bytes = b"".join(encoded_frames)
|
185
|
+
|
186
|
+
# Encode the concatenated bytes to base64
|
187
|
+
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
|
188
|
+
|
189
|
+
return video_base64
|
190
|
+
|
191
|
+
|
133
192
|
def _is_chinese_char(cp):
|
134
193
|
"""Checks whether CP is the codepoint of a CJK character."""
|
135
194
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
@@ -191,3 +250,14 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|
191
250
|
raise RuntimeError()
|
192
251
|
|
193
252
|
return ret_value[0]
|
253
|
+
|
254
|
+
|
255
|
+
def graceful_registry(sub_module_name):
|
256
|
+
def graceful_shutdown(signum, frame):
|
257
|
+
logger.info(
|
258
|
+
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
259
|
+
)
|
260
|
+
if signum == signal.SIGTERM:
|
261
|
+
logger.info(f"{sub_module_name} recive sigterm")
|
262
|
+
|
263
|
+
signal.signal(signal.SIGTERM, graceful_shutdown)
|