sglang 0.2.8__py3-none-any.whl → 0.2.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -5
- sglang/check_env.py +1 -0
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/model_executor/model_runner.py +6 -4
- sglang/srt/openai_api/adapter.py +7 -6
- sglang/srt/server.py +5 -13
- sglang/srt/server_args.py +11 -0
- sglang/srt/utils.py +20 -0
- sglang/test/run_eval.py +104 -0
- sglang/test/simple_eval_common.py +467 -0
- sglang/test/simple_eval_humaneval.py +139 -0
- sglang/test/simple_eval_mmlu.py +120 -0
- sglang/test/test_programs.py +4 -4
- sglang/test/test_utils.py +32 -0
- sglang/version.py +1 -1
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/RECORD +21 -19
- sglang/test/test_conversation.py +0 -46
- sglang/test/test_openai_protocol.py +0 -51
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -21,7 +21,7 @@ import sys
|
|
21
21
|
import time
|
22
22
|
import traceback
|
23
23
|
import warnings
|
24
|
-
from argparse import ArgumentParser
|
24
|
+
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
27
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
|
|
868
868
|
|
869
869
|
|
870
870
|
if __name__ == "__main__":
|
871
|
-
parser =
|
872
|
-
description="Benchmark the online serving throughput."
|
873
|
-
)
|
871
|
+
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
874
872
|
parser.add_argument(
|
875
873
|
"--backend",
|
876
874
|
type=str,
|
877
|
-
required=True,
|
878
875
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
876
|
+
default="sglang",
|
879
877
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
880
878
|
)
|
881
879
|
parser.add_argument(
|
sglang/check_env.py
CHANGED
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
211
|
all_logprobs = all_logits
|
212
|
-
del all_logits
|
212
|
+
del all_logits, hidden_states
|
213
213
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
214
214
|
|
215
215
|
# Get the logprob of top-k tokens
|
@@ -79,6 +79,7 @@ class TokenizerManager:
|
|
79
79
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
80
80
|
|
81
81
|
self.model_path = server_args.model_path
|
82
|
+
self.served_model_name = server_args.served_model_name
|
82
83
|
self.hf_config = get_config(
|
83
84
|
self.model_path,
|
84
85
|
trust_remote_code=server_args.trust_remote_code,
|
@@ -312,10 +312,12 @@ class ModelRunner:
|
|
312
312
|
self.cuda_graph_runner.capture(batch_size_list)
|
313
313
|
except RuntimeError as e:
|
314
314
|
raise Exception(
|
315
|
-
f"Capture cuda graph failed: {e}
|
316
|
-
|
317
|
-
|
318
|
-
|
315
|
+
f"Capture cuda graph failed: {e}\n"
|
316
|
+
"Possible solutions:\n"
|
317
|
+
"1. disable torch compile by not using --enable-torch-compile\n"
|
318
|
+
"2. disable cuda graph by --disable-cuda-graph\n"
|
319
|
+
"3. set --mem-fraction-static to a smaller value\n"
|
320
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
319
321
|
)
|
320
322
|
|
321
323
|
@torch.inference_mode()
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
594
594
|
|
595
595
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
596
596
|
|
597
|
-
|
597
|
+
input_ids = []
|
598
598
|
sampling_params_list = []
|
599
599
|
image_data_list = []
|
600
600
|
return_logprobs = []
|
@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
608
608
|
if not isinstance(request.messages, str):
|
609
609
|
# Apply chat template and its stop strings.
|
610
610
|
if chat_template_name is None:
|
611
|
-
|
612
|
-
request.messages, tokenize=
|
611
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
612
|
+
request.messages, tokenize=True, add_generation_prompt=True
|
613
613
|
)
|
614
614
|
stop = request.stop
|
615
615
|
image_data = None
|
@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
623
623
|
stop.append(request.stop)
|
624
624
|
else:
|
625
625
|
stop.extend(request.stop)
|
626
|
+
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
626
627
|
else:
|
627
628
|
# Use the raw prompt and stop strings if the messages is already a string.
|
628
629
|
prompt = request.messages
|
629
630
|
stop = request.stop
|
630
631
|
image_data = None
|
631
|
-
|
632
|
+
input_ids.append(prompt_ids)
|
632
633
|
return_logprobs.append(request.logprobs)
|
633
634
|
top_logprobs_nums.append(request.top_logprobs)
|
634
635
|
sampling_params_list.append(
|
@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
645
646
|
)
|
646
647
|
image_data_list.append(image_data)
|
647
648
|
if len(all_requests) == 1:
|
648
|
-
|
649
|
+
input_ids = input_ids[0]
|
649
650
|
sampling_params_list = sampling_params_list[0]
|
650
651
|
image_data = image_data_list[0]
|
651
652
|
return_logprobs = return_logprobs[0]
|
652
653
|
top_logprobs_nums = top_logprobs_nums[0]
|
653
654
|
adapted_request = GenerateReqInput(
|
654
|
-
|
655
|
+
input_ids=input_ids,
|
655
656
|
image_data=image_data,
|
656
657
|
sampling_params=sampling_params_list,
|
657
658
|
return_logprob=return_logprobs,
|
sglang/srt/server.py
CHANGED
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
allocate_init_ports,
|
73
73
|
assert_pkg_version,
|
74
74
|
enable_show_time_cost,
|
75
|
+
kill_child_process,
|
75
76
|
maybe_set_triton_cache_manager,
|
76
77
|
set_ulimit,
|
77
78
|
)
|
@@ -189,10 +190,10 @@ async def retrieve_file_content(file_id: str):
|
|
189
190
|
@app.get("/v1/models")
|
190
191
|
def available_models():
|
191
192
|
"""Show available models."""
|
192
|
-
|
193
|
+
served_model_names = [tokenizer_manager.served_model_name]
|
193
194
|
model_cards = []
|
194
|
-
for
|
195
|
-
model_cards.append(ModelCard(id=
|
195
|
+
for served_model_name in served_model_names:
|
196
|
+
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
196
197
|
return ModelList(data=model_cards)
|
197
198
|
|
198
199
|
|
@@ -467,16 +468,7 @@ class Runtime:
|
|
467
468
|
|
468
469
|
def shutdown(self):
|
469
470
|
if self.pid is not None:
|
470
|
-
|
471
|
-
parent = psutil.Process(self.pid)
|
472
|
-
except psutil.NoSuchProcess:
|
473
|
-
return
|
474
|
-
children = parent.children(recursive=True)
|
475
|
-
for child in children:
|
476
|
-
child.kill()
|
477
|
-
psutil.wait_procs(children, timeout=5)
|
478
|
-
parent.kill()
|
479
|
-
parent.wait(timeout=5)
|
471
|
+
kill_child_process(self.pid)
|
480
472
|
self.pid = None
|
481
473
|
|
482
474
|
def cache_prefix(self, prefix: str):
|
sglang/srt/server_args.py
CHANGED
@@ -32,6 +32,7 @@ class ServerArgs:
|
|
32
32
|
trust_remote_code: bool = True
|
33
33
|
context_length: Optional[int] = None
|
34
34
|
quantization: Optional[str] = None
|
35
|
+
served_model_name: Optional[str] = None
|
35
36
|
chat_template: Optional[str] = None
|
36
37
|
|
37
38
|
# Port
|
@@ -90,6 +91,10 @@ class ServerArgs:
|
|
90
91
|
def __post_init__(self):
|
91
92
|
if self.tokenizer_path is None:
|
92
93
|
self.tokenizer_path = self.model_path
|
94
|
+
|
95
|
+
if self.served_model_name is None:
|
96
|
+
self.served_model_name = self.model_path
|
97
|
+
|
93
98
|
if self.mem_fraction_static is None:
|
94
99
|
if self.tp_size >= 16:
|
95
100
|
self.mem_fraction_static = 0.79
|
@@ -202,6 +207,12 @@ class ServerArgs:
|
|
202
207
|
],
|
203
208
|
help="The quantization method.",
|
204
209
|
)
|
210
|
+
parser.add_argument(
|
211
|
+
"--served-model-name",
|
212
|
+
type=str,
|
213
|
+
default=ServerArgs.served_model_name,
|
214
|
+
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
215
|
+
)
|
205
216
|
parser.add_argument(
|
206
217
|
"--chat-template",
|
207
218
|
type=str,
|
sglang/srt/utils.py
CHANGED
@@ -366,6 +366,26 @@ def kill_parent_process():
|
|
366
366
|
os.kill(parent_process.pid, 9)
|
367
367
|
|
368
368
|
|
369
|
+
def kill_child_process(pid, including_parent=True):
|
370
|
+
try:
|
371
|
+
parent = psutil.Process(pid)
|
372
|
+
except psutil.NoSuchProcess:
|
373
|
+
return
|
374
|
+
|
375
|
+
children = parent.children(recursive=True)
|
376
|
+
for child in children:
|
377
|
+
try:
|
378
|
+
child.kill()
|
379
|
+
except psutil.NoSuchProcess:
|
380
|
+
pass
|
381
|
+
|
382
|
+
if including_parent:
|
383
|
+
try:
|
384
|
+
parent.kill()
|
385
|
+
except psutil.NoSuchProcess:
|
386
|
+
pass
|
387
|
+
|
388
|
+
|
369
389
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
370
390
|
"""
|
371
391
|
Monkey patch the slow p2p access check in vllm.
|
sglang/test/run_eval.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import time
|
10
|
+
|
11
|
+
from sglang.test.simple_eval_common import (
|
12
|
+
ChatCompletionSampler,
|
13
|
+
download_dataset,
|
14
|
+
make_report,
|
15
|
+
set_ulimit,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
def run_eval(args):
|
20
|
+
if "OPENAI_API_KEY" not in os.environ:
|
21
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
22
|
+
|
23
|
+
base_url = (
|
24
|
+
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
25
|
+
)
|
26
|
+
|
27
|
+
if args.eval_name == "mmlu":
|
28
|
+
from sglang.test.simple_eval_mmlu import MMLUEval
|
29
|
+
|
30
|
+
dataset_path = "mmlu.csv"
|
31
|
+
|
32
|
+
if not os.path.exists(dataset_path):
|
33
|
+
download_dataset(
|
34
|
+
dataset_path,
|
35
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
36
|
+
)
|
37
|
+
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
38
|
+
elif args.eval_name == "humaneval":
|
39
|
+
from sglang.test.simple_eval_humaneval import HumanEval
|
40
|
+
|
41
|
+
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
42
|
+
else:
|
43
|
+
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
44
|
+
|
45
|
+
sampler = ChatCompletionSampler(
|
46
|
+
model=args.model,
|
47
|
+
max_tokens=2048,
|
48
|
+
base_url=base_url,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Run eval
|
52
|
+
tic = time.time()
|
53
|
+
result = eval_obj(sampler)
|
54
|
+
latency = time.time() - tic
|
55
|
+
|
56
|
+
# Dump reports
|
57
|
+
metrics = result.metrics | {"score": result.score}
|
58
|
+
file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
|
59
|
+
report_filename = f"/tmp/{file_stem}.html"
|
60
|
+
print(f"Writing report to {report_filename}")
|
61
|
+
with open(report_filename, "w") as fh:
|
62
|
+
fh.write(make_report(result))
|
63
|
+
metrics = result.metrics | {"score": result.score}
|
64
|
+
print(metrics)
|
65
|
+
result_filename = f"/tmp/{file_stem}.json"
|
66
|
+
with open(result_filename, "w") as f:
|
67
|
+
f.write(json.dumps(metrics, indent=2))
|
68
|
+
print(f"Writing results to {result_filename}")
|
69
|
+
|
70
|
+
# Print results
|
71
|
+
print(f"Total latency: {latency:.3f} s")
|
72
|
+
print(f"Score: {metrics['score']:.3f}")
|
73
|
+
|
74
|
+
return metrics
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == "__main__":
|
78
|
+
parser = argparse.ArgumentParser()
|
79
|
+
parser.add_argument(
|
80
|
+
"--base-url",
|
81
|
+
type=str,
|
82
|
+
default=None,
|
83
|
+
help="Server or API base url if not using http host and port.",
|
84
|
+
)
|
85
|
+
parser.add_argument(
|
86
|
+
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
87
|
+
)
|
88
|
+
parser.add_argument(
|
89
|
+
"--port",
|
90
|
+
type=int,
|
91
|
+
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
92
|
+
)
|
93
|
+
parser.add_argument(
|
94
|
+
"--model",
|
95
|
+
type=str,
|
96
|
+
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
97
|
+
)
|
98
|
+
parser.add_argument("--eval-name", type=str, default="mmlu")
|
99
|
+
parser.add_argument("--num-examples", type=int)
|
100
|
+
parser.add_argument("--num-threads", type=int, default=64)
|
101
|
+
set_ulimit()
|
102
|
+
args = parser.parse_args()
|
103
|
+
|
104
|
+
run_eval(args)
|