sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/test/runners.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
import json
|
17
|
+
import multiprocessing
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import List, Union
|
20
|
+
|
21
|
+
import torch
|
22
|
+
import torch.nn.functional as F
|
23
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24
|
+
|
25
|
+
from sglang.srt.server import Runtime
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = [
|
28
|
+
"The capital of France is",
|
29
|
+
"The capital of the United Kindom is",
|
30
|
+
"Today is a sunny day and I like",
|
31
|
+
]
|
32
|
+
|
33
|
+
NUM_TOP_LOGPROBS = 5
|
34
|
+
|
35
|
+
|
36
|
+
def is_embedding_model(model_path):
|
37
|
+
# FIXME incomplete list
|
38
|
+
if "e5-mistral-7b-instruct" in model_path.lower():
|
39
|
+
return True
|
40
|
+
return False
|
41
|
+
|
42
|
+
|
43
|
+
def get_dtype_str(torch_dtype):
|
44
|
+
if torch_dtype is torch.float16:
|
45
|
+
return "float16"
|
46
|
+
else:
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class ModelOutput:
|
52
|
+
output_strs: str = None
|
53
|
+
top_input_logprobs: torch.Tensor = None
|
54
|
+
top_output_logprobs: torch.Tensor = None
|
55
|
+
embed_logits: torch.Tensor = None
|
56
|
+
|
57
|
+
|
58
|
+
class HFRunner:
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
model_path,
|
62
|
+
torch_dtype=torch.float16,
|
63
|
+
is_embedding_model=None,
|
64
|
+
):
|
65
|
+
self.in_queue = multiprocessing.Queue()
|
66
|
+
self.out_queue = multiprocessing.Queue()
|
67
|
+
|
68
|
+
self.model_proc = multiprocessing.Process(
|
69
|
+
target=self.start_model_process,
|
70
|
+
args=(
|
71
|
+
self.in_queue,
|
72
|
+
self.out_queue,
|
73
|
+
model_path,
|
74
|
+
torch_dtype,
|
75
|
+
is_embedding_model,
|
76
|
+
),
|
77
|
+
)
|
78
|
+
self.model_proc.start()
|
79
|
+
|
80
|
+
def start_model_process(
|
81
|
+
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
82
|
+
):
|
83
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
84
|
+
model_path,
|
85
|
+
torch_dtype=torch_dtype,
|
86
|
+
trust_remote_code=True,
|
87
|
+
)
|
88
|
+
|
89
|
+
self.is_embedding_model = (
|
90
|
+
is_embedding_model(model_path)
|
91
|
+
if is_embedding_model is None
|
92
|
+
else is_embedding_model
|
93
|
+
)
|
94
|
+
if not self.is_embedding_model:
|
95
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
96
|
+
model_path,
|
97
|
+
torch_dtype=torch_dtype,
|
98
|
+
low_cpu_mem_usage=True,
|
99
|
+
trust_remote_code=True,
|
100
|
+
).cuda()
|
101
|
+
else:
|
102
|
+
from sentence_transformers import SentenceTransformer
|
103
|
+
|
104
|
+
self.model = SentenceTransformer(
|
105
|
+
model_path,
|
106
|
+
device="cpu",
|
107
|
+
).to(dtype=torch_dtype)
|
108
|
+
|
109
|
+
while True:
|
110
|
+
prompts, max_new_tokens = in_queue.get()
|
111
|
+
if prompts is not None:
|
112
|
+
if not self.is_embedding_model:
|
113
|
+
output_strs = []
|
114
|
+
prefill_logprobs = []
|
115
|
+
for p in prompts:
|
116
|
+
if isinstance(p, str):
|
117
|
+
input_ids = self.tokenizer.encode(
|
118
|
+
p, return_tensors="pt"
|
119
|
+
).cuda()
|
120
|
+
else:
|
121
|
+
input_ids = torch.tensor([p], device="cuda")
|
122
|
+
|
123
|
+
output_ids = self.model.generate(
|
124
|
+
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
125
|
+
)
|
126
|
+
output_strs.append(self.tokenizer.decode(output_ids[0]))
|
127
|
+
|
128
|
+
logits = self.model.forward(input_ids).logits[0]
|
129
|
+
logprobs = F.log_softmax(
|
130
|
+
logits, dim=-1, dtype=torch.float32
|
131
|
+
).tolist()
|
132
|
+
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
133
|
+
# print("index", index_of_max)
|
134
|
+
logprobs = [
|
135
|
+
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
136
|
+
for token_logprobs in logprobs
|
137
|
+
]
|
138
|
+
prefill_logprobs.append(logprobs)
|
139
|
+
|
140
|
+
out_queue.put(
|
141
|
+
ModelOutput(
|
142
|
+
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
143
|
+
)
|
144
|
+
)
|
145
|
+
|
146
|
+
else:
|
147
|
+
assert isinstance(prompts, List[str])
|
148
|
+
logits = self.model.encode(prompts).tolist()
|
149
|
+
|
150
|
+
out_queue.put(ModelOutput(embed_logits=logits))
|
151
|
+
|
152
|
+
def forward(
|
153
|
+
self,
|
154
|
+
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
155
|
+
max_new_tokens=64,
|
156
|
+
):
|
157
|
+
self.in_queue.put((prompts, max_new_tokens))
|
158
|
+
return self.out_queue.get()
|
159
|
+
|
160
|
+
def terminate(self):
|
161
|
+
self.model_proc.terminate()
|
162
|
+
self.in_queue = self.out_queue = None
|
163
|
+
|
164
|
+
def __enter__(self):
|
165
|
+
return self
|
166
|
+
|
167
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
168
|
+
self.model_proc.terminate()
|
169
|
+
self.in_queue = self.out_queue = None
|
170
|
+
|
171
|
+
|
172
|
+
class SRTRunner:
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
model_path,
|
176
|
+
tp_size=1,
|
177
|
+
torch_dtype=torch.float16,
|
178
|
+
is_embedding_model=None,
|
179
|
+
):
|
180
|
+
self.is_embedding_model = (
|
181
|
+
is_embedding_model(model_path)
|
182
|
+
if is_embedding_model is None
|
183
|
+
else is_embedding_model
|
184
|
+
)
|
185
|
+
if self.is_embedding_model:
|
186
|
+
raise NotImplementedError()
|
187
|
+
|
188
|
+
self.runtime = Runtime(
|
189
|
+
model_path=model_path,
|
190
|
+
tp_size=tp_size,
|
191
|
+
dtype=get_dtype_str(torch_dtype),
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(
|
195
|
+
self,
|
196
|
+
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
197
|
+
max_new_tokens=64,
|
198
|
+
):
|
199
|
+
# the return value contains logprobs from prefill
|
200
|
+
output_strs = []
|
201
|
+
top_input_logprobs = []
|
202
|
+
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
203
|
+
for prompt in prompts:
|
204
|
+
response = self.runtime.generate(
|
205
|
+
prompt,
|
206
|
+
sampling_params=sampling_params,
|
207
|
+
return_logprob=True,
|
208
|
+
top_logprobs_num=NUM_TOP_LOGPROBS,
|
209
|
+
)
|
210
|
+
response = json.loads(response)
|
211
|
+
output_strs.append(response["text"])
|
212
|
+
top_input_logprobs.append(
|
213
|
+
[
|
214
|
+
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
215
|
+
for x in response["meta_info"]["input_top_logprobs"][1:]
|
216
|
+
]
|
217
|
+
+ [
|
218
|
+
[
|
219
|
+
tup[0]
|
220
|
+
for tup in response["meta_info"]["output_top_logprobs"][0][
|
221
|
+
:NUM_TOP_LOGPROBS
|
222
|
+
]
|
223
|
+
]
|
224
|
+
]
|
225
|
+
)
|
226
|
+
# print(response["meta_info"]["output_top_logprobs"][0])
|
227
|
+
|
228
|
+
return ModelOutput(
|
229
|
+
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
230
|
+
)
|
231
|
+
|
232
|
+
def __enter__(self):
|
233
|
+
return self
|
234
|
+
|
235
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
236
|
+
self.runtime.shutdown()
|
237
|
+
del self.runtime
|
@@ -7,7 +7,7 @@ import time
|
|
7
7
|
from collections import defaultdict
|
8
8
|
from dataclasses import dataclass, field
|
9
9
|
from multiprocessing.pool import ThreadPool
|
10
|
-
from typing import Any
|
10
|
+
from typing import Any, Dict, List, Tuple
|
11
11
|
|
12
12
|
import httpx
|
13
13
|
import jinja2
|
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
|
27
|
-
Message =
|
28
|
-
MessageList =
|
27
|
+
Message = Dict[str, Any] # keys role, content
|
28
|
+
MessageList = List[Message]
|
29
29
|
|
30
30
|
|
31
31
|
class SamplerBase:
|
@@ -45,9 +45,9 @@ class EvalResult:
|
|
45
45
|
"""
|
46
46
|
|
47
47
|
score: float | None # top-line metric
|
48
|
-
metrics:
|
49
|
-
htmls:
|
50
|
-
convos:
|
48
|
+
metrics: Dict[str, float] | None # other metrics
|
49
|
+
htmls: List[str] # strings of valid HTML
|
50
|
+
convos: List[MessageList] # sampled conversations
|
51
51
|
|
52
52
|
|
53
53
|
@dataclass
|
@@ -57,7 +57,7 @@ class SingleEvalResult:
|
|
57
57
|
"""
|
58
58
|
|
59
59
|
score: float | None
|
60
|
-
metrics:
|
60
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
61
61
|
html: str | None = None
|
62
62
|
convo: MessageList | None = None # sampled conversation
|
63
63
|
|
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
|
|
270
270
|
|
271
271
|
|
272
272
|
def aggregate_results(
|
273
|
-
single_eval_results:
|
274
|
-
default_stats:
|
275
|
-
name2stats:
|
273
|
+
single_eval_results: List[SingleEvalResult],
|
274
|
+
default_stats: Tuple[str] = ("mean", "std"),
|
275
|
+
name2stats: Dict[str, Tuple[str]] | None = None,
|
276
276
|
) -> EvalResult:
|
277
277
|
"""
|
278
278
|
Aggregate results from multiple evaluations into a single EvalResult.
|
@@ -302,7 +302,7 @@ def aggregate_results(
|
|
302
302
|
)
|
303
303
|
|
304
304
|
|
305
|
-
def map_with_progress(f: callable, xs:
|
305
|
+
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
306
306
|
"""
|
307
307
|
Apply f to each element of xs, using a ThreadPool, and show progress.
|
308
308
|
"""
|
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
|
|
422
422
|
)
|
423
423
|
|
424
424
|
|
425
|
-
def make_report_from_example_htmls(htmls:
|
425
|
+
def make_report_from_example_htmls(htmls: List[str]):
|
426
426
|
"""
|
427
427
|
Create a standalone HTML report from a list of example htmls
|
428
428
|
"""
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
5
|
+
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
6
|
+
https://arxiv.org/abs/2311.12022
|
7
|
+
"""
|
8
|
+
|
9
|
+
import random
|
10
|
+
import re
|
11
|
+
|
12
|
+
import pandas
|
13
|
+
|
14
|
+
from sglang.test import simple_eval_common as common
|
15
|
+
from sglang.test.simple_eval_common import (
|
16
|
+
ANSWER_PATTERN_MULTICHOICE,
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
MessageList,
|
21
|
+
SamplerBase,
|
22
|
+
SingleEvalResult,
|
23
|
+
format_multichoice_question,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class GPQAEval(Eval):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
filename: str,
|
31
|
+
num_examples: int | None,
|
32
|
+
num_threads: int,
|
33
|
+
n_repeats: int = 1,
|
34
|
+
):
|
35
|
+
df = pandas.read_csv(filename)
|
36
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
37
|
+
rng = random.Random(0)
|
38
|
+
if num_examples:
|
39
|
+
assert n_repeats == 1, "n_repeats only supported for num_examples"
|
40
|
+
examples = rng.sample(examples, num_examples)
|
41
|
+
examples = examples * n_repeats
|
42
|
+
examples = [
|
43
|
+
example | {"permutation": rng.sample(range(4), 4)} for example in examples
|
44
|
+
]
|
45
|
+
self.examples = examples
|
46
|
+
self.n_repeats = n_repeats
|
47
|
+
self.num_threads = num_threads
|
48
|
+
|
49
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
50
|
+
def fn(row: dict):
|
51
|
+
choices = [
|
52
|
+
row["Correct Answer"],
|
53
|
+
row["Incorrect Answer 1"],
|
54
|
+
row["Incorrect Answer 2"],
|
55
|
+
row["Incorrect Answer 3"],
|
56
|
+
]
|
57
|
+
choices = [choices[i] for i in row["permutation"]]
|
58
|
+
correct_index = choices.index(row["Correct Answer"])
|
59
|
+
correct_answer = "ABCD"[correct_index]
|
60
|
+
choices_dict = dict(
|
61
|
+
A=choices[0],
|
62
|
+
B=choices[1],
|
63
|
+
C=choices[2],
|
64
|
+
D=choices[3],
|
65
|
+
Question=row["Question"],
|
66
|
+
)
|
67
|
+
prompt_messages = [
|
68
|
+
sampler._pack_message(
|
69
|
+
content=format_multichoice_question(choices_dict), role="user"
|
70
|
+
)
|
71
|
+
]
|
72
|
+
response_text = sampler(prompt_messages)
|
73
|
+
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
74
|
+
extracted_answer = match.group(1) if match else None
|
75
|
+
score = 1.0 if extracted_answer == correct_answer else 0.0
|
76
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
77
|
+
prompt_messages=prompt_messages,
|
78
|
+
next_message=dict(content=response_text, role="assistant"),
|
79
|
+
score=score,
|
80
|
+
correct_answer=correct_answer,
|
81
|
+
extracted_answer=extracted_answer,
|
82
|
+
)
|
83
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
84
|
+
return SingleEvalResult(
|
85
|
+
html=html,
|
86
|
+
score=score,
|
87
|
+
convo=convo,
|
88
|
+
metrics={"chars": len(response_text)},
|
89
|
+
)
|
90
|
+
|
91
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
92
|
+
return common.aggregate_results(results)
|
@@ -14,7 +14,7 @@ import re
|
|
14
14
|
from collections import Counter, defaultdict
|
15
15
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
16
|
from io import BytesIO
|
17
|
-
from typing import Any, Tuple
|
17
|
+
from typing import Any, Dict, List, Tuple
|
18
18
|
|
19
19
|
import blobfile as bf
|
20
20
|
import tqdm
|
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
|
|
38
38
|
|
39
39
|
|
40
40
|
def evaluate_functional_correctness(
|
41
|
-
sample:
|
42
|
-
completions:
|
41
|
+
sample: Dict[str, str],
|
42
|
+
completions: List[str],
|
43
43
|
n_workers: int = 4,
|
44
44
|
timeout: float = 3.0,
|
45
45
|
):
|
@@ -70,7 +70,7 @@ class HumanEval(Eval):
|
|
70
70
|
num_examples: int | None,
|
71
71
|
num_threads: int,
|
72
72
|
num_samples_per_task: int = 5,
|
73
|
-
ks_passes:
|
73
|
+
ks_passes: List[int] = [1, 2, 5],
|
74
74
|
timeout: int = 120,
|
75
75
|
):
|
76
76
|
self.seed = 0
|
@@ -97,7 +97,7 @@ class HumanEval(Eval):
|
|
97
97
|
] # remove signature
|
98
98
|
return extracted_answer
|
99
99
|
|
100
|
-
def fn(sample:
|
100
|
+
def fn(sample: Dict[str, str]):
|
101
101
|
prompt_messages = [
|
102
102
|
sampler._pack_message(
|
103
103
|
role="user", content=instruction + sample["prompt"]
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
Measuring Mathematical Problem Solving With the MATH Dataset
|
5
|
+
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
|
6
|
+
https://arxiv.org/abs/2103.03874
|
7
|
+
"""
|
8
|
+
|
9
|
+
import random
|
10
|
+
import re
|
11
|
+
|
12
|
+
import pandas
|
13
|
+
|
14
|
+
from sglang.test import simple_eval_common as common
|
15
|
+
from sglang.test.simple_eval_common import (
|
16
|
+
ANSWER_PATTERN,
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
SamplerBase,
|
21
|
+
SingleEvalResult,
|
22
|
+
check_equality,
|
23
|
+
)
|
24
|
+
|
25
|
+
QUERY_TEMPLATE = """
|
26
|
+
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
|
27
|
+
|
28
|
+
{Question}
|
29
|
+
|
30
|
+
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
|
31
|
+
""".strip()
|
32
|
+
|
33
|
+
|
34
|
+
class MathEval(Eval):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
filename: str,
|
38
|
+
equality_checker: SamplerBase,
|
39
|
+
num_examples: int | None,
|
40
|
+
num_threads: int,
|
41
|
+
):
|
42
|
+
df = pandas.read_csv(filename)
|
43
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
44
|
+
if num_examples:
|
45
|
+
examples = random.Random(0).sample(examples, num_examples)
|
46
|
+
self.examples = examples
|
47
|
+
self.equality_checker = equality_checker
|
48
|
+
self.num_threads = num_threads
|
49
|
+
|
50
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
51
|
+
def fn(row: dict):
|
52
|
+
prompt_messages = [
|
53
|
+
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
|
54
|
+
]
|
55
|
+
response_text = sampler(prompt_messages)
|
56
|
+
match = re.search(ANSWER_PATTERN, response_text)
|
57
|
+
extracted_answer = match.group(1) if match else None
|
58
|
+
score = float(
|
59
|
+
check_equality(self.equality_checker, row["Answer"], extracted_answer)
|
60
|
+
)
|
61
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
62
|
+
prompt_messages=prompt_messages,
|
63
|
+
next_message=dict(content=response_text, role="assistant"),
|
64
|
+
score=score,
|
65
|
+
correct_answer=row["Answer"],
|
66
|
+
extracted_answer=extracted_answer,
|
67
|
+
)
|
68
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
69
|
+
return SingleEvalResult(html=html, score=score, convo=convo)
|
70
|
+
|
71
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
72
|
+
return common.aggregate_results(results)
|
sglang/test/test_utils.py
CHANGED
@@ -1,9 +1,14 @@
|
|
1
1
|
"""Common utilities for testing and benchmarking"""
|
2
2
|
|
3
|
+
import argparse
|
3
4
|
import asyncio
|
5
|
+
import multiprocessing
|
4
6
|
import subprocess
|
7
|
+
import threading
|
5
8
|
import time
|
9
|
+
import unittest
|
6
10
|
from functools import partial
|
11
|
+
from typing import Callable, List, Optional
|
7
12
|
|
8
13
|
import numpy as np
|
9
14
|
import requests
|
@@ -13,7 +18,7 @@ from sglang.lang.backend.openai import OpenAI
|
|
13
18
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
14
19
|
from sglang.utils import get_exception_traceback
|
15
20
|
|
16
|
-
|
21
|
+
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
17
22
|
|
18
23
|
|
19
24
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
|
|
247
252
|
return choices.index(answer)
|
248
253
|
|
249
254
|
|
250
|
-
def add_common_other_args_and_parse(parser):
|
255
|
+
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
251
256
|
parser.add_argument("--parallel", type=int, default=64)
|
252
257
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
253
258
|
parser.add_argument("--port", type=int, default=None)
|
@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
|
|
286
291
|
return args
|
287
292
|
|
288
293
|
|
289
|
-
def add_common_sglang_args_and_parse(parser):
|
294
|
+
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
290
295
|
parser.add_argument("--parallel", type=int, default=64)
|
291
296
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
292
297
|
parser.add_argument("--port", type=int, default=30000)
|
@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
|
|
296
301
|
return args
|
297
302
|
|
298
303
|
|
299
|
-
def select_sglang_backend(args):
|
304
|
+
def select_sglang_backend(args: argparse.Namespace):
|
300
305
|
if args.backend.startswith("srt"):
|
301
306
|
if args.backend == "srt-no-parallel":
|
302
307
|
global_config.enable_parallel_decoding = False
|
@@ -309,7 +314,7 @@ def select_sglang_backend(args):
|
|
309
314
|
return backend
|
310
315
|
|
311
316
|
|
312
|
-
def _get_call_generate(args):
|
317
|
+
def _get_call_generate(args: argparse.Namespace):
|
313
318
|
if args.backend == "lightllm":
|
314
319
|
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
|
315
320
|
elif args.backend == "vllm":
|
@@ -336,7 +341,7 @@ def _get_call_generate(args):
|
|
336
341
|
raise ValueError(f"Invalid backend: {args.backend}")
|
337
342
|
|
338
343
|
|
339
|
-
def _get_call_select(args):
|
344
|
+
def _get_call_select(args: argparse.Namespace):
|
340
345
|
if args.backend == "lightllm":
|
341
346
|
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
|
342
347
|
elif args.backend == "vllm":
|
@@ -359,7 +364,7 @@ def _get_call_select(args):
|
|
359
364
|
raise ValueError(f"Invalid backend: {args.backend}")
|
360
365
|
|
361
366
|
|
362
|
-
def get_call_generate(args):
|
367
|
+
def get_call_generate(args: argparse.Namespace):
|
363
368
|
call_generate = _get_call_generate(args)
|
364
369
|
|
365
370
|
def func(*args, **kwargs):
|
@@ -372,7 +377,7 @@ def get_call_generate(args):
|
|
372
377
|
return func
|
373
378
|
|
374
379
|
|
375
|
-
def get_call_select(args):
|
380
|
+
def get_call_select(args: argparse.Namespace):
|
376
381
|
call_select = _get_call_select(args)
|
377
382
|
|
378
383
|
def func(*args, **kwargs):
|
@@ -385,7 +390,16 @@ def get_call_select(args):
|
|
385
390
|
return func
|
386
391
|
|
387
392
|
|
388
|
-
def popen_launch_server(
|
393
|
+
def popen_launch_server(
|
394
|
+
model: str,
|
395
|
+
base_url: str,
|
396
|
+
timeout: float,
|
397
|
+
api_key: Optional[str] = None,
|
398
|
+
other_args: tuple = (),
|
399
|
+
):
|
400
|
+
_, host, port = base_url.split(":")
|
401
|
+
host = host[2:]
|
402
|
+
|
389
403
|
command = [
|
390
404
|
"python3",
|
391
405
|
"-m",
|
@@ -393,21 +407,88 @@ def popen_launch_server(model, port, timeout, *args):
|
|
393
407
|
"--model-path",
|
394
408
|
model,
|
395
409
|
"--host",
|
396
|
-
|
410
|
+
host,
|
397
411
|
"--port",
|
398
|
-
|
399
|
-
*
|
412
|
+
port,
|
413
|
+
*other_args,
|
400
414
|
]
|
415
|
+
if api_key:
|
416
|
+
command += ["--api-key", api_key]
|
417
|
+
|
401
418
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
402
|
-
base_url = f"http://localhost:{port}/v1"
|
403
419
|
|
404
420
|
start_time = time.time()
|
405
421
|
while time.time() - start_time < timeout:
|
406
422
|
try:
|
407
|
-
|
423
|
+
headers = {
|
424
|
+
"Content-Type": "application/json; charset=utf-8",
|
425
|
+
"Authorization": f"Bearer {api_key}",
|
426
|
+
}
|
427
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
408
428
|
if response.status_code == 200:
|
409
429
|
return process
|
410
430
|
except requests.RequestException:
|
411
431
|
pass
|
412
432
|
time.sleep(10)
|
413
433
|
raise TimeoutError("Server failed to start within the timeout period.")
|
434
|
+
|
435
|
+
|
436
|
+
def run_with_timeout(
|
437
|
+
func: Callable,
|
438
|
+
args: tuple = (),
|
439
|
+
kwargs: Optional[dict] = None,
|
440
|
+
timeout: float = None,
|
441
|
+
):
|
442
|
+
"""Run a function with timeout."""
|
443
|
+
ret_value = []
|
444
|
+
|
445
|
+
def _target_func():
|
446
|
+
ret_value.append(func(*args, **(kwargs or {})))
|
447
|
+
|
448
|
+
t = threading.Thread(target=_target_func)
|
449
|
+
t.start()
|
450
|
+
t.join(timeout=timeout)
|
451
|
+
if t.is_alive():
|
452
|
+
raise TimeoutError()
|
453
|
+
|
454
|
+
if not ret_value:
|
455
|
+
raise RuntimeError()
|
456
|
+
|
457
|
+
return ret_value[0]
|
458
|
+
|
459
|
+
|
460
|
+
def run_unittest_files(files: List[str], timeout_per_file: float):
|
461
|
+
tic = time.time()
|
462
|
+
success = True
|
463
|
+
|
464
|
+
for filename in files:
|
465
|
+
|
466
|
+
def func():
|
467
|
+
print(f"\n\nRun {filename}\n\n")
|
468
|
+
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
|
469
|
+
|
470
|
+
p = multiprocessing.Process(target=func)
|
471
|
+
|
472
|
+
def run_one_file():
|
473
|
+
p.start()
|
474
|
+
p.join()
|
475
|
+
|
476
|
+
try:
|
477
|
+
run_with_timeout(run_one_file, timeout=timeout_per_file)
|
478
|
+
if p.exitcode != 0:
|
479
|
+
success = False
|
480
|
+
break
|
481
|
+
except TimeoutError:
|
482
|
+
p.terminate()
|
483
|
+
time.sleep(5)
|
484
|
+
print(
|
485
|
+
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
|
486
|
+
)
|
487
|
+
return False
|
488
|
+
|
489
|
+
if success:
|
490
|
+
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
|
491
|
+
else:
|
492
|
+
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
493
|
+
|
494
|
+
return 0 if success else -1
|