sglang 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -5
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/openai_api/adapter.py +7 -6
- sglang/srt/server.py +5 -13
- sglang/srt/server_args.py +11 -0
- sglang/srt/utils.py +20 -0
- sglang/test/run_eval.py +104 -0
- sglang/test/simple_eval_common.py +467 -0
- sglang/test/simple_eval_humaneval.py +139 -0
- sglang/test/simple_eval_mmlu.py +120 -0
- sglang/test/test_programs.py +4 -4
- sglang/test/test_utils.py +32 -0
- sglang/version.py +1 -1
- {sglang-0.2.8.dist-info → sglang-0.2.9.dist-info}/METADATA +3 -3
- {sglang-0.2.8.dist-info → sglang-0.2.9.dist-info}/RECORD +19 -17
- sglang/test/test_conversation.py +0 -46
- sglang/test/test_openai_protocol.py +0 -51
- {sglang-0.2.8.dist-info → sglang-0.2.9.dist-info}/LICENSE +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.dist-info}/WHEEL +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -21,7 +21,7 @@ import sys
|
|
21
21
|
import time
|
22
22
|
import traceback
|
23
23
|
import warnings
|
24
|
-
from argparse import ArgumentParser
|
24
|
+
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
27
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
|
|
868
868
|
|
869
869
|
|
870
870
|
if __name__ == "__main__":
|
871
|
-
parser =
|
872
|
-
description="Benchmark the online serving throughput."
|
873
|
-
)
|
871
|
+
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
874
872
|
parser.add_argument(
|
875
873
|
"--backend",
|
876
874
|
type=str,
|
877
|
-
required=True,
|
878
875
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
876
|
+
default="sglang",
|
879
877
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
880
878
|
)
|
881
879
|
parser.add_argument(
|
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
211
|
all_logprobs = all_logits
|
212
|
-
del all_logits
|
212
|
+
del all_logits, hidden_states
|
213
213
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
214
214
|
|
215
215
|
# Get the logprob of top-k tokens
|
@@ -79,6 +79,7 @@ class TokenizerManager:
|
|
79
79
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
80
80
|
|
81
81
|
self.model_path = server_args.model_path
|
82
|
+
self.served_model_name = server_args.served_model_name
|
82
83
|
self.hf_config = get_config(
|
83
84
|
self.model_path,
|
84
85
|
trust_remote_code=server_args.trust_remote_code,
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
594
594
|
|
595
595
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
596
596
|
|
597
|
-
|
597
|
+
input_ids = []
|
598
598
|
sampling_params_list = []
|
599
599
|
image_data_list = []
|
600
600
|
return_logprobs = []
|
@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
608
608
|
if not isinstance(request.messages, str):
|
609
609
|
# Apply chat template and its stop strings.
|
610
610
|
if chat_template_name is None:
|
611
|
-
|
612
|
-
request.messages, tokenize=
|
611
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
612
|
+
request.messages, tokenize=True, add_generation_prompt=True
|
613
613
|
)
|
614
614
|
stop = request.stop
|
615
615
|
image_data = None
|
@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
623
623
|
stop.append(request.stop)
|
624
624
|
else:
|
625
625
|
stop.extend(request.stop)
|
626
|
+
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
626
627
|
else:
|
627
628
|
# Use the raw prompt and stop strings if the messages is already a string.
|
628
629
|
prompt = request.messages
|
629
630
|
stop = request.stop
|
630
631
|
image_data = None
|
631
|
-
|
632
|
+
input_ids.append(prompt_ids)
|
632
633
|
return_logprobs.append(request.logprobs)
|
633
634
|
top_logprobs_nums.append(request.top_logprobs)
|
634
635
|
sampling_params_list.append(
|
@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
645
646
|
)
|
646
647
|
image_data_list.append(image_data)
|
647
648
|
if len(all_requests) == 1:
|
648
|
-
|
649
|
+
input_ids = input_ids[0]
|
649
650
|
sampling_params_list = sampling_params_list[0]
|
650
651
|
image_data = image_data_list[0]
|
651
652
|
return_logprobs = return_logprobs[0]
|
652
653
|
top_logprobs_nums = top_logprobs_nums[0]
|
653
654
|
adapted_request = GenerateReqInput(
|
654
|
-
|
655
|
+
input_ids=input_ids,
|
655
656
|
image_data=image_data,
|
656
657
|
sampling_params=sampling_params_list,
|
657
658
|
return_logprob=return_logprobs,
|
sglang/srt/server.py
CHANGED
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
allocate_init_ports,
|
73
73
|
assert_pkg_version,
|
74
74
|
enable_show_time_cost,
|
75
|
+
kill_child_process,
|
75
76
|
maybe_set_triton_cache_manager,
|
76
77
|
set_ulimit,
|
77
78
|
)
|
@@ -189,10 +190,10 @@ async def retrieve_file_content(file_id: str):
|
|
189
190
|
@app.get("/v1/models")
|
190
191
|
def available_models():
|
191
192
|
"""Show available models."""
|
192
|
-
|
193
|
+
served_model_names = [tokenizer_manager.served_model_name]
|
193
194
|
model_cards = []
|
194
|
-
for
|
195
|
-
model_cards.append(ModelCard(id=
|
195
|
+
for served_model_name in served_model_names:
|
196
|
+
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
196
197
|
return ModelList(data=model_cards)
|
197
198
|
|
198
199
|
|
@@ -467,16 +468,7 @@ class Runtime:
|
|
467
468
|
|
468
469
|
def shutdown(self):
|
469
470
|
if self.pid is not None:
|
470
|
-
|
471
|
-
parent = psutil.Process(self.pid)
|
472
|
-
except psutil.NoSuchProcess:
|
473
|
-
return
|
474
|
-
children = parent.children(recursive=True)
|
475
|
-
for child in children:
|
476
|
-
child.kill()
|
477
|
-
psutil.wait_procs(children, timeout=5)
|
478
|
-
parent.kill()
|
479
|
-
parent.wait(timeout=5)
|
471
|
+
kill_child_process(self.pid)
|
480
472
|
self.pid = None
|
481
473
|
|
482
474
|
def cache_prefix(self, prefix: str):
|
sglang/srt/server_args.py
CHANGED
@@ -32,6 +32,7 @@ class ServerArgs:
|
|
32
32
|
trust_remote_code: bool = True
|
33
33
|
context_length: Optional[int] = None
|
34
34
|
quantization: Optional[str] = None
|
35
|
+
served_model_name: Optional[str] = None
|
35
36
|
chat_template: Optional[str] = None
|
36
37
|
|
37
38
|
# Port
|
@@ -90,6 +91,10 @@ class ServerArgs:
|
|
90
91
|
def __post_init__(self):
|
91
92
|
if self.tokenizer_path is None:
|
92
93
|
self.tokenizer_path = self.model_path
|
94
|
+
|
95
|
+
if self.served_model_name is None:
|
96
|
+
self.served_model_name = self.model_path
|
97
|
+
|
93
98
|
if self.mem_fraction_static is None:
|
94
99
|
if self.tp_size >= 16:
|
95
100
|
self.mem_fraction_static = 0.79
|
@@ -202,6 +207,12 @@ class ServerArgs:
|
|
202
207
|
],
|
203
208
|
help="The quantization method.",
|
204
209
|
)
|
210
|
+
parser.add_argument(
|
211
|
+
"--served-model-name",
|
212
|
+
type=str,
|
213
|
+
default=ServerArgs.served_model_name,
|
214
|
+
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
215
|
+
)
|
205
216
|
parser.add_argument(
|
206
217
|
"--chat-template",
|
207
218
|
type=str,
|
sglang/srt/utils.py
CHANGED
@@ -366,6 +366,26 @@ def kill_parent_process():
|
|
366
366
|
os.kill(parent_process.pid, 9)
|
367
367
|
|
368
368
|
|
369
|
+
def kill_child_process(pid, including_parent=True):
|
370
|
+
try:
|
371
|
+
parent = psutil.Process(pid)
|
372
|
+
except psutil.NoSuchProcess:
|
373
|
+
return
|
374
|
+
|
375
|
+
children = parent.children(recursive=True)
|
376
|
+
for child in children:
|
377
|
+
try:
|
378
|
+
child.kill()
|
379
|
+
except psutil.NoSuchProcess:
|
380
|
+
pass
|
381
|
+
|
382
|
+
if including_parent:
|
383
|
+
try:
|
384
|
+
parent.kill()
|
385
|
+
except psutil.NoSuchProcess:
|
386
|
+
pass
|
387
|
+
|
388
|
+
|
369
389
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
370
390
|
"""
|
371
391
|
Monkey patch the slow p2p access check in vllm.
|
sglang/test/run_eval.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import time
|
10
|
+
|
11
|
+
from sglang.test.simple_eval_common import (
|
12
|
+
ChatCompletionSampler,
|
13
|
+
download_dataset,
|
14
|
+
make_report,
|
15
|
+
set_ulimit,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
def run_eval(args):
|
20
|
+
if "OPENAI_API_KEY" not in os.environ:
|
21
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
22
|
+
|
23
|
+
base_url = (
|
24
|
+
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
25
|
+
)
|
26
|
+
|
27
|
+
if args.eval_name == "mmlu":
|
28
|
+
from sglang.test.simple_eval_mmlu import MMLUEval
|
29
|
+
|
30
|
+
dataset_path = "mmlu.csv"
|
31
|
+
|
32
|
+
if not os.path.exists(dataset_path):
|
33
|
+
download_dataset(
|
34
|
+
dataset_path,
|
35
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
36
|
+
)
|
37
|
+
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
38
|
+
elif args.eval_name == "humaneval":
|
39
|
+
from sglang.test.simple_eval_humaneval import HumanEval
|
40
|
+
|
41
|
+
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
42
|
+
else:
|
43
|
+
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
44
|
+
|
45
|
+
sampler = ChatCompletionSampler(
|
46
|
+
model=args.model,
|
47
|
+
max_tokens=2048,
|
48
|
+
base_url=base_url,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Run eval
|
52
|
+
tic = time.time()
|
53
|
+
result = eval_obj(sampler)
|
54
|
+
latency = time.time() - tic
|
55
|
+
|
56
|
+
# Dump reports
|
57
|
+
metrics = result.metrics | {"score": result.score}
|
58
|
+
file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
|
59
|
+
report_filename = f"/tmp/{file_stem}.html"
|
60
|
+
print(f"Writing report to {report_filename}")
|
61
|
+
with open(report_filename, "w") as fh:
|
62
|
+
fh.write(make_report(result))
|
63
|
+
metrics = result.metrics | {"score": result.score}
|
64
|
+
print(metrics)
|
65
|
+
result_filename = f"/tmp/{file_stem}.json"
|
66
|
+
with open(result_filename, "w") as f:
|
67
|
+
f.write(json.dumps(metrics, indent=2))
|
68
|
+
print(f"Writing results to {result_filename}")
|
69
|
+
|
70
|
+
# Print results
|
71
|
+
print(f"Total latency: {latency:.3f} s")
|
72
|
+
print(f"Score: {metrics['score']:.3f}")
|
73
|
+
|
74
|
+
return metrics
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == "__main__":
|
78
|
+
parser = argparse.ArgumentParser()
|
79
|
+
parser.add_argument(
|
80
|
+
"--base-url",
|
81
|
+
type=str,
|
82
|
+
default=None,
|
83
|
+
help="Server or API base url if not using http host and port.",
|
84
|
+
)
|
85
|
+
parser.add_argument(
|
86
|
+
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
87
|
+
)
|
88
|
+
parser.add_argument(
|
89
|
+
"--port",
|
90
|
+
type=int,
|
91
|
+
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
92
|
+
)
|
93
|
+
parser.add_argument(
|
94
|
+
"--model",
|
95
|
+
type=str,
|
96
|
+
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
97
|
+
)
|
98
|
+
parser.add_argument("--eval-name", type=str, default="mmlu")
|
99
|
+
parser.add_argument("--num-examples", type=int)
|
100
|
+
parser.add_argument("--num-threads", type=int, default=64)
|
101
|
+
set_ulimit()
|
102
|
+
args = parser.parse_args()
|
103
|
+
|
104
|
+
run_eval(args)
|
@@ -0,0 +1,467 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import os
|
5
|
+
import resource
|
6
|
+
import time
|
7
|
+
from collections import defaultdict
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from multiprocessing.pool import ThreadPool
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
import httpx
|
13
|
+
import jinja2
|
14
|
+
import numpy as np
|
15
|
+
import openai
|
16
|
+
import requests
|
17
|
+
from openai import OpenAI
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
21
|
+
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
22
|
+
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
23
|
+
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
Message = dict[str, Any] # keys role, content
|
28
|
+
MessageList = list[Message]
|
29
|
+
|
30
|
+
|
31
|
+
class SamplerBase:
|
32
|
+
"""
|
33
|
+
Base class for defining a sampling model, which can be evaluated,
|
34
|
+
or used as part of the grading process.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __call__(self, message_list: MessageList) -> str:
|
38
|
+
raise NotImplementedError()
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class EvalResult:
|
43
|
+
"""
|
44
|
+
Result of running an evaluation (usually consisting of many samples)
|
45
|
+
"""
|
46
|
+
|
47
|
+
score: float | None # top-line metric
|
48
|
+
metrics: dict[str, float] | None # other metrics
|
49
|
+
htmls: list[str] # strings of valid HTML
|
50
|
+
convos: list[MessageList] # sampled conversations
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class SingleEvalResult:
|
55
|
+
"""
|
56
|
+
Result of evaluating a single sample
|
57
|
+
"""
|
58
|
+
|
59
|
+
score: float | None
|
60
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
61
|
+
html: str | None = None
|
62
|
+
convo: MessageList | None = None # sampled conversation
|
63
|
+
|
64
|
+
|
65
|
+
class Eval:
|
66
|
+
"""
|
67
|
+
Base class for defining an evaluation.
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
71
|
+
raise NotImplementedError()
|
72
|
+
|
73
|
+
|
74
|
+
class LargerHttpxClient(httpx.Client):
|
75
|
+
def __init__(self):
|
76
|
+
timeout_config = httpx.Timeout(3600)
|
77
|
+
limits = httpx.Limits(
|
78
|
+
max_keepalive_connections=3600,
|
79
|
+
max_connections=3600,
|
80
|
+
)
|
81
|
+
super().__init__(timeout=timeout_config, limits=limits)
|
82
|
+
|
83
|
+
|
84
|
+
class ChatCompletionSampler(SamplerBase):
|
85
|
+
"""
|
86
|
+
Sample from OpenAI's chat completion API
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
base_url: str = None,
|
92
|
+
model: str | None = None,
|
93
|
+
system_message: str | None = None,
|
94
|
+
temperature: float = 0.0,
|
95
|
+
max_tokens: int = 2048,
|
96
|
+
):
|
97
|
+
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
98
|
+
|
99
|
+
if model is None:
|
100
|
+
model = self.client.models.list().data[0].id
|
101
|
+
|
102
|
+
self.model = model
|
103
|
+
self.system_message = system_message
|
104
|
+
self.temperature = temperature
|
105
|
+
self.max_tokens = max_tokens
|
106
|
+
self.image_format = "url"
|
107
|
+
|
108
|
+
def _handle_image(
|
109
|
+
self,
|
110
|
+
image: str,
|
111
|
+
encoding: str = "base64",
|
112
|
+
format: str = "png",
|
113
|
+
fovea: int = 768,
|
114
|
+
):
|
115
|
+
new_image = {
|
116
|
+
"type": "image_url",
|
117
|
+
"image_url": {
|
118
|
+
"url": f"data:image/{format};{encoding},{image}",
|
119
|
+
},
|
120
|
+
}
|
121
|
+
return new_image
|
122
|
+
|
123
|
+
def _handle_text(self, text: str):
|
124
|
+
return {"type": "text", "text": text}
|
125
|
+
|
126
|
+
def _pack_message(self, role: str, content: Any):
|
127
|
+
return {"role": str(role), "content": content}
|
128
|
+
|
129
|
+
def __call__(self, message_list: MessageList) -> str:
|
130
|
+
if self.system_message:
|
131
|
+
message_list = [
|
132
|
+
self._pack_message("system", self.system_message)
|
133
|
+
] + message_list
|
134
|
+
trial = 0
|
135
|
+
while True:
|
136
|
+
try:
|
137
|
+
response = self.client.chat.completions.create(
|
138
|
+
model=self.model,
|
139
|
+
messages=message_list,
|
140
|
+
temperature=self.temperature,
|
141
|
+
max_tokens=self.max_tokens,
|
142
|
+
)
|
143
|
+
return response.choices[0].message.content
|
144
|
+
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
145
|
+
except openai.BadRequestError as e:
|
146
|
+
print("Bad Request Error", e)
|
147
|
+
return ""
|
148
|
+
except Exception as e:
|
149
|
+
exception_backoff = 2**trial # expontial back off
|
150
|
+
print(
|
151
|
+
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
152
|
+
e,
|
153
|
+
)
|
154
|
+
time.sleep(exception_backoff)
|
155
|
+
trial += 1
|
156
|
+
# unknown error shall throw exception
|
157
|
+
|
158
|
+
|
159
|
+
QUERY_TEMPLATE_MULTICHOICE = """
|
160
|
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
161
|
+
|
162
|
+
{Question}
|
163
|
+
|
164
|
+
A) {A}
|
165
|
+
B) {B}
|
166
|
+
C) {C}
|
167
|
+
D) {D}
|
168
|
+
""".strip()
|
169
|
+
|
170
|
+
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
|
171
|
+
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
|
172
|
+
|
173
|
+
|
174
|
+
EQUALITY_TEMPLATE = r"""
|
175
|
+
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
|
179
|
+
Expression 1: $2x+3$
|
180
|
+
Expression 2: $3+2x$
|
181
|
+
|
182
|
+
Yes
|
183
|
+
|
184
|
+
Expression 1: 3/2
|
185
|
+
Expression 2: 1.5
|
186
|
+
|
187
|
+
Yes
|
188
|
+
|
189
|
+
Expression 1: $x^2+2x+1$
|
190
|
+
Expression 2: $y^2+2y+1$
|
191
|
+
|
192
|
+
No
|
193
|
+
|
194
|
+
Expression 1: $x^2+2x+1$
|
195
|
+
Expression 2: $(x+1)^2$
|
196
|
+
|
197
|
+
Yes
|
198
|
+
|
199
|
+
Expression 1: 3245/5
|
200
|
+
Expression 2: 649
|
201
|
+
|
202
|
+
No
|
203
|
+
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
204
|
+
|
205
|
+
Expression 1: 2/(-3)
|
206
|
+
Expression 2: -2/3
|
207
|
+
|
208
|
+
Yes
|
209
|
+
(trivial simplifications are allowed)
|
210
|
+
|
211
|
+
Expression 1: 72 degrees
|
212
|
+
Expression 2: 72
|
213
|
+
|
214
|
+
Yes
|
215
|
+
(give benefit of the doubt to units)
|
216
|
+
|
217
|
+
Expression 1: 64
|
218
|
+
Expression 2: 64 square feet
|
219
|
+
|
220
|
+
Yes
|
221
|
+
(give benefit of the doubt to units)
|
222
|
+
|
223
|
+
---
|
224
|
+
|
225
|
+
YOUR TASK
|
226
|
+
|
227
|
+
|
228
|
+
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
229
|
+
|
230
|
+
Expression 1: %(expression1)s
|
231
|
+
Expression 2: %(expression2)s
|
232
|
+
""".strip()
|
233
|
+
|
234
|
+
|
235
|
+
HTML_JINJA = """
|
236
|
+
<h3>Prompt conversation</h3>
|
237
|
+
{% for message in prompt_messages %}
|
238
|
+
{{ message_to_html(message) | safe }}
|
239
|
+
{% endfor %}
|
240
|
+
<h3>Sampled message</h3>
|
241
|
+
{{ message_to_html(next_message) | safe }}
|
242
|
+
<h3>Results</h3>
|
243
|
+
<p>Correct Answer: {{ correct_answer }}</p>
|
244
|
+
<p>Extracted Answer: {{ extracted_answer }}</p>
|
245
|
+
<p>Score: {{ score }}</p>
|
246
|
+
"""
|
247
|
+
|
248
|
+
|
249
|
+
def format_multichoice_question(row):
|
250
|
+
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
251
|
+
|
252
|
+
|
253
|
+
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
254
|
+
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
255
|
+
response = sampler([dict(content=prompt, role="user")])
|
256
|
+
return response.lower().strip() == "yes"
|
257
|
+
|
258
|
+
|
259
|
+
def _compute_stat(values: list, stat: str):
|
260
|
+
if stat == "mean":
|
261
|
+
return np.mean(values)
|
262
|
+
elif stat == "std":
|
263
|
+
return np.std(values)
|
264
|
+
elif stat == "min":
|
265
|
+
return np.min(values)
|
266
|
+
elif stat == "max":
|
267
|
+
return np.max(values)
|
268
|
+
else:
|
269
|
+
raise ValueError(f"Unknown {stat =}")
|
270
|
+
|
271
|
+
|
272
|
+
def aggregate_results(
|
273
|
+
single_eval_results: list[SingleEvalResult],
|
274
|
+
default_stats: tuple[str] = ("mean", "std"),
|
275
|
+
name2stats: dict[str, tuple[str]] | None = None,
|
276
|
+
) -> EvalResult:
|
277
|
+
"""
|
278
|
+
Aggregate results from multiple evaluations into a single EvalResult.
|
279
|
+
"""
|
280
|
+
name2stats = name2stats or {}
|
281
|
+
name2values = defaultdict(list)
|
282
|
+
htmls = []
|
283
|
+
convos = []
|
284
|
+
for single_eval_result in single_eval_results:
|
285
|
+
for name, value in single_eval_result.metrics.items():
|
286
|
+
name2values[name].append(value)
|
287
|
+
if single_eval_result.score is not None:
|
288
|
+
name2values["score"].append(single_eval_result.score)
|
289
|
+
htmls.append(single_eval_result.html)
|
290
|
+
convos.append(single_eval_result.convo)
|
291
|
+
final_metrics = {}
|
292
|
+
for name, values in name2values.items():
|
293
|
+
stats = name2stats.get(name, default_stats)
|
294
|
+
for stat in stats:
|
295
|
+
key = name if stat == "mean" else f"{name}:{stat}"
|
296
|
+
final_metrics[key] = _compute_stat(values, stat)
|
297
|
+
return EvalResult(
|
298
|
+
score=final_metrics.pop("score", None),
|
299
|
+
metrics=final_metrics,
|
300
|
+
htmls=htmls,
|
301
|
+
convos=convos,
|
302
|
+
)
|
303
|
+
|
304
|
+
|
305
|
+
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
306
|
+
"""
|
307
|
+
Apply f to each element of xs, using a ThreadPool, and show progress.
|
308
|
+
"""
|
309
|
+
if os.getenv("debug"):
|
310
|
+
return list(map(f, tqdm(xs, total=len(xs))))
|
311
|
+
else:
|
312
|
+
with ThreadPool(min(num_threads, len(xs))) as pool:
|
313
|
+
return list(tqdm(pool.imap(f, xs), total=len(xs)))
|
314
|
+
|
315
|
+
|
316
|
+
jinja_env = jinja2.Environment(
|
317
|
+
loader=jinja2.BaseLoader(),
|
318
|
+
undefined=jinja2.StrictUndefined,
|
319
|
+
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
320
|
+
)
|
321
|
+
_message_template = """
|
322
|
+
<div class="message {{ role }}">
|
323
|
+
<div class="role">
|
324
|
+
{{ role }}
|
325
|
+
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
326
|
+
</div>
|
327
|
+
<div class="content">
|
328
|
+
<pre>{{ content }}</pre>
|
329
|
+
</div>
|
330
|
+
</div>
|
331
|
+
"""
|
332
|
+
|
333
|
+
|
334
|
+
def message_to_html(message: Message) -> str:
|
335
|
+
"""
|
336
|
+
Generate HTML snippet (inside a <div>) for a message.
|
337
|
+
"""
|
338
|
+
return jinja_env.from_string(_message_template).render(
|
339
|
+
role=message["role"],
|
340
|
+
content=message["content"],
|
341
|
+
variant=message.get("variant", None),
|
342
|
+
)
|
343
|
+
|
344
|
+
|
345
|
+
jinja_env.globals["message_to_html"] = message_to_html
|
346
|
+
|
347
|
+
|
348
|
+
_report_template = """<!DOCTYPE html>
|
349
|
+
<html>
|
350
|
+
<head>
|
351
|
+
<style>
|
352
|
+
.message {
|
353
|
+
padding: 8px 16px;
|
354
|
+
margin-bottom: 8px;
|
355
|
+
border-radius: 4px;
|
356
|
+
}
|
357
|
+
.message.user {
|
358
|
+
background-color: #B2DFDB;
|
359
|
+
color: #00695C;
|
360
|
+
}
|
361
|
+
.message.assistant {
|
362
|
+
background-color: #B39DDB;
|
363
|
+
color: #4527A0;
|
364
|
+
}
|
365
|
+
.message.system {
|
366
|
+
background-color: #EEEEEE;
|
367
|
+
color: #212121;
|
368
|
+
}
|
369
|
+
.role {
|
370
|
+
font-weight: bold;
|
371
|
+
margin-bottom: 4px;
|
372
|
+
}
|
373
|
+
.variant {
|
374
|
+
color: #795548;
|
375
|
+
}
|
376
|
+
table, th, td {
|
377
|
+
border: 1px solid black;
|
378
|
+
}
|
379
|
+
pre {
|
380
|
+
white-space: pre-wrap;
|
381
|
+
}
|
382
|
+
</style>
|
383
|
+
</head>
|
384
|
+
<body>
|
385
|
+
{% if metrics %}
|
386
|
+
<h1>Metrics</h1>
|
387
|
+
<table>
|
388
|
+
<tr>
|
389
|
+
<th>Metric</th>
|
390
|
+
<th>Value</th>
|
391
|
+
</tr>
|
392
|
+
<tr>
|
393
|
+
<td><b>Score</b></td>
|
394
|
+
<td>{{ score | float | round(3) }}</td>
|
395
|
+
</tr>
|
396
|
+
{% for name, value in metrics.items() %}
|
397
|
+
<tr>
|
398
|
+
<td>{{ name }}</td>
|
399
|
+
<td>{{ value }}</td>
|
400
|
+
</tr>
|
401
|
+
{% endfor %}
|
402
|
+
</table>
|
403
|
+
{% endif %}
|
404
|
+
<h1>Examples</h1>
|
405
|
+
{% for html in htmls %}
|
406
|
+
{{ html | safe }}
|
407
|
+
<hr>
|
408
|
+
{% endfor %}
|
409
|
+
</body>
|
410
|
+
</html>
|
411
|
+
"""
|
412
|
+
|
413
|
+
|
414
|
+
def make_report(eval_result: EvalResult) -> str:
|
415
|
+
"""
|
416
|
+
Create a standalone HTML report from an EvalResult.
|
417
|
+
"""
|
418
|
+
return jinja_env.from_string(_report_template).render(
|
419
|
+
score=eval_result.score,
|
420
|
+
metrics=eval_result.metrics,
|
421
|
+
htmls=eval_result.htmls,
|
422
|
+
)
|
423
|
+
|
424
|
+
|
425
|
+
def make_report_from_example_htmls(htmls: list[str]):
|
426
|
+
"""
|
427
|
+
Create a standalone HTML report from a list of example htmls
|
428
|
+
"""
|
429
|
+
return jinja_env.from_string(_report_template).render(
|
430
|
+
score=None, metrics={}, htmls=htmls
|
431
|
+
)
|
432
|
+
|
433
|
+
|
434
|
+
def download_dataset(path, url):
|
435
|
+
print(f"Downloading dataset {path} from {url}")
|
436
|
+
try:
|
437
|
+
response = requests.get(url, stream=True)
|
438
|
+
response.raise_for_status()
|
439
|
+
|
440
|
+
total_size = int(response.headers.get("content-length", 0))
|
441
|
+
block_size = 8192
|
442
|
+
|
443
|
+
with open(path, "wb") as f, tqdm(
|
444
|
+
desc="Downloading",
|
445
|
+
total=total_size,
|
446
|
+
unit="iB",
|
447
|
+
unit_scale=True,
|
448
|
+
unit_divisor=1024,
|
449
|
+
) as progress_bar:
|
450
|
+
for data in response.iter_content(block_size):
|
451
|
+
size = f.write(data)
|
452
|
+
progress_bar.update(size)
|
453
|
+
|
454
|
+
print(f"Dataset downloaded and saved to {path}")
|
455
|
+
except requests.RequestException as e:
|
456
|
+
raise Exception(f"Failed to download dataset: {e}")
|
457
|
+
|
458
|
+
|
459
|
+
def set_ulimit(target_soft_limit=65535):
|
460
|
+
resource_type = resource.RLIMIT_NOFILE
|
461
|
+
current_soft, current_hard = resource.getrlimit(resource_type)
|
462
|
+
|
463
|
+
if current_soft < target_soft_limit:
|
464
|
+
try:
|
465
|
+
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
466
|
+
except ValueError as e:
|
467
|
+
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
@@ -0,0 +1,139 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
HumanEval: Evaluating Large Language Models Trained on Code
|
5
|
+
Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
|
6
|
+
https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import multiprocessing
|
12
|
+
import random
|
13
|
+
import re
|
14
|
+
from collections import Counter, defaultdict
|
15
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
|
+
from io import BytesIO
|
17
|
+
from typing import Any, Tuple
|
18
|
+
|
19
|
+
import blobfile as bf
|
20
|
+
import tqdm
|
21
|
+
|
22
|
+
try:
|
23
|
+
from human_eval.data import HUMAN_EVAL, read_problems
|
24
|
+
from human_eval.evaluation import estimate_pass_at_k
|
25
|
+
from human_eval.execution import check_correctness # , unsafe_execute
|
26
|
+
except (ImportError, ModuleNotFoundError):
|
27
|
+
print("\nPlease install human-eval at https://github.com/openai/human-eval.\n")
|
28
|
+
raise
|
29
|
+
|
30
|
+
from sglang.test import simple_eval_common as common
|
31
|
+
from sglang.test.simple_eval_common import (
|
32
|
+
HTML_JINJA,
|
33
|
+
Eval,
|
34
|
+
EvalResult,
|
35
|
+
SamplerBase,
|
36
|
+
SingleEvalResult,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def evaluate_functional_correctness(
|
41
|
+
sample: dict[str, str],
|
42
|
+
completions: list[str],
|
43
|
+
n_workers: int = 4,
|
44
|
+
timeout: float = 3.0,
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Evaluates the functional correctness of generated samples, and writes
|
48
|
+
results to f"{sample_file}_results.jsonl.gz"
|
49
|
+
"""
|
50
|
+
import copy
|
51
|
+
|
52
|
+
# Check the generated samples against test suites.
|
53
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
54
|
+
futures = []
|
55
|
+
for i, completion in enumerate(completions):
|
56
|
+
args = (sample, completion, timeout, i)
|
57
|
+
future = executor.submit(check_correctness, *args)
|
58
|
+
futures.append(future)
|
59
|
+
results = []
|
60
|
+
for future in as_completed(futures):
|
61
|
+
result = future.result()
|
62
|
+
results.append(result)
|
63
|
+
passed = [int(r["passed"]) for r in results]
|
64
|
+
return passed
|
65
|
+
|
66
|
+
|
67
|
+
class HumanEval(Eval):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
num_examples: int | None,
|
71
|
+
num_threads: int,
|
72
|
+
num_samples_per_task: int = 5,
|
73
|
+
ks_passes: list[int] = [1, 2, 5],
|
74
|
+
timeout: int = 120,
|
75
|
+
):
|
76
|
+
self.seed = 0
|
77
|
+
self.examples = read_problems()
|
78
|
+
self.examples = list(self.examples.values())
|
79
|
+
|
80
|
+
self._num_examples = num_examples
|
81
|
+
if self._num_examples:
|
82
|
+
self.examples = random.Random(self.seed).sample(self.examples, num_examples)
|
83
|
+
self._num_samples_per_task = num_samples_per_task
|
84
|
+
self._ks_passes = ks_passes
|
85
|
+
self._timeout = timeout
|
86
|
+
self._num_threads = num_threads
|
87
|
+
|
88
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
89
|
+
instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
|
90
|
+
|
91
|
+
def find_code(completion):
|
92
|
+
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
|
93
|
+
matches = pattern.findall(completion)
|
94
|
+
extracted_answer = matches[0] if len(matches) >= 1 else completion
|
95
|
+
extracted_answer = extracted_answer[
|
96
|
+
extracted_answer.find(":\n ") + 2 :
|
97
|
+
] # remove signature
|
98
|
+
return extracted_answer
|
99
|
+
|
100
|
+
def fn(sample: dict[str, str]):
|
101
|
+
prompt_messages = [
|
102
|
+
sampler._pack_message(
|
103
|
+
role="user", content=instruction + sample["prompt"]
|
104
|
+
)
|
105
|
+
]
|
106
|
+
completions = [
|
107
|
+
find_code(sampler(prompt_messages))
|
108
|
+
for _ in range(self._num_samples_per_task)
|
109
|
+
]
|
110
|
+
results = evaluate_functional_correctness(sample, completions)
|
111
|
+
total = len(results)
|
112
|
+
correct = sum(results)
|
113
|
+
score = sum(results) / len(results)
|
114
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
115
|
+
prompt_messages=prompt_messages,
|
116
|
+
next_message=dict(content=completions[0], role="assistant"),
|
117
|
+
score=score,
|
118
|
+
correct_answer=[1] * len(results),
|
119
|
+
extracted_answer=results,
|
120
|
+
)
|
121
|
+
convo = prompt_messages + [
|
122
|
+
dict(content=completion, role="assistant") for completion in completions
|
123
|
+
]
|
124
|
+
return SingleEvalResult(
|
125
|
+
html=html,
|
126
|
+
score=score,
|
127
|
+
convo=convo,
|
128
|
+
metrics={
|
129
|
+
f"pass@{k}": estimate_pass_at_k([total], [correct], k)
|
130
|
+
# this will be aggrated so no need of .mean()
|
131
|
+
for k in self._ks_passes
|
132
|
+
if total >= k
|
133
|
+
},
|
134
|
+
)
|
135
|
+
|
136
|
+
results = common.map_with_progress(
|
137
|
+
fn, self.examples, num_threads=self._num_threads
|
138
|
+
)
|
139
|
+
return common.aggregate_results(results)
|
@@ -0,0 +1,120 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
Measuring Massive Multitask Language Understanding
|
5
|
+
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
|
6
|
+
https://arxiv.org/abs/2009.03300
|
7
|
+
"""
|
8
|
+
|
9
|
+
import random
|
10
|
+
import re
|
11
|
+
|
12
|
+
import pandas
|
13
|
+
|
14
|
+
from sglang.test import simple_eval_common as common
|
15
|
+
from sglang.test.simple_eval_common import (
|
16
|
+
ANSWER_PATTERN_MULTICHOICE,
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
SamplerBase,
|
21
|
+
SingleEvalResult,
|
22
|
+
format_multichoice_question,
|
23
|
+
)
|
24
|
+
|
25
|
+
subject2category = {
|
26
|
+
"abstract_algebra": "stem",
|
27
|
+
"anatomy": "other",
|
28
|
+
"astronomy": "stem",
|
29
|
+
"business_ethics": "other",
|
30
|
+
"clinical_knowledge": "other",
|
31
|
+
"college_biology": "stem",
|
32
|
+
"college_chemistry": "stem",
|
33
|
+
"college_computer_science": "stem",
|
34
|
+
"college_mathematics": "stem",
|
35
|
+
"college_medicine": "other",
|
36
|
+
"college_physics": "stem",
|
37
|
+
"computer_security": "stem",
|
38
|
+
"conceptual_physics": "stem",
|
39
|
+
"econometrics": "social_sciences",
|
40
|
+
"electrical_engineering": "stem",
|
41
|
+
"elementary_mathematics": "stem",
|
42
|
+
"formal_logic": "humanities",
|
43
|
+
"global_facts": "other",
|
44
|
+
"high_school_biology": "stem",
|
45
|
+
"high_school_chemistry": "stem",
|
46
|
+
"high_school_computer_science": "stem",
|
47
|
+
"high_school_european_history": "humanities",
|
48
|
+
"high_school_geography": "social_sciences",
|
49
|
+
"high_school_government_and_politics": "social_sciences",
|
50
|
+
"high_school_macroeconomics": "social_sciences",
|
51
|
+
"high_school_mathematics": "stem",
|
52
|
+
"high_school_microeconomics": "social_sciences",
|
53
|
+
"high_school_physics": "stem",
|
54
|
+
"high_school_psychology": "social_sciences",
|
55
|
+
"high_school_statistics": "stem",
|
56
|
+
"high_school_us_history": "humanities",
|
57
|
+
"high_school_world_history": "humanities",
|
58
|
+
"human_aging": "other",
|
59
|
+
"human_sexuality": "social_sciences",
|
60
|
+
"international_law": "humanities",
|
61
|
+
"jurisprudence": "humanities",
|
62
|
+
"logical_fallacies": "humanities",
|
63
|
+
"machine_learning": "stem",
|
64
|
+
"management": "other",
|
65
|
+
"marketing": "other",
|
66
|
+
"medical_genetics": "other",
|
67
|
+
"miscellaneous": "other",
|
68
|
+
"moral_disputes": "humanities",
|
69
|
+
"moral_scenarios": "humanities",
|
70
|
+
"nutrition": "other",
|
71
|
+
"philosophy": "humanities",
|
72
|
+
"prehistory": "humanities",
|
73
|
+
"professional_accounting": "other",
|
74
|
+
"professional_law": "humanities",
|
75
|
+
"professional_medicine": "other",
|
76
|
+
"professional_psychology": "social_sciences",
|
77
|
+
"public_relations": "social_sciences",
|
78
|
+
"security_studies": "social_sciences",
|
79
|
+
"sociology": "social_sciences",
|
80
|
+
"us_foreign_policy": "social_sciences",
|
81
|
+
"virology": "other",
|
82
|
+
"world_religions": "humanities",
|
83
|
+
}
|
84
|
+
|
85
|
+
|
86
|
+
class MMLUEval(Eval):
|
87
|
+
def __init__(self, filename: str, num_examples: int | None, num_threads: int):
|
88
|
+
df = pandas.read_csv(filename)
|
89
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
90
|
+
if num_examples:
|
91
|
+
examples = random.Random(0).sample(examples, num_examples)
|
92
|
+
self.examples = examples
|
93
|
+
self.num_threads = num_threads
|
94
|
+
|
95
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
96
|
+
def fn(row: dict):
|
97
|
+
prompt_messages = [
|
98
|
+
sampler._pack_message(
|
99
|
+
content=format_multichoice_question(row), role="user"
|
100
|
+
)
|
101
|
+
]
|
102
|
+
response_text = sampler(prompt_messages)
|
103
|
+
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
104
|
+
extracted_answer = match.group(1) if match else None
|
105
|
+
score = 1.0 if extracted_answer == row["Answer"] else 0.0
|
106
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
107
|
+
prompt_messages=prompt_messages,
|
108
|
+
next_message=dict(content=response_text, role="assistant"),
|
109
|
+
score=score,
|
110
|
+
correct_answer=row["Answer"],
|
111
|
+
extracted_answer=extracted_answer,
|
112
|
+
)
|
113
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
114
|
+
category = subject2category.get(row["Subject"], "other")
|
115
|
+
return SingleEvalResult(
|
116
|
+
html=html, score=score, metrics={category: score}, convo=convo
|
117
|
+
)
|
118
|
+
|
119
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
120
|
+
return common.aggregate_results(results)
|
sglang/test/test_programs.py
CHANGED
@@ -105,15 +105,14 @@ def test_decode_json_regex():
|
|
105
105
|
def decode_json(s):
|
106
106
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
107
107
|
|
108
|
-
s += "Generate a JSON object to describe the basic information of
|
108
|
+
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
109
109
|
|
110
110
|
with s.var_scope("json_output"):
|
111
111
|
s += "{\n"
|
112
112
|
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
113
113
|
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
114
114
|
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
115
|
-
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT
|
116
|
-
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
|
115
|
+
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
117
116
|
s += "}"
|
118
117
|
|
119
118
|
ret = decode_json.run(temperature=0.0)
|
@@ -129,7 +128,7 @@ def test_decode_json_regex():
|
|
129
128
|
def test_decode_json():
|
130
129
|
@sgl.function
|
131
130
|
def decode_json(s):
|
132
|
-
s += "Generate a JSON object to describe the basic information of
|
131
|
+
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
133
132
|
|
134
133
|
with s.var_scope("json_output"):
|
135
134
|
s += "{\n"
|
@@ -264,6 +263,7 @@ def test_parallel_decoding():
|
|
264
263
|
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
265
264
|
|
266
265
|
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
266
|
+
assert isinstance(ret["summary"], str)
|
267
267
|
|
268
268
|
|
269
269
|
def test_parallel_encoding(check_answer=True):
|
sglang/test/test_utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Common utilities for testing and benchmarking"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import subprocess
|
5
|
+
import time
|
4
6
|
from functools import partial
|
5
7
|
|
6
8
|
import numpy as np
|
@@ -11,6 +13,8 @@ from sglang.lang.backend.openai import OpenAI
|
|
11
13
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
12
14
|
from sglang.utils import get_exception_traceback
|
13
15
|
|
16
|
+
MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
17
|
+
|
14
18
|
|
15
19
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
16
20
|
assert url is not None
|
@@ -379,3 +383,31 @@ def get_call_select(args):
|
|
379
383
|
raise
|
380
384
|
|
381
385
|
return func
|
386
|
+
|
387
|
+
|
388
|
+
def popen_launch_server(model, port, timeout, *args):
|
389
|
+
command = [
|
390
|
+
"python3",
|
391
|
+
"-m",
|
392
|
+
"sglang.launch_server",
|
393
|
+
"--model-path",
|
394
|
+
model,
|
395
|
+
"--host",
|
396
|
+
"localhost",
|
397
|
+
"--port",
|
398
|
+
str(port),
|
399
|
+
*args,
|
400
|
+
]
|
401
|
+
process = subprocess.Popen(command, stdout=None, stderr=None)
|
402
|
+
base_url = f"http://localhost:{port}/v1"
|
403
|
+
|
404
|
+
start_time = time.time()
|
405
|
+
while time.time() - start_time < timeout:
|
406
|
+
try:
|
407
|
+
response = requests.get(f"{base_url}/models")
|
408
|
+
if response.status_code == 200:
|
409
|
+
return process
|
410
|
+
except requests.RequestException:
|
411
|
+
pass
|
412
|
+
time.sleep(10)
|
413
|
+
raise TimeoutError("Server failed to start within the timeout period.")
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.2.
|
1
|
+
__version__ = "0.2.9"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.9
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -299,8 +299,8 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
|
299
299
|
|
300
300
|
### Method 2: From source
|
301
301
|
```
|
302
|
-
# Use the stable v0.2.
|
303
|
-
git clone -b v0.2.
|
302
|
+
# Use the stable v0.2.9 branch
|
303
|
+
git clone -b v0.2.9 https://github.com/sgl-project/sglang.git
|
304
304
|
cd sglang
|
305
305
|
|
306
306
|
pip install --upgrade pip
|
@@ -1,13 +1,13 @@
|
|
1
1
|
sglang/__init__.py,sha256=ECjvAWlxIwKtUIXGchfkoCIbF-iqLjH-Q0o8xHTlVNY,1352
|
2
2
|
sglang/api.py,sha256=s_P8BvGDCQ0PiqOapr2TLFge1NA7QmKqUx6bFQ8Q5GQ,5676
|
3
3
|
sglang/bench_latency.py,sha256=JPatRvstM3nXb-ViVgtR-TaRrFHpcHzqoDG7BQmRYK8,10539
|
4
|
-
sglang/bench_serving.py,sha256=
|
4
|
+
sglang/bench_serving.py,sha256=M0YQT6xElpkx-FtmyUe6lhX1DZfVLGh54qd6qfFYquc,34801
|
5
5
|
sglang/check_env.py,sha256=Eeb_20VetnlEFYSRcHFlNqt85lYUQN60NEtkoX7ahPA,4121
|
6
6
|
sglang/global_config.py,sha256=CyhGL7PE-KlMcg7IHWykzImU1y4NQlpeIlh9lHA77uo,1749
|
7
7
|
sglang/launch_server.py,sha256=Gg8CwNlTCCfg1dF65ZT9ePLxOT9LKtY79GhIPG6PCrU,358
|
8
8
|
sglang/launch_server_llavavid.py,sha256=40uaazMsavKuk6YXFa5v37kdUpFGuealgJJeph1g8gU,1025
|
9
9
|
sglang/utils.py,sha256=r0Z7hY_bFFk-b6WeQJir9br-hCW2-p7n5E7Et2WziaQ,8776
|
10
|
-
sglang/version.py,sha256=
|
10
|
+
sglang/version.py,sha256=F8OVhAhMXSkvvXYgZtbPn2SG1AQC3joK4yu-FrHt81Y,22
|
11
11
|
sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
sglang/lang/chat_template.py,sha256=psIlhaDo70twgLrx5Lgln03metLEA3-FZuixeI0Y7Ao,13309
|
13
13
|
sglang/lang/compiler.py,sha256=UiXUmPR9wBAPtnORrLcyQX8Uh0ZL0nKeV8ZgBozAJPw,7531
|
@@ -26,9 +26,9 @@ sglang/srt/hf_transformers_utils.py,sha256=Fg-3panb6lsqOhHmAYA0ivkXyBjdnvY5mqvil
|
|
26
26
|
sglang/srt/mm_utils.py,sha256=n7_GmbOM_0IWVXovpM34rKIBw0Py9yb_NXSQw27u4OA,9454
|
27
27
|
sglang/srt/model_config.py,sha256=DO7m84WiT3dzPWmyKz_UXDAHEdqEjq8Lq5wCjzjYMME,6023
|
28
28
|
sglang/srt/sampling_params.py,sha256=uZFDlTUPnNR5_3IDH-INDeN-tm6LlRkC2KT-B3njxJs,3687
|
29
|
-
sglang/srt/server.py,sha256=
|
30
|
-
sglang/srt/server_args.py,sha256=
|
31
|
-
sglang/srt/utils.py,sha256=
|
29
|
+
sglang/srt/server.py,sha256=cDHUmLqj7MjF-3L9WcfA-4z9dRl55cwF5ygXuncMl-Q,15852
|
30
|
+
sglang/srt/server_args.py,sha256=wdRlxR-509RfNYuMQoxUAefMwoc5eme6sYwEMyRBHmk,16034
|
31
|
+
sglang/srt/utils.py,sha256=5wgGe6kI59JAmf8kxLsItulJ4xQaOJHHYaWWd6_WWmo,23384
|
32
32
|
sglang/srt/constrained/__init__.py,sha256=NLpZGj9RIx83ejDrM_pfaRtqGgaPq_ggJszPQENUJ2E,2037
|
33
33
|
sglang/srt/constrained/base_tool_cache.py,sha256=1_m-AivPtWRwUgGiEZBafCrSFUGahK4UM4vgAd8TkMg,2004
|
34
34
|
sglang/srt/constrained/fsm_cache.py,sha256=GoPBr_9ZdJizF2PKbYoQw2I4ckfrUYwCeMZxB9sY3TM,2639
|
@@ -37,7 +37,7 @@ sglang/srt/layers/context_flashattention_nopad.py,sha256=r_TpHuYAVgq1pN81PiWe1be
|
|
37
37
|
sglang/srt/layers/extend_attention.py,sha256=zuNnAdL_wF6BX0Mwn1dgDJvh3YJjYwqa5Fbzp8muOVc,12573
|
38
38
|
sglang/srt/layers/fused_moe.py,sha256=KmyXwau2OOZpQimGIQrHptzGNs1trIud5AKEEKXdzPU,20823
|
39
39
|
sglang/srt/layers/linear.py,sha256=3Se2FRXyqXcd-uvNx2b7s-jolsUTEVeYBMYHmV82wPw,34518
|
40
|
-
sglang/srt/layers/logits_processor.py,sha256=
|
40
|
+
sglang/srt/layers/logits_processor.py,sha256=5Cg3h5b4H0EUeOJRst3IOMWL5dniP63A5s15BRkAMmk,11091
|
41
41
|
sglang/srt/layers/radix_attention.py,sha256=tdA-kdd9LQY1wbw3iYuy-9cikVJYmy3EctwAlUfN-Uo,6945
|
42
42
|
sglang/srt/layers/token_attention.py,sha256=ylUqUnozJCCohxTGAiiP3sxgUrcXfEVic8-qgcHYDj4,7968
|
43
43
|
sglang/srt/layers/quantization/__init__.py,sha256=JMlgE-FWS759lfQ9Uc6mGFqBbTFLlvKeVEFpZLATe14,2536
|
@@ -48,7 +48,7 @@ sglang/srt/managers/detokenizer_manager.py,sha256=GXWdW4n2N-otL3zcgdr0t1PcEe2EmQ
|
|
48
48
|
sglang/srt/managers/io_struct.py,sha256=Rz7Ur9Yw6prDGdy6XjsSiUmVBccS6cef-G_9TW7HA_4,7105
|
49
49
|
sglang/srt/managers/policy_scheduler.py,sha256=ajSB-gCC6VJkXvnKU8FYU3Kgcigozp2pMTwF84Wp14o,3138
|
50
50
|
sglang/srt/managers/schedule_batch.py,sha256=LIoVCPNivh0u1dOrrWRgFD6a4ywq3nrG_4dNgCK0kIw,37697
|
51
|
-
sglang/srt/managers/tokenizer_manager.py,sha256=
|
51
|
+
sglang/srt/managers/tokenizer_manager.py,sha256=rtZ44aiZOMHLHkXDhMgj0HDR3gExpeGjWfoCD0PfG_o,20574
|
52
52
|
sglang/srt/managers/tp_worker.py,sha256=JPLneFwcPlmPXZX1QxZHWgcdau8FC8wNuVqfCqsgOkU,35234
|
53
53
|
sglang/srt/mem_cache/base_cache.py,sha256=czyN8IumXcMQskYOZDV3DzjfD4kdR-qwLVxceDqnOmE,788
|
54
54
|
sglang/srt/mem_cache/chunk_cache.py,sha256=u1mkGoTI7_31H0i0mhKT7S57StYSsdmsSPqyGubE7lY,1560
|
@@ -82,14 +82,16 @@ sglang/srt/models/qwen2.py,sha256=mXlVd6UTCXY3VdgodFpQnlaY-NYLIbA-SknxdA9R13w,12
|
|
82
82
|
sglang/srt/models/qwen2_moe.py,sha256=YYdJEezic7GyW-_bXlNIaqBa0C4IHQpz_vuRBLxms4k,18141
|
83
83
|
sglang/srt/models/stablelm.py,sha256=b3d-ZwLQoLjZ6CupnkIq7d-z9tzGSxAyIcgSmZiZxZw,11362
|
84
84
|
sglang/srt/models/yivl.py,sha256=p4s_D_m4H2exP4b91Y-CTkq8T-eIG3DJsFy9pB0e7TM,4932
|
85
|
-
sglang/srt/openai_api/adapter.py,sha256=
|
85
|
+
sglang/srt/openai_api/adapter.py,sha256=h6TIU0Fu3jU361pye4J12vcDug7UJJRPiBAY_HfFUuE,32599
|
86
86
|
sglang/srt/openai_api/protocol.py,sha256=JXLnnQ63I-bJv93ICPfP0cBpyomQA5IYE_mkUg5X4Es,8177
|
87
|
-
sglang/test/
|
88
|
-
sglang/test/
|
89
|
-
sglang/test/
|
90
|
-
sglang/test/
|
91
|
-
sglang
|
92
|
-
sglang
|
93
|
-
sglang-0.2.
|
94
|
-
sglang-0.2.
|
95
|
-
sglang-0.2.
|
87
|
+
sglang/test/run_eval.py,sha256=WvMLSi70G9fhruP8cPLOfDJ9XEKL7yNn2pylx-7tNsQ,3054
|
88
|
+
sglang/test/simple_eval_common.py,sha256=Qh1-iEXJCKfJmgpAzNSp28fcP1TUJzt3s9i1FjvemHY,12340
|
89
|
+
sglang/test/simple_eval_humaneval.py,sha256=IW0ZC6D4SXu06IJiMoAY9DK9SMsTOlDPAwu4cfbJco0,5826
|
90
|
+
sglang/test/simple_eval_mmlu.py,sha256=KqSSdSu2qfoKQ870ttxev1NJ7c90xv2mvKOQsSODtAw,4326
|
91
|
+
sglang/test/test_programs.py,sha256=e9_ifoIvuI1Ctkbkz3wfdZLBBSRikby8ywcodBIkf9M,13826
|
92
|
+
sglang/test/test_utils.py,sha256=PndOL1zdseMrpHTHGmgsHHepxqYBn__eNLrlsSXLy6k,11905
|
93
|
+
sglang-0.2.9.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
94
|
+
sglang-0.2.9.dist-info/METADATA,sha256=8vhH67MeR6EdJepUSvmqKSneJTQ8l_9LD9L6FfzyrHk,33214
|
95
|
+
sglang-0.2.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
96
|
+
sglang-0.2.9.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
|
97
|
+
sglang-0.2.9.dist-info/RECORD,,
|
sglang/test/test_conversation.py
DELETED
@@ -1,46 +0,0 @@
|
|
1
|
-
from sglang.srt.conversation import generate_chat_conv
|
2
|
-
from sglang.srt.managers.openai_api.protocol import (
|
3
|
-
ChatCompletionMessageContentImagePart,
|
4
|
-
ChatCompletionMessageContentImageURL,
|
5
|
-
ChatCompletionMessageContentTextPart,
|
6
|
-
ChatCompletionMessageGenericParam,
|
7
|
-
ChatCompletionMessageUserParam,
|
8
|
-
ChatCompletionRequest,
|
9
|
-
)
|
10
|
-
|
11
|
-
|
12
|
-
def test_chat_completion_to_conv_image():
|
13
|
-
"""Test that we can convert a chat image request to a convo"""
|
14
|
-
request = ChatCompletionRequest(
|
15
|
-
model="default",
|
16
|
-
messages=[
|
17
|
-
ChatCompletionMessageGenericParam(
|
18
|
-
role="system", content="You are a helpful AI assistant"
|
19
|
-
),
|
20
|
-
ChatCompletionMessageUserParam(
|
21
|
-
role="user",
|
22
|
-
content=[
|
23
|
-
ChatCompletionMessageContentTextPart(
|
24
|
-
type="text", text="Describe this image"
|
25
|
-
),
|
26
|
-
ChatCompletionMessageContentImagePart(
|
27
|
-
type="image_url",
|
28
|
-
image_url=ChatCompletionMessageContentImageURL(
|
29
|
-
url="https://someurl.com"
|
30
|
-
),
|
31
|
-
),
|
32
|
-
],
|
33
|
-
),
|
34
|
-
],
|
35
|
-
)
|
36
|
-
conv = generate_chat_conv(request, "vicuna_v1.1")
|
37
|
-
assert conv.messages == [
|
38
|
-
["USER", "Describe this image<image>"],
|
39
|
-
["ASSISTANT", None],
|
40
|
-
]
|
41
|
-
assert conv.system_message == "You are a helpful AI assistant"
|
42
|
-
assert conv.image_data == ["https://someurl.com"]
|
43
|
-
|
44
|
-
|
45
|
-
if __name__ == "__main__":
|
46
|
-
test_chat_completion_to_conv_image()
|
@@ -1,51 +0,0 @@
|
|
1
|
-
from sglang.srt.managers.openai_api.protocol import (
|
2
|
-
ChatCompletionMessageContentImagePart,
|
3
|
-
ChatCompletionMessageContentImageURL,
|
4
|
-
ChatCompletionMessageContentTextPart,
|
5
|
-
ChatCompletionMessageGenericParam,
|
6
|
-
ChatCompletionMessageUserParam,
|
7
|
-
ChatCompletionRequest,
|
8
|
-
)
|
9
|
-
|
10
|
-
|
11
|
-
def test_chat_completion_request_image():
|
12
|
-
"""Test that Chat Completion Requests with images can be converted."""
|
13
|
-
|
14
|
-
image_request = {
|
15
|
-
"model": "default",
|
16
|
-
"messages": [
|
17
|
-
{"role": "system", "content": "You are a helpful AI assistant"},
|
18
|
-
{
|
19
|
-
"role": "user",
|
20
|
-
"content": [
|
21
|
-
{"type": "text", "text": "Describe this image"},
|
22
|
-
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
|
23
|
-
],
|
24
|
-
},
|
25
|
-
],
|
26
|
-
"temperature": 0,
|
27
|
-
"max_tokens": 64,
|
28
|
-
}
|
29
|
-
request = ChatCompletionRequest(**image_request)
|
30
|
-
assert len(request.messages) == 2
|
31
|
-
assert request.messages[0] == ChatCompletionMessageGenericParam(
|
32
|
-
role="system", content="You are a helpful AI assistant"
|
33
|
-
)
|
34
|
-
assert request.messages[1] == ChatCompletionMessageUserParam(
|
35
|
-
role="user",
|
36
|
-
content=[
|
37
|
-
ChatCompletionMessageContentTextPart(
|
38
|
-
type="text", text="Describe this image"
|
39
|
-
),
|
40
|
-
ChatCompletionMessageContentImagePart(
|
41
|
-
type="image_url",
|
42
|
-
image_url=ChatCompletionMessageContentImageURL(
|
43
|
-
url="https://someurl.com"
|
44
|
-
),
|
45
|
-
),
|
46
|
-
],
|
47
|
-
)
|
48
|
-
|
49
|
-
|
50
|
-
if __name__ == "__main__":
|
51
|
-
test_chat_completion_request_image()
|
File without changes
|
File without changes
|
File without changes
|