sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -22,6 +22,11 @@ from sglang.api import (
|
|
22
22
|
user_end,
|
23
23
|
video,
|
24
24
|
)
|
25
|
+
from sglang.lang.choices import (
|
26
|
+
greedy_token_selection,
|
27
|
+
token_length_normalized,
|
28
|
+
unconditional_likelihood_normalized,
|
29
|
+
)
|
25
30
|
|
26
31
|
# SGLang DSL APIs
|
27
32
|
__all__ = [
|
@@ -45,6 +50,9 @@ __all__ = [
|
|
45
50
|
"user_begin",
|
46
51
|
"user_end",
|
47
52
|
"video",
|
53
|
+
"greedy_token_selection",
|
54
|
+
"token_length_normalized",
|
55
|
+
"unconditional_likelihood_normalized",
|
48
56
|
]
|
49
57
|
|
50
58
|
# Global Configurations
|
sglang/api.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
|
|
6
6
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.lang.backend.base_backend import BaseBackend
|
9
|
+
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
9
10
|
from sglang.lang.ir import (
|
10
11
|
SglExpr,
|
11
12
|
SglExprList,
|
@@ -73,12 +74,18 @@ def gen(
|
|
73
74
|
return_text_in_logprobs: Optional[bool] = None,
|
74
75
|
dtype: Optional[type] = None,
|
75
76
|
choices: Optional[List[str]] = None,
|
77
|
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
76
78
|
regex: Optional[str] = None,
|
77
79
|
):
|
78
80
|
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
79
81
|
|
80
82
|
if choices:
|
81
|
-
return SglSelect(
|
83
|
+
return SglSelect(
|
84
|
+
name,
|
85
|
+
choices,
|
86
|
+
0.0 if temperature is None else temperature,
|
87
|
+
token_length_normalized if choices_method is None else choices_method,
|
88
|
+
)
|
82
89
|
|
83
90
|
# check regex is valid
|
84
91
|
if regex is not None:
|
@@ -186,9 +193,10 @@ def select(
|
|
186
193
|
name: Optional[str] = None,
|
187
194
|
choices: Optional[List[str]] = None,
|
188
195
|
temperature: float = 0.0,
|
196
|
+
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
189
197
|
):
|
190
198
|
assert choices is not None
|
191
|
-
return SglSelect(name, choices, temperature)
|
199
|
+
return SglSelect(name, choices, temperature, choices_method)
|
192
200
|
|
193
201
|
|
194
202
|
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
sglang/bench_latency.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1
1
|
"""
|
2
2
|
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
3
|
|
4
|
-
# Usage (latency test)
|
4
|
+
# Usage (latency test)
|
5
|
+
## with dummy weights:
|
5
6
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
7
|
+
## sweep through multiple data points and store (append) the results in a jsonl file:
|
8
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
|
9
|
+
## do some changes, and store the results under a different run_name:
|
10
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
|
11
|
+
## plot the results in series of lines:
|
12
|
+
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
|
13
|
+
|
6
14
|
|
7
15
|
# Usage (correctness test):
|
8
16
|
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
17
|
|
10
|
-
|
18
|
+
## Reference output (of the correctness test above, can be gpu dependent):
|
11
19
|
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
20
|
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
21
|
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
@@ -28,19 +36,23 @@ I'm going to the park
|
|
28
36
|
|
29
37
|
import argparse
|
30
38
|
import dataclasses
|
39
|
+
import itertools
|
31
40
|
import logging
|
32
41
|
import multiprocessing
|
42
|
+
import os
|
43
|
+
import sqlite3
|
33
44
|
import time
|
34
45
|
from typing import Tuple
|
35
46
|
|
36
|
-
import jsonlines
|
37
47
|
import numpy as np
|
48
|
+
import pandas as pd
|
38
49
|
import torch
|
39
50
|
import torch.distributed as dist
|
40
51
|
|
41
52
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
42
|
-
from sglang.srt.managers.schedule_batch import
|
53
|
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
43
54
|
from sglang.srt.model_config import ModelConfig
|
55
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
44
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
45
57
|
from sglang.srt.sampling_params import SamplingParams
|
46
58
|
from sglang.srt.server_args import ServerArgs
|
@@ -49,26 +61,42 @@ from sglang.srt.utils import suppress_other_loggers
|
|
49
61
|
|
50
62
|
@dataclasses.dataclass
|
51
63
|
class BenchArgs:
|
64
|
+
run_name: str = "before"
|
52
65
|
batch_size: Tuple[int] = (1,)
|
53
|
-
input_len: int = 1024
|
54
|
-
output_len: int = 4
|
66
|
+
input_len: Tuple[int] = (1024,)
|
67
|
+
output_len: Tuple[int] = (4,)
|
55
68
|
result_filename: str = ""
|
56
69
|
correctness_test: bool = False
|
57
70
|
# This is only used for correctness test
|
58
71
|
cut_len: int = 4
|
72
|
+
# Plotting args
|
73
|
+
graph_sql: str = (
|
74
|
+
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
|
75
|
+
)
|
76
|
+
graph_filename: str = "out.png"
|
59
77
|
|
60
78
|
@staticmethod
|
61
79
|
def add_cli_args(parser: argparse.ArgumentParser):
|
80
|
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
62
81
|
parser.add_argument(
|
63
82
|
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
64
83
|
)
|
65
|
-
parser.add_argument(
|
66
|
-
|
84
|
+
parser.add_argument(
|
85
|
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
86
|
+
)
|
87
|
+
parser.add_argument(
|
88
|
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
89
|
+
)
|
67
90
|
parser.add_argument(
|
68
91
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
69
92
|
)
|
70
93
|
parser.add_argument("--correctness-test", action="store_true")
|
71
94
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
95
|
+
# graphing
|
96
|
+
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
|
97
|
+
parser.add_argument(
|
98
|
+
"--graph-filename", type=str, default=BenchArgs.graph_filename
|
99
|
+
)
|
72
100
|
|
73
101
|
@classmethod
|
74
102
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -124,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
124
152
|
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
125
153
|
req.prefix_indices = []
|
126
154
|
req.sampling_params = sampling_params
|
127
|
-
req.
|
155
|
+
req.fill_ids = req.origin_input_ids
|
128
156
|
reqs.append(req)
|
129
157
|
|
130
158
|
return input_ids, reqs
|
@@ -135,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
135
163
|
):
|
136
164
|
for i in range(len(reqs)):
|
137
165
|
req = reqs[i]
|
138
|
-
req.
|
166
|
+
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
139
167
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
140
168
|
i, : bench_args.cut_len
|
141
169
|
]
|
@@ -154,14 +182,14 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
154
182
|
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
155
183
|
req.prefix_indices = []
|
156
184
|
req.sampling_params = sampling_params
|
157
|
-
req.
|
185
|
+
req.fill_ids = req.origin_input_ids
|
158
186
|
reqs.append(req)
|
159
187
|
|
160
188
|
return reqs
|
161
189
|
|
162
190
|
|
163
191
|
def extend(reqs, model_runner):
|
164
|
-
batch =
|
192
|
+
batch = ScheduleBatch.init_new(
|
165
193
|
reqs=reqs,
|
166
194
|
req_to_token_pool=model_runner.req_to_token_pool,
|
167
195
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
@@ -210,7 +238,7 @@ def correctness_test(
|
|
210
238
|
|
211
239
|
# Decode
|
212
240
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
213
|
-
for _ in range(bench_args.output_len):
|
241
|
+
for _ in range(bench_args.output_len[0]):
|
214
242
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
215
243
|
for i in range(len(reqs)):
|
216
244
|
output_ids[i].append(next_token_ids[i])
|
@@ -222,15 +250,21 @@ def correctness_test(
|
|
222
250
|
|
223
251
|
@torch.inference_mode()
|
224
252
|
def latency_test_run_once(
|
225
|
-
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
253
|
+
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
226
254
|
):
|
255
|
+
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
256
|
+
if batch_size > max_batch_size:
|
257
|
+
rank_print(
|
258
|
+
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
259
|
+
)
|
260
|
+
return
|
227
261
|
|
228
262
|
# Clear the pools.
|
229
263
|
model_runner.req_to_token_pool.clear()
|
230
264
|
model_runner.token_to_kv_pool.clear()
|
231
265
|
|
232
266
|
measurement_results = {
|
233
|
-
"run_name":
|
267
|
+
"run_name": run_name,
|
234
268
|
"batch_size": batch_size,
|
235
269
|
"input_len": input_len,
|
236
270
|
"output_len": output_len,
|
@@ -291,49 +325,121 @@ def latency_test(
|
|
291
325
|
|
292
326
|
# Load the model
|
293
327
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
294
|
-
rank_print(
|
295
|
-
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
296
|
-
)
|
297
328
|
|
298
|
-
#
|
299
|
-
bench_args.batch_size = bench_args.batch_size[0]
|
300
|
-
|
301
|
-
# Prepare inputs
|
329
|
+
# Prepare inputs for warm up
|
302
330
|
reqs = prepare_synthetic_inputs_for_latency_test(
|
303
|
-
bench_args.batch_size, bench_args.input_len
|
331
|
+
bench_args.batch_size[0], bench_args.input_len[0]
|
304
332
|
)
|
305
333
|
|
306
334
|
# Warm up
|
335
|
+
rank_print("Warmup ...")
|
307
336
|
latency_test_run_once(
|
308
|
-
|
337
|
+
bench_args.run_name,
|
338
|
+
model_runner,
|
339
|
+
rank_print,
|
340
|
+
reqs,
|
341
|
+
bench_args.batch_size[0],
|
342
|
+
bench_args.input_len[0],
|
343
|
+
4, # shorter decoding to speed up the warmup
|
309
344
|
)
|
345
|
+
rank_print("Benchmark ...")
|
310
346
|
|
311
|
-
# Run
|
347
|
+
# Run the sweep
|
312
348
|
result_list = []
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
bench_args.
|
319
|
-
bench_args.input_len,
|
320
|
-
bench_args.output_len,
|
349
|
+
for bs, il, ol in itertools.product(
|
350
|
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
351
|
+
):
|
352
|
+
req = prepare_synthetic_inputs_for_latency_test(bs, il)
|
353
|
+
ret = latency_test_run_once(
|
354
|
+
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
|
321
355
|
)
|
322
|
-
|
356
|
+
if ret is not None:
|
357
|
+
result_list.append(ret)
|
358
|
+
|
359
|
+
# Write results in jsonlines format on rank 0.
|
360
|
+
if tp_rank == 0 and bench_args.result_filename:
|
361
|
+
import jsonlines
|
323
362
|
|
324
|
-
# Write results in jsonlines format.
|
325
|
-
if bench_args.result_filename:
|
326
363
|
with jsonlines.open(bench_args.result_filename, "a") as f:
|
327
364
|
f.write_all(result_list)
|
328
365
|
|
329
366
|
|
367
|
+
def plot_latency_test(
|
368
|
+
server_args,
|
369
|
+
bench_args,
|
370
|
+
tp_rank,
|
371
|
+
):
|
372
|
+
assert tp_rank == 0
|
373
|
+
|
374
|
+
# read the jsonl file and put in sqlite
|
375
|
+
df = pd.read_json(bench_args.result_filename, lines=True)
|
376
|
+
conn = sqlite3.connect(":memory:")
|
377
|
+
cur = conn.cursor()
|
378
|
+
|
379
|
+
# get the columns and their types
|
380
|
+
column_names = list(df.iloc[0].keys())
|
381
|
+
type_dict = {
|
382
|
+
str: "TEXT",
|
383
|
+
np.int64: "INTEGER",
|
384
|
+
np.float64: "FLOAT",
|
385
|
+
}
|
386
|
+
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
|
387
|
+
|
388
|
+
# create the table
|
389
|
+
cur.execute(
|
390
|
+
f"""
|
391
|
+
CREATE TABLE IF NOT EXISTS results (
|
392
|
+
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
|
393
|
+
)
|
394
|
+
"""
|
395
|
+
)
|
396
|
+
conn.commit()
|
397
|
+
|
398
|
+
# write the results to DB
|
399
|
+
df.to_sql("results", conn, if_exists="replace", index=False)
|
400
|
+
conn.commit()
|
401
|
+
|
402
|
+
# read it back using sql
|
403
|
+
df = pd.read_sql_query(bench_args.graph_sql, conn)
|
404
|
+
conn.close()
|
405
|
+
|
406
|
+
# plot it and save to a file
|
407
|
+
import matplotlib.pyplot as plt
|
408
|
+
|
409
|
+
assert (
|
410
|
+
len(df.columns) == 3
|
411
|
+
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
|
412
|
+
for label in df[df.columns[0]].unique():
|
413
|
+
q = f"{df.columns[0]}=='{label}'"
|
414
|
+
series = df.query(q)
|
415
|
+
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
|
416
|
+
plt.xlabel(df.columns[1])
|
417
|
+
plt.ylabel(df.columns[2])
|
418
|
+
plt.legend()
|
419
|
+
plt.savefig(bench_args.graph_filename, dpi=300)
|
420
|
+
|
421
|
+
# if in kitty, just dump it to the terminal
|
422
|
+
if os.environ["TERM"] == "xterm-kitty":
|
423
|
+
os.system(
|
424
|
+
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
|
425
|
+
)
|
426
|
+
|
427
|
+
|
330
428
|
def main(server_args, bench_args):
|
331
|
-
print(bench_args)
|
332
429
|
|
333
|
-
if
|
334
|
-
|
430
|
+
if server_args.model_path:
|
431
|
+
if bench_args.correctness_test:
|
432
|
+
work_func = correctness_test
|
433
|
+
else:
|
434
|
+
work_func = latency_test
|
435
|
+
elif os.path.isfile(bench_args.result_filename):
|
436
|
+
assert bench_args.graph_filename, "please provide a filename for the graph"
|
437
|
+
work_func = plot_latency_test
|
335
438
|
else:
|
336
|
-
|
439
|
+
raise ValueError(
|
440
|
+
"Provide --model-path for running the tests or "
|
441
|
+
"provide --result-filename for plotting the results"
|
442
|
+
)
|
337
443
|
|
338
444
|
if server_args.tp_size == 1:
|
339
445
|
work_func(server_args, bench_args, 0)
|
@@ -361,6 +467,11 @@ if __name__ == "__main__":
|
|
361
467
|
parser = argparse.ArgumentParser()
|
362
468
|
ServerArgs.add_cli_args(parser)
|
363
469
|
BenchArgs.add_cli_args(parser)
|
470
|
+
# For this script, model-path is not required
|
471
|
+
assert (
|
472
|
+
parser._actions[1].option_strings[0] == "--model-path"
|
473
|
+
), "options changed, this code need to be updated"
|
474
|
+
parser._actions[1].required = False
|
364
475
|
args = parser.parse_args()
|
365
476
|
|
366
477
|
server_args = ServerArgs.from_cli_args(args)
|
sglang/bench_serving.py
CHANGED
@@ -24,7 +24,7 @@ import warnings
|
|
24
24
|
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
|
-
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
27
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
28
28
|
|
29
29
|
import aiohttp
|
30
30
|
import numpy as np
|
@@ -39,6 +39,8 @@ from transformers import (
|
|
39
39
|
|
40
40
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
41
41
|
|
42
|
+
global args
|
43
|
+
|
42
44
|
|
43
45
|
@dataclass
|
44
46
|
class RequestFuncInput:
|
@@ -47,6 +49,7 @@ class RequestFuncInput:
|
|
47
49
|
prompt_len: int
|
48
50
|
output_len: int
|
49
51
|
model: str
|
52
|
+
extra_request_body: Dict[str, Any]
|
50
53
|
|
51
54
|
|
52
55
|
@dataclass
|
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
|
|
84
87
|
"stream": True,
|
85
88
|
"min_length": request_func_input.output_len,
|
86
89
|
"end_id": 1048576,
|
90
|
+
**request_func_input.extra_request_body,
|
87
91
|
}
|
88
92
|
if args.disable_ignore_eos:
|
89
93
|
del payload["min_length"]
|
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
|
|
154
158
|
"max_tokens": request_func_input.output_len,
|
155
159
|
"stream": not args.disable_stream,
|
156
160
|
"ignore_eos": not args.disable_ignore_eos,
|
161
|
+
**request_func_input.extra_request_body,
|
157
162
|
}
|
158
163
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
159
164
|
|
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
|
|
192
197
|
output.ttft = ttft
|
193
198
|
|
194
199
|
# Decoding phase
|
195
|
-
|
200
|
+
else:
|
201
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
196
202
|
|
197
203
|
most_recent_timestamp = timestamp
|
198
204
|
generated_text += data["choices"][0]["text"]
|
@@ -542,6 +548,7 @@ async def benchmark(
|
|
542
548
|
request_rate: float,
|
543
549
|
disable_tqdm: bool,
|
544
550
|
enable_multi: bool,
|
551
|
+
extra_request_body: Dict[str, Any],
|
545
552
|
):
|
546
553
|
if backend in ASYNC_REQUEST_FUNCS:
|
547
554
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -556,6 +563,7 @@ async def benchmark(
|
|
556
563
|
api_url=api_url,
|
557
564
|
prompt_len=test_prompt_len,
|
558
565
|
output_len=test_output_len,
|
566
|
+
extra_request_body=extra_request_body,
|
559
567
|
)
|
560
568
|
test_output = await request_func(request_func_input=test_input)
|
561
569
|
if not test_output.success:
|
@@ -578,6 +586,7 @@ async def benchmark(
|
|
578
586
|
api_url=api_url,
|
579
587
|
prompt_len=prompt_len,
|
580
588
|
output_len=output_len,
|
589
|
+
extra_request_body=extra_request_body,
|
581
590
|
)
|
582
591
|
tasks.append(
|
583
592
|
asyncio.create_task(
|
@@ -660,19 +669,20 @@ async def benchmark(
|
|
660
669
|
"backend": args.backend,
|
661
670
|
"dataset_name": args.dataset_name,
|
662
671
|
"request_rate": request_rate,
|
663
|
-
"
|
664
|
-
"
|
665
|
-
"
|
666
|
-
"
|
667
|
-
"
|
668
|
-
"
|
669
|
-
"
|
670
|
-
"
|
672
|
+
"total_input_tokens": metrics.total_input,
|
673
|
+
"total_output_tokens": metrics.total_output,
|
674
|
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
675
|
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
676
|
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
677
|
+
"median_ttft_ms": metrics.median_ttft_ms,
|
678
|
+
"median_itl_ms": metrics.median_itl_ms,
|
679
|
+
"output_throughput": metrics.output_throughput,
|
671
680
|
"sharegpt_output_len": args.sharegpt_output_len,
|
672
681
|
"random_input_len": args.random_input_len,
|
673
682
|
"random_output_len": args.random_output_len,
|
674
683
|
"random_range_ratio": args.random_range_ratio,
|
675
|
-
"
|
684
|
+
"duration": benchmark_duration,
|
685
|
+
"completed": metrics.completed,
|
676
686
|
}
|
677
687
|
else:
|
678
688
|
print(f"Error running benchmark for request rate: {request_rate}")
|
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
|
|
742
752
|
return False
|
743
753
|
|
744
754
|
|
745
|
-
def
|
755
|
+
def run_benchmark(args_: argparse.Namespace):
|
756
|
+
global args
|
757
|
+
args = args_
|
758
|
+
|
759
|
+
set_ulimit()
|
746
760
|
random.seed(args.seed)
|
747
761
|
np.random.seed(args.seed)
|
748
762
|
|
763
|
+
extra_request_body = {}
|
764
|
+
if args.extra_request_body:
|
765
|
+
extra_request_body = json.loads(args.extra_request_body)
|
766
|
+
|
749
767
|
if args.port is None:
|
750
768
|
args.port = {
|
751
769
|
"sglang": 30000,
|
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
|
|
838
856
|
request_rate=rate,
|
839
857
|
disable_tqdm=args.disable_tqdm,
|
840
858
|
enable_multi=args.multi,
|
859
|
+
extra_request_body=extra_request_body,
|
841
860
|
)
|
842
861
|
)
|
843
862
|
else:
|
844
|
-
asyncio.run(
|
863
|
+
return asyncio.run(
|
845
864
|
benchmark(
|
846
865
|
backend=backend,
|
847
866
|
api_url=api_url,
|
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
|
|
851
870
|
request_rate=args.request_rate,
|
852
871
|
disable_tqdm=args.disable_tqdm,
|
853
872
|
enable_multi=args.multi,
|
873
|
+
extra_request_body=extra_request_body,
|
854
874
|
)
|
855
875
|
)
|
856
876
|
|
@@ -949,11 +969,6 @@ if __name__ == "__main__":
|
|
949
969
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
|
950
970
|
)
|
951
971
|
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
|
952
|
-
parser.add_argument(
|
953
|
-
"--disable-tqdm",
|
954
|
-
action="store_true",
|
955
|
-
help="Specify to disable tqdm progress bar.",
|
956
|
-
)
|
957
972
|
parser.add_argument(
|
958
973
|
"--multi",
|
959
974
|
action="store_true",
|
@@ -966,6 +981,11 @@ if __name__ == "__main__":
|
|
966
981
|
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
967
982
|
)
|
968
983
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
984
|
+
parser.add_argument(
|
985
|
+
"--disable-tqdm",
|
986
|
+
action="store_true",
|
987
|
+
help="Specify to disable tqdm progress bar.",
|
988
|
+
)
|
969
989
|
parser.add_argument(
|
970
990
|
"--disable-stream",
|
971
991
|
action="store_true",
|
@@ -976,8 +996,12 @@ if __name__ == "__main__":
|
|
976
996
|
action="store_true",
|
977
997
|
help="Disable ignoring EOS.",
|
978
998
|
)
|
979
|
-
|
980
|
-
|
981
|
-
|
999
|
+
parser.add_argument(
|
1000
|
+
"--extra-request-body",
|
1001
|
+
metavar='{"key1": "value1", "key2": "value2"}',
|
1002
|
+
type=str,
|
1003
|
+
help="Append given JSON object to the request payload. You can use this to specify"
|
1004
|
+
"additional generate params like sampling params.",
|
1005
|
+
)
|
982
1006
|
args = parser.parse_args()
|
983
|
-
|
1007
|
+
run_benchmark(args)
|
sglang/check_env.py
CHANGED
@@ -14,6 +14,7 @@ PACKAGE_LIST = [
|
|
14
14
|
"sglang",
|
15
15
|
"flashinfer",
|
16
16
|
"triton",
|
17
|
+
"transformers",
|
17
18
|
"requests",
|
18
19
|
"tqdm",
|
19
20
|
"numpy",
|
@@ -73,10 +74,26 @@ def _get_gpu_info():
|
|
73
74
|
Get information about available GPUs.
|
74
75
|
"""
|
75
76
|
devices = defaultdict(list)
|
77
|
+
capabilities = defaultdict(list)
|
76
78
|
for k in range(torch.cuda.device_count()):
|
77
79
|
devices[torch.cuda.get_device_name(k)].append(str(k))
|
80
|
+
capability = torch.cuda.get_device_capability(k)
|
81
|
+
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
78
82
|
|
79
|
-
|
83
|
+
gpu_info = {}
|
84
|
+
for name, device_ids in devices.items():
|
85
|
+
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
86
|
+
|
87
|
+
if len(capabilities) == 1:
|
88
|
+
# All GPUs have the same compute capability
|
89
|
+
cap, gpu_ids = list(capabilities.items())[0]
|
90
|
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
91
|
+
else:
|
92
|
+
# GPUs have different compute capabilities
|
93
|
+
for cap, gpu_ids in capabilities.items():
|
94
|
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
95
|
+
|
96
|
+
return gpu_info
|
80
97
|
|
81
98
|
|
82
99
|
def _get_cuda_version_info():
|
@@ -118,6 +135,7 @@ def _get_cuda_driver_version():
|
|
118
135
|
"""
|
119
136
|
Get CUDA driver version.
|
120
137
|
"""
|
138
|
+
versions = set()
|
121
139
|
try:
|
122
140
|
output = subprocess.check_output(
|
123
141
|
[
|
@@ -126,7 +144,11 @@ def _get_cuda_driver_version():
|
|
126
144
|
"--format=csv,noheader,nounits",
|
127
145
|
]
|
128
146
|
)
|
129
|
-
|
147
|
+
versions = set(output.decode().strip().split("\n"))
|
148
|
+
if len(versions) == 1:
|
149
|
+
return {"CUDA Driver Version": versions.pop()}
|
150
|
+
else:
|
151
|
+
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
130
152
|
except subprocess.SubprocessError:
|
131
153
|
return {"CUDA Driver Version": "Not Available"}
|
132
154
|
|
sglang/global_config.py
CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|
19
19
|
self.init_new_token_ratio = 0.7
|
20
20
|
self.base_min_new_token_ratio = 0.1
|
21
21
|
self.new_token_ratio_decay = 0.001
|
22
|
-
self.new_token_ratio_recovery = 0.05
|
23
22
|
|
24
23
|
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
25
24
|
# This can improve the speed for large batch sizes during prefill.
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import Callable, List, Optional, Union
|
2
2
|
|
3
3
|
from sglang.lang.chat_template import get_chat_template
|
4
|
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
4
5
|
from sglang.lang.interpreter import StreamExecutor
|
5
6
|
from sglang.lang.ir import SglSamplingParams
|
6
7
|
|
@@ -64,7 +65,8 @@ class BaseBackend:
|
|
64
65
|
s: StreamExecutor,
|
65
66
|
choices: List[str],
|
66
67
|
temperature: float,
|
67
|
-
|
68
|
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
69
|
+
) -> ChoicesDecision:
|
68
70
|
raise NotImplementedError()
|
69
71
|
|
70
72
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
sglang/lang/backend/openai.py
CHANGED
@@ -8,6 +8,7 @@ import numpy as np
|
|
8
8
|
|
9
9
|
from sglang.lang.backend.base_backend import BaseBackend
|
10
10
|
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
11
|
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
11
12
|
from sglang.lang.interpreter import StreamExecutor
|
12
13
|
from sglang.lang.ir import SglSamplingParams
|
13
14
|
|
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
|
|
296
297
|
s: StreamExecutor,
|
297
298
|
choices: List[str],
|
298
299
|
temperature: float,
|
299
|
-
|
300
|
+
choices_method: ChoicesSamplingMethod,
|
301
|
+
) -> ChoicesDecision:
|
302
|
+
"""Note: `choices_method` is not used by the OpenAI backend."""
|
300
303
|
if self.is_chat_model:
|
301
304
|
raise NotImplementedError(
|
302
305
|
"select/choices is not supported for chat models. "
|
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
|
|
354
357
|
|
355
358
|
prompt_tokens.append(ret_token)
|
356
359
|
|
357
|
-
|
358
|
-
|
360
|
+
return ChoicesDecision(
|
361
|
+
decision=choices[np.argmax(scores)],
|
362
|
+
meta_info={"scores": scores},
|
363
|
+
)
|
359
364
|
|
360
365
|
|
361
366
|
def openai_completion(
|