sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +145 -36
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -5
- sglang/srt/managers/schedule_batch.py +95 -324
- sglang/srt/managers/tokenizer_manager.py +6 -3
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +9 -14
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +6 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +34 -12
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/server.py +24 -6
- sglang/srt/server_args.py +4 -0
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -22,6 +22,11 @@ from sglang.api import (
|
|
22
22
|
user_end,
|
23
23
|
video,
|
24
24
|
)
|
25
|
+
from sglang.lang.choices import (
|
26
|
+
greedy_token_selection,
|
27
|
+
token_length_normalized,
|
28
|
+
unconditional_likelihood_normalized,
|
29
|
+
)
|
25
30
|
|
26
31
|
# SGLang DSL APIs
|
27
32
|
__all__ = [
|
@@ -45,6 +50,9 @@ __all__ = [
|
|
45
50
|
"user_begin",
|
46
51
|
"user_end",
|
47
52
|
"video",
|
53
|
+
"greedy_token_selection",
|
54
|
+
"token_length_normalized",
|
55
|
+
"unconditional_likelihood_normalized",
|
48
56
|
]
|
49
57
|
|
50
58
|
# Global Configurations
|
sglang/api.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
|
|
6
6
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.lang.backend.base_backend import BaseBackend
|
9
|
+
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
9
10
|
from sglang.lang.ir import (
|
10
11
|
SglExpr,
|
11
12
|
SglExprList,
|
@@ -73,12 +74,18 @@ def gen(
|
|
73
74
|
return_text_in_logprobs: Optional[bool] = None,
|
74
75
|
dtype: Optional[type] = None,
|
75
76
|
choices: Optional[List[str]] = None,
|
77
|
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
76
78
|
regex: Optional[str] = None,
|
77
79
|
):
|
78
80
|
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
79
81
|
|
80
82
|
if choices:
|
81
|
-
return SglSelect(
|
83
|
+
return SglSelect(
|
84
|
+
name,
|
85
|
+
choices,
|
86
|
+
0.0 if temperature is None else temperature,
|
87
|
+
token_length_normalized if choices_method is None else choices_method,
|
88
|
+
)
|
82
89
|
|
83
90
|
# check regex is valid
|
84
91
|
if regex is not None:
|
@@ -186,9 +193,10 @@ def select(
|
|
186
193
|
name: Optional[str] = None,
|
187
194
|
choices: Optional[List[str]] = None,
|
188
195
|
temperature: float = 0.0,
|
196
|
+
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
189
197
|
):
|
190
198
|
assert choices is not None
|
191
|
-
return SglSelect(name, choices, temperature)
|
199
|
+
return SglSelect(name, choices, temperature, choices_method)
|
192
200
|
|
193
201
|
|
194
202
|
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
sglang/bench_latency.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1
1
|
"""
|
2
2
|
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
3
|
|
4
|
-
# Usage (latency test)
|
4
|
+
# Usage (latency test)
|
5
|
+
## with dummy weights:
|
5
6
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
7
|
+
## sweep through multiple data points and store (append) the results in a jsonl file:
|
8
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
|
9
|
+
## do some changes, and store the results under a different run_name:
|
10
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
|
11
|
+
## plot the results in series of lines:
|
12
|
+
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
|
13
|
+
|
6
14
|
|
7
15
|
# Usage (correctness test):
|
8
16
|
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
17
|
|
10
|
-
|
18
|
+
## Reference output (of the correctness test above, can be gpu dependent):
|
11
19
|
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
20
|
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
21
|
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
@@ -28,19 +36,23 @@ I'm going to the park
|
|
28
36
|
|
29
37
|
import argparse
|
30
38
|
import dataclasses
|
39
|
+
import itertools
|
31
40
|
import logging
|
32
41
|
import multiprocessing
|
42
|
+
import os
|
43
|
+
import sqlite3
|
33
44
|
import time
|
34
45
|
from typing import Tuple
|
35
46
|
|
36
|
-
import jsonlines
|
37
47
|
import numpy as np
|
48
|
+
import pandas as pd
|
38
49
|
import torch
|
39
50
|
import torch.distributed as dist
|
40
51
|
|
41
52
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
42
|
-
from sglang.srt.managers.schedule_batch import
|
53
|
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
43
54
|
from sglang.srt.model_config import ModelConfig
|
55
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
44
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
45
57
|
from sglang.srt.sampling_params import SamplingParams
|
46
58
|
from sglang.srt.server_args import ServerArgs
|
@@ -49,26 +61,42 @@ from sglang.srt.utils import suppress_other_loggers
|
|
49
61
|
|
50
62
|
@dataclasses.dataclass
|
51
63
|
class BenchArgs:
|
64
|
+
run_name: str = "before"
|
52
65
|
batch_size: Tuple[int] = (1,)
|
53
|
-
input_len: int = 1024
|
54
|
-
output_len: int = 4
|
66
|
+
input_len: Tuple[int] = (1024,)
|
67
|
+
output_len: Tuple[int] = (4,)
|
55
68
|
result_filename: str = ""
|
56
69
|
correctness_test: bool = False
|
57
70
|
# This is only used for correctness test
|
58
71
|
cut_len: int = 4
|
72
|
+
# Plotting args
|
73
|
+
graph_sql: str = (
|
74
|
+
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
|
75
|
+
)
|
76
|
+
graph_filename: str = "out.png"
|
59
77
|
|
60
78
|
@staticmethod
|
61
79
|
def add_cli_args(parser: argparse.ArgumentParser):
|
80
|
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
62
81
|
parser.add_argument(
|
63
82
|
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
64
83
|
)
|
65
|
-
parser.add_argument(
|
66
|
-
|
84
|
+
parser.add_argument(
|
85
|
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
86
|
+
)
|
87
|
+
parser.add_argument(
|
88
|
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
89
|
+
)
|
67
90
|
parser.add_argument(
|
68
91
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
69
92
|
)
|
70
93
|
parser.add_argument("--correctness-test", action="store_true")
|
71
94
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
95
|
+
# graphing
|
96
|
+
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
|
97
|
+
parser.add_argument(
|
98
|
+
"--graph-filename", type=str, default=BenchArgs.graph_filename
|
99
|
+
)
|
72
100
|
|
73
101
|
@classmethod
|
74
102
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -161,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
161
189
|
|
162
190
|
|
163
191
|
def extend(reqs, model_runner):
|
164
|
-
batch =
|
192
|
+
batch = ScheduleBatch.init_new(
|
165
193
|
reqs=reqs,
|
166
194
|
req_to_token_pool=model_runner.req_to_token_pool,
|
167
195
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
@@ -222,15 +250,21 @@ def correctness_test(
|
|
222
250
|
|
223
251
|
@torch.inference_mode()
|
224
252
|
def latency_test_run_once(
|
225
|
-
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
253
|
+
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
226
254
|
):
|
255
|
+
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
256
|
+
if batch_size > max_batch_size:
|
257
|
+
rank_print(
|
258
|
+
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
259
|
+
)
|
260
|
+
return
|
227
261
|
|
228
262
|
# Clear the pools.
|
229
263
|
model_runner.req_to_token_pool.clear()
|
230
264
|
model_runner.token_to_kv_pool.clear()
|
231
265
|
|
232
266
|
measurement_results = {
|
233
|
-
"run_name":
|
267
|
+
"run_name": run_name,
|
234
268
|
"batch_size": batch_size,
|
235
269
|
"input_len": input_len,
|
236
270
|
"output_len": output_len,
|
@@ -291,49 +325,119 @@ def latency_test(
|
|
291
325
|
|
292
326
|
# Load the model
|
293
327
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
294
|
-
rank_print(
|
295
|
-
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
296
|
-
)
|
297
328
|
|
298
|
-
#
|
299
|
-
bench_args.batch_size = bench_args.batch_size[0]
|
300
|
-
|
301
|
-
# Prepare inputs
|
329
|
+
# Prepare inputs for warm up
|
302
330
|
reqs = prepare_synthetic_inputs_for_latency_test(
|
303
|
-
bench_args.batch_size, bench_args.input_len
|
331
|
+
bench_args.batch_size[0], bench_args.input_len[0]
|
304
332
|
)
|
305
333
|
|
306
334
|
# Warm up
|
307
335
|
latency_test_run_once(
|
308
|
-
|
336
|
+
bench_args.run_name,
|
337
|
+
model_runner,
|
338
|
+
rank_print,
|
339
|
+
reqs,
|
340
|
+
bench_args.batch_size[0],
|
341
|
+
bench_args.input_len[0],
|
342
|
+
4, # shorter decoding to speed up the warmup
|
309
343
|
)
|
310
344
|
|
311
|
-
# Run
|
345
|
+
# Run the sweep
|
312
346
|
result_list = []
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
bench_args.
|
319
|
-
bench_args.input_len,
|
320
|
-
bench_args.output_len,
|
347
|
+
for bs, il, ol in itertools.product(
|
348
|
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
349
|
+
):
|
350
|
+
req = prepare_synthetic_inputs_for_latency_test(bs, il)
|
351
|
+
ret = latency_test_run_once(
|
352
|
+
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
|
321
353
|
)
|
322
|
-
|
354
|
+
if ret is not None:
|
355
|
+
result_list.append(ret)
|
356
|
+
|
357
|
+
# Write results in jsonlines format on rank 0.
|
358
|
+
if tp_rank == 0 and bench_args.result_filename:
|
359
|
+
import jsonlines
|
323
360
|
|
324
|
-
# Write results in jsonlines format.
|
325
|
-
if bench_args.result_filename:
|
326
361
|
with jsonlines.open(bench_args.result_filename, "a") as f:
|
327
362
|
f.write_all(result_list)
|
328
363
|
|
329
364
|
|
365
|
+
def plot_latency_test(
|
366
|
+
server_args,
|
367
|
+
bench_args,
|
368
|
+
tp_rank,
|
369
|
+
):
|
370
|
+
assert tp_rank == 0
|
371
|
+
|
372
|
+
# read the jsonl file and put in sqlite
|
373
|
+
df = pd.read_json(bench_args.result_filename, lines=True)
|
374
|
+
conn = sqlite3.connect(":memory:")
|
375
|
+
cur = conn.cursor()
|
376
|
+
|
377
|
+
# get the columns and their types
|
378
|
+
column_names = list(df.iloc[0].keys())
|
379
|
+
type_dict = {
|
380
|
+
str: "TEXT",
|
381
|
+
np.int64: "INTEGER",
|
382
|
+
np.float64: "FLOAT",
|
383
|
+
}
|
384
|
+
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
|
385
|
+
|
386
|
+
# create the table
|
387
|
+
cur.execute(
|
388
|
+
f"""
|
389
|
+
CREATE TABLE IF NOT EXISTS results (
|
390
|
+
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
|
391
|
+
)
|
392
|
+
"""
|
393
|
+
)
|
394
|
+
conn.commit()
|
395
|
+
|
396
|
+
# write the results to DB
|
397
|
+
df.to_sql("results", conn, if_exists="replace", index=False)
|
398
|
+
conn.commit()
|
399
|
+
|
400
|
+
# read it back using sql
|
401
|
+
df = pd.read_sql_query(bench_args.graph_sql, conn)
|
402
|
+
conn.close()
|
403
|
+
|
404
|
+
# plot it and save to a file
|
405
|
+
import matplotlib.pyplot as plt
|
406
|
+
|
407
|
+
assert (
|
408
|
+
len(df.columns) == 3
|
409
|
+
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
|
410
|
+
for label in df[df.columns[0]].unique():
|
411
|
+
q = f"{df.columns[0]}=='{label}'"
|
412
|
+
series = df.query(q)
|
413
|
+
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
|
414
|
+
plt.xlabel(df.columns[1])
|
415
|
+
plt.ylabel(df.columns[2])
|
416
|
+
plt.legend()
|
417
|
+
plt.savefig(bench_args.graph_filename, dpi=300)
|
418
|
+
|
419
|
+
# if in kitty, just dump it to the terminal
|
420
|
+
if os.environ["TERM"] == "xterm-kitty":
|
421
|
+
os.system(
|
422
|
+
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
|
423
|
+
)
|
424
|
+
|
425
|
+
|
330
426
|
def main(server_args, bench_args):
|
331
|
-
print(bench_args)
|
332
427
|
|
333
|
-
if
|
334
|
-
|
428
|
+
if server_args.model_path:
|
429
|
+
if bench_args.correctness_test:
|
430
|
+
work_func = correctness_test
|
431
|
+
else:
|
432
|
+
work_func = latency_test
|
433
|
+
elif os.path.isfile(bench_args.result_filename):
|
434
|
+
assert bench_args.graph_filename, "please provide a filename for the graph"
|
435
|
+
work_func = plot_latency_test
|
335
436
|
else:
|
336
|
-
|
437
|
+
raise ValueError(
|
438
|
+
"Provide --model-path for running the tests or "
|
439
|
+
"provide --result-filename for plotting the results"
|
440
|
+
)
|
337
441
|
|
338
442
|
if server_args.tp_size == 1:
|
339
443
|
work_func(server_args, bench_args, 0)
|
@@ -361,6 +465,11 @@ if __name__ == "__main__":
|
|
361
465
|
parser = argparse.ArgumentParser()
|
362
466
|
ServerArgs.add_cli_args(parser)
|
363
467
|
BenchArgs.add_cli_args(parser)
|
468
|
+
# For this script, model-path is not required
|
469
|
+
assert (
|
470
|
+
parser._actions[1].option_strings[0] == "--model-path"
|
471
|
+
), "options changed, this code need to be updated"
|
472
|
+
parser._actions[1].required = False
|
364
473
|
args = parser.parse_args()
|
365
474
|
|
366
475
|
server_args = ServerArgs.from_cli_args(args)
|
sglang/check_env.py
CHANGED
@@ -14,6 +14,7 @@ PACKAGE_LIST = [
|
|
14
14
|
"sglang",
|
15
15
|
"flashinfer",
|
16
16
|
"triton",
|
17
|
+
"transformers",
|
17
18
|
"requests",
|
18
19
|
"tqdm",
|
19
20
|
"numpy",
|
@@ -73,10 +74,26 @@ def _get_gpu_info():
|
|
73
74
|
Get information about available GPUs.
|
74
75
|
"""
|
75
76
|
devices = defaultdict(list)
|
77
|
+
capabilities = defaultdict(list)
|
76
78
|
for k in range(torch.cuda.device_count()):
|
77
79
|
devices[torch.cuda.get_device_name(k)].append(str(k))
|
80
|
+
capability = torch.cuda.get_device_capability(k)
|
81
|
+
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
78
82
|
|
79
|
-
|
83
|
+
gpu_info = {}
|
84
|
+
for name, device_ids in devices.items():
|
85
|
+
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
86
|
+
|
87
|
+
if len(capabilities) == 1:
|
88
|
+
# All GPUs have the same compute capability
|
89
|
+
cap, gpu_ids = list(capabilities.items())[0]
|
90
|
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
91
|
+
else:
|
92
|
+
# GPUs have different compute capabilities
|
93
|
+
for cap, gpu_ids in capabilities.items():
|
94
|
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
95
|
+
|
96
|
+
return gpu_info
|
80
97
|
|
81
98
|
|
82
99
|
def _get_cuda_version_info():
|
@@ -118,6 +135,7 @@ def _get_cuda_driver_version():
|
|
118
135
|
"""
|
119
136
|
Get CUDA driver version.
|
120
137
|
"""
|
138
|
+
versions = set()
|
121
139
|
try:
|
122
140
|
output = subprocess.check_output(
|
123
141
|
[
|
@@ -126,7 +144,11 @@ def _get_cuda_driver_version():
|
|
126
144
|
"--format=csv,noheader,nounits",
|
127
145
|
]
|
128
146
|
)
|
129
|
-
|
147
|
+
versions = set(output.decode().strip().split("\n"))
|
148
|
+
if len(versions) == 1:
|
149
|
+
return {"CUDA Driver Version": versions.pop()}
|
150
|
+
else:
|
151
|
+
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
130
152
|
except subprocess.SubprocessError:
|
131
153
|
return {"CUDA Driver Version": "Not Available"}
|
132
154
|
|
sglang/global_config.py
CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|
19
19
|
self.init_new_token_ratio = 0.7
|
20
20
|
self.base_min_new_token_ratio = 0.1
|
21
21
|
self.new_token_ratio_decay = 0.001
|
22
|
-
self.new_token_ratio_recovery = 0.05
|
23
22
|
|
24
23
|
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
25
24
|
# This can improve the speed for large batch sizes during prefill.
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import Callable, List, Optional, Union
|
2
2
|
|
3
3
|
from sglang.lang.chat_template import get_chat_template
|
4
|
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
4
5
|
from sglang.lang.interpreter import StreamExecutor
|
5
6
|
from sglang.lang.ir import SglSamplingParams
|
6
7
|
|
@@ -64,7 +65,8 @@ class BaseBackend:
|
|
64
65
|
s: StreamExecutor,
|
65
66
|
choices: List[str],
|
66
67
|
temperature: float,
|
67
|
-
|
68
|
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
69
|
+
) -> ChoicesDecision:
|
68
70
|
raise NotImplementedError()
|
69
71
|
|
70
72
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
sglang/lang/backend/openai.py
CHANGED
@@ -8,6 +8,7 @@ import numpy as np
|
|
8
8
|
|
9
9
|
from sglang.lang.backend.base_backend import BaseBackend
|
10
10
|
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
11
|
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
11
12
|
from sglang.lang.interpreter import StreamExecutor
|
12
13
|
from sglang.lang.ir import SglSamplingParams
|
13
14
|
|
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
|
|
296
297
|
s: StreamExecutor,
|
297
298
|
choices: List[str],
|
298
299
|
temperature: float,
|
299
|
-
|
300
|
+
choices_method: ChoicesSamplingMethod,
|
301
|
+
) -> ChoicesDecision:
|
302
|
+
"""Note: `choices_method` is not used by the OpenAI backend."""
|
300
303
|
if self.is_chat_model:
|
301
304
|
raise NotImplementedError(
|
302
305
|
"select/choices is not supported for chat models. "
|
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
|
|
354
357
|
|
355
358
|
prompt_tokens.append(ret_token)
|
356
359
|
|
357
|
-
|
358
|
-
|
360
|
+
return ChoicesDecision(
|
361
|
+
decision=choices[np.argmax(scores)],
|
362
|
+
meta_info={"scores": scores},
|
363
|
+
)
|
359
364
|
|
360
365
|
|
361
366
|
def openai_completion(
|
@@ -1,17 +1,21 @@
|
|
1
1
|
import json
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
import numpy as np
|
5
|
-
|
6
4
|
from sglang.global_config import global_config
|
7
5
|
from sglang.lang.backend.base_backend import BaseBackend
|
8
6
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
7
|
+
from sglang.lang.choices import (
|
8
|
+
ChoicesDecision,
|
9
|
+
ChoicesSamplingMethod,
|
10
|
+
token_length_normalized,
|
11
|
+
)
|
9
12
|
from sglang.lang.interpreter import StreamExecutor
|
10
13
|
from sglang.lang.ir import SglSamplingParams
|
11
14
|
from sglang.utils import http_request
|
12
15
|
|
13
16
|
|
14
17
|
class RuntimeEndpoint(BaseBackend):
|
18
|
+
|
15
19
|
def __init__(
|
16
20
|
self,
|
17
21
|
base_url: str,
|
@@ -43,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
43
47
|
def flush_cache(self):
|
44
48
|
res = http_request(
|
45
49
|
self.base_url + "/flush_cache",
|
46
|
-
|
50
|
+
api_key=self.api_key,
|
47
51
|
verify=self.verify,
|
48
52
|
)
|
49
53
|
self._assert_success(res)
|
@@ -51,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
51
55
|
def get_server_args(self):
|
52
56
|
res = http_request(
|
53
57
|
self.base_url + "/get_server_args",
|
54
|
-
|
58
|
+
api_key=self.api_key,
|
55
59
|
verify=self.verify,
|
56
60
|
)
|
57
61
|
self._assert_success(res)
|
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
|
|
208
212
|
s: StreamExecutor,
|
209
213
|
choices: List[str],
|
210
214
|
temperature: float,
|
211
|
-
|
215
|
+
choices_method: ChoicesSamplingMethod,
|
216
|
+
) -> ChoicesDecision:
|
212
217
|
assert temperature <= 1e-5
|
213
218
|
|
214
219
|
# Cache common prefix
|
215
220
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
216
|
-
self.
|
217
|
-
|
218
|
-
self.base_url + "/generate",
|
219
|
-
json=data,
|
220
|
-
api_key=self.api_key,
|
221
|
-
verify=self.verify,
|
222
|
-
)
|
223
|
-
self._assert_success(res)
|
224
|
-
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
221
|
+
obj = self._generate_http_request(s, data)
|
222
|
+
prompt_len = obj["meta_info"]["prompt_tokens"]
|
225
223
|
|
226
224
|
# Compute logprob
|
227
225
|
data = {
|
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
|
|
230
228
|
"return_logprob": True,
|
231
229
|
"logprob_start_len": max(prompt_len - 2, 0),
|
232
230
|
}
|
233
|
-
self.
|
234
|
-
|
235
|
-
self.base_url + "/generate",
|
236
|
-
json=data,
|
237
|
-
api_key=self.api_key,
|
238
|
-
verify=self.verify,
|
239
|
-
)
|
240
|
-
self._assert_success(res)
|
241
|
-
obj = res.json()
|
231
|
+
obj = self._generate_http_request(s, data)
|
232
|
+
|
242
233
|
normalized_prompt_logprobs = [
|
243
234
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
244
235
|
]
|
245
|
-
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
246
236
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
247
237
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
248
238
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
239
|
+
# Compute unconditional logprobs if required
|
240
|
+
if choices_method.requires_unconditional_logprobs:
|
241
|
+
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
242
|
+
data = {
|
243
|
+
"input_ids": input_ids,
|
244
|
+
"sampling_params": {"max_new_tokens": 0},
|
245
|
+
"return_logprob": True,
|
246
|
+
}
|
247
|
+
obj = self._generate_http_request(s, data)
|
248
|
+
unconditional_token_logprobs = [
|
249
|
+
r["meta_info"]["input_token_logprobs"] for r in obj
|
250
|
+
]
|
251
|
+
else:
|
252
|
+
unconditional_token_logprobs = None
|
253
|
+
|
254
|
+
return choices_method(
|
255
|
+
choices=choices,
|
256
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
257
|
+
input_token_logprobs=input_token_logprobs,
|
258
|
+
output_token_logprobs=output_token_logprobs,
|
259
|
+
unconditional_token_logprobs=unconditional_token_logprobs,
|
254
260
|
)
|
255
261
|
|
256
262
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
|
|
262
268
|
)
|
263
269
|
self._assert_success(res)
|
264
270
|
|
271
|
+
def _generate_http_request(self, s: StreamExecutor, data):
|
272
|
+
self._add_images(s, data)
|
273
|
+
res = http_request(
|
274
|
+
self.base_url + "/generate",
|
275
|
+
json=data,
|
276
|
+
api_key=self.api_key,
|
277
|
+
verify=self.verify,
|
278
|
+
)
|
279
|
+
self._assert_success(res)
|
280
|
+
return res.json()
|
281
|
+
|
265
282
|
def _add_images(self, s: StreamExecutor, data):
|
266
283
|
if s.images_:
|
267
284
|
assert len(s.images_) == 1, "Only support one image."
|