sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +114 -63
- sglang/check_env.py +1 -0
- sglang/lang/backend/runtime_endpoint.py +0 -11
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/radix_attention.py +22 -9
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +15 -11
- sglang/srt/managers/tokenizer_manager.py +28 -13
- sglang/srt/mem_cache/memory_pool.py +65 -24
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/model_runner.py +46 -17
- sglang/srt/models/deepseek_v2.py +198 -16
- sglang/srt/openai_api/adapter.py +120 -20
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/server.py +87 -78
- sglang/srt/server_args.py +8 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +94 -13
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/METADATA +29 -28
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/RECORD +33 -30
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/top_level.txt +0 -0
sglang/test/run_eval.py
CHANGED
@@ -10,7 +10,6 @@ import time
|
|
10
10
|
|
11
11
|
from sglang.test.simple_eval_common import (
|
12
12
|
ChatCompletionSampler,
|
13
|
-
download_dataset,
|
14
13
|
make_report,
|
15
14
|
set_ulimit,
|
16
15
|
)
|
@@ -27,14 +26,26 @@ def run_eval(args):
|
|
27
26
|
if args.eval_name == "mmlu":
|
28
27
|
from sglang.test.simple_eval_mmlu import MMLUEval
|
29
28
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
29
|
+
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
|
30
|
+
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
|
31
|
+
elif args.eval_name == "math":
|
32
|
+
from sglang.test.simple_eval_math import MathEval
|
33
|
+
|
34
|
+
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
|
35
|
+
|
36
|
+
filename = (
|
37
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
|
38
|
+
)
|
39
|
+
eval_obj = MathEval(
|
40
|
+
filename, equality_checker, args.num_examples, args.num_threads
|
41
|
+
)
|
42
|
+
elif args.eval_name == "gpqa":
|
43
|
+
from sglang.test.simple_eval_gpqa import GPQAEval
|
44
|
+
|
45
|
+
filename = (
|
46
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
|
47
|
+
)
|
48
|
+
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
|
38
49
|
elif args.eval_name == "humaneval":
|
39
50
|
from sglang.test.simple_eval_humaneval import HumanEval
|
40
51
|
|
@@ -97,7 +108,7 @@ if __name__ == "__main__":
|
|
97
108
|
)
|
98
109
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
99
110
|
parser.add_argument("--num-examples", type=int)
|
100
|
-
parser.add_argument("--num-threads", type=int, default=
|
111
|
+
parser.add_argument("--num-threads", type=int, default=512)
|
101
112
|
set_ulimit()
|
102
113
|
args = parser.parse_args()
|
103
114
|
|
sglang/test/runners.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
import json
|
17
|
+
import multiprocessing
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import List, Union
|
20
|
+
|
21
|
+
import torch
|
22
|
+
import torch.nn.functional as F
|
23
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24
|
+
|
25
|
+
from sglang.srt.server import Runtime
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = [
|
28
|
+
"The capital of France is",
|
29
|
+
"The capital of the United Kindom is",
|
30
|
+
"Today is a sunny day and I like",
|
31
|
+
]
|
32
|
+
|
33
|
+
NUM_TOP_LOGPROBS = 5
|
34
|
+
|
35
|
+
|
36
|
+
def is_embedding_model(model_path):
|
37
|
+
# FIXME incomplete list
|
38
|
+
if "e5-mistral-7b-instruct" in model_path.lower():
|
39
|
+
return True
|
40
|
+
return False
|
41
|
+
|
42
|
+
|
43
|
+
def get_dtype_str(torch_dtype):
|
44
|
+
if torch_dtype is torch.float16:
|
45
|
+
return "float16"
|
46
|
+
else:
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class ModelOutput:
|
52
|
+
output_strs: str = None
|
53
|
+
top_input_logprobs: torch.Tensor = None
|
54
|
+
top_output_logprobs: torch.Tensor = None
|
55
|
+
embed_logits: torch.Tensor = None
|
56
|
+
|
57
|
+
|
58
|
+
class HFRunner:
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
model_path,
|
62
|
+
torch_dtype=torch.float16,
|
63
|
+
is_embedding_model=None,
|
64
|
+
):
|
65
|
+
self.in_queue = multiprocessing.Queue()
|
66
|
+
self.out_queue = multiprocessing.Queue()
|
67
|
+
|
68
|
+
self.model_proc = multiprocessing.Process(
|
69
|
+
target=self.start_model_process,
|
70
|
+
args=(
|
71
|
+
self.in_queue,
|
72
|
+
self.out_queue,
|
73
|
+
model_path,
|
74
|
+
torch_dtype,
|
75
|
+
is_embedding_model,
|
76
|
+
),
|
77
|
+
)
|
78
|
+
self.model_proc.start()
|
79
|
+
|
80
|
+
def start_model_process(
|
81
|
+
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
82
|
+
):
|
83
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
84
|
+
model_path,
|
85
|
+
torch_dtype=torch_dtype,
|
86
|
+
trust_remote_code=True,
|
87
|
+
)
|
88
|
+
|
89
|
+
self.is_embedding_model = (
|
90
|
+
is_embedding_model(model_path)
|
91
|
+
if is_embedding_model is None
|
92
|
+
else is_embedding_model
|
93
|
+
)
|
94
|
+
if not self.is_embedding_model:
|
95
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
96
|
+
model_path,
|
97
|
+
torch_dtype=torch_dtype,
|
98
|
+
low_cpu_mem_usage=True,
|
99
|
+
trust_remote_code=True,
|
100
|
+
).cuda()
|
101
|
+
else:
|
102
|
+
from sentence_transformers import SentenceTransformer
|
103
|
+
|
104
|
+
self.model = SentenceTransformer(
|
105
|
+
model_path,
|
106
|
+
device="cpu",
|
107
|
+
).to(dtype=torch_dtype)
|
108
|
+
|
109
|
+
while True:
|
110
|
+
prompts, max_new_tokens = in_queue.get()
|
111
|
+
if prompts is not None:
|
112
|
+
if not self.is_embedding_model:
|
113
|
+
output_strs = []
|
114
|
+
prefill_logprobs = []
|
115
|
+
for p in prompts:
|
116
|
+
if isinstance(p, str):
|
117
|
+
input_ids = self.tokenizer.encode(
|
118
|
+
p, return_tensors="pt"
|
119
|
+
).cuda()
|
120
|
+
else:
|
121
|
+
input_ids = torch.tensor([p], device="cuda")
|
122
|
+
|
123
|
+
output_ids = self.model.generate(
|
124
|
+
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
125
|
+
)
|
126
|
+
output_strs.append(self.tokenizer.decode(output_ids[0]))
|
127
|
+
|
128
|
+
logits = self.model.forward(input_ids).logits[0]
|
129
|
+
logprobs = F.log_softmax(
|
130
|
+
logits, dim=-1, dtype=torch.float32
|
131
|
+
).tolist()
|
132
|
+
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
133
|
+
# print("index", index_of_max)
|
134
|
+
logprobs = [
|
135
|
+
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
136
|
+
for token_logprobs in logprobs
|
137
|
+
]
|
138
|
+
prefill_logprobs.append(logprobs)
|
139
|
+
|
140
|
+
out_queue.put(
|
141
|
+
ModelOutput(
|
142
|
+
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
143
|
+
)
|
144
|
+
)
|
145
|
+
|
146
|
+
else:
|
147
|
+
assert isinstance(prompts, List[str])
|
148
|
+
logits = self.model.encode(prompts).tolist()
|
149
|
+
|
150
|
+
out_queue.put(ModelOutput(embed_logits=logits))
|
151
|
+
|
152
|
+
def forward(
|
153
|
+
self,
|
154
|
+
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
155
|
+
max_new_tokens=64,
|
156
|
+
):
|
157
|
+
self.in_queue.put((prompts, max_new_tokens))
|
158
|
+
return self.out_queue.get()
|
159
|
+
|
160
|
+
def terminate(self):
|
161
|
+
self.model_proc.terminate()
|
162
|
+
self.in_queue = self.out_queue = None
|
163
|
+
|
164
|
+
def __enter__(self):
|
165
|
+
return self
|
166
|
+
|
167
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
168
|
+
self.model_proc.terminate()
|
169
|
+
self.in_queue = self.out_queue = None
|
170
|
+
|
171
|
+
|
172
|
+
class SRTRunner:
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
model_path,
|
176
|
+
tp_size=1,
|
177
|
+
torch_dtype=torch.float16,
|
178
|
+
is_embedding_model=None,
|
179
|
+
):
|
180
|
+
self.is_embedding_model = (
|
181
|
+
is_embedding_model(model_path)
|
182
|
+
if is_embedding_model is None
|
183
|
+
else is_embedding_model
|
184
|
+
)
|
185
|
+
if self.is_embedding_model:
|
186
|
+
raise NotImplementedError()
|
187
|
+
|
188
|
+
self.runtime = Runtime(
|
189
|
+
model_path=model_path,
|
190
|
+
tp_size=tp_size,
|
191
|
+
dtype=get_dtype_str(torch_dtype),
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(
|
195
|
+
self,
|
196
|
+
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
197
|
+
max_new_tokens=64,
|
198
|
+
):
|
199
|
+
# the return value contains logprobs from prefill
|
200
|
+
output_strs = []
|
201
|
+
top_input_logprobs = []
|
202
|
+
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
203
|
+
for prompt in prompts:
|
204
|
+
response = self.runtime.generate(
|
205
|
+
prompt,
|
206
|
+
sampling_params=sampling_params,
|
207
|
+
return_logprob=True,
|
208
|
+
top_logprobs_num=NUM_TOP_LOGPROBS,
|
209
|
+
)
|
210
|
+
response = json.loads(response)
|
211
|
+
output_strs.append(response["text"])
|
212
|
+
top_input_logprobs.append(
|
213
|
+
[
|
214
|
+
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
215
|
+
for x in response["meta_info"]["input_top_logprobs"][1:]
|
216
|
+
]
|
217
|
+
+ [
|
218
|
+
[
|
219
|
+
tup[0]
|
220
|
+
for tup in response["meta_info"]["output_top_logprobs"][0][
|
221
|
+
:NUM_TOP_LOGPROBS
|
222
|
+
]
|
223
|
+
]
|
224
|
+
]
|
225
|
+
)
|
226
|
+
# print(response["meta_info"]["output_top_logprobs"][0])
|
227
|
+
|
228
|
+
return ModelOutput(
|
229
|
+
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
230
|
+
)
|
231
|
+
|
232
|
+
def __enter__(self):
|
233
|
+
return self
|
234
|
+
|
235
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
236
|
+
self.runtime.shutdown()
|
237
|
+
del self.runtime
|
@@ -7,7 +7,7 @@ import time
|
|
7
7
|
from collections import defaultdict
|
8
8
|
from dataclasses import dataclass, field
|
9
9
|
from multiprocessing.pool import ThreadPool
|
10
|
-
from typing import Any
|
10
|
+
from typing import Any, Dict, List, Tuple
|
11
11
|
|
12
12
|
import httpx
|
13
13
|
import jinja2
|
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
|
27
|
-
Message =
|
28
|
-
MessageList =
|
27
|
+
Message = Dict[str, Any] # keys role, content
|
28
|
+
MessageList = List[Message]
|
29
29
|
|
30
30
|
|
31
31
|
class SamplerBase:
|
@@ -45,9 +45,9 @@ class EvalResult:
|
|
45
45
|
"""
|
46
46
|
|
47
47
|
score: float | None # top-line metric
|
48
|
-
metrics:
|
49
|
-
htmls:
|
50
|
-
convos:
|
48
|
+
metrics: Dict[str, float] | None # other metrics
|
49
|
+
htmls: List[str] # strings of valid HTML
|
50
|
+
convos: List[MessageList] # sampled conversations
|
51
51
|
|
52
52
|
|
53
53
|
@dataclass
|
@@ -57,7 +57,7 @@ class SingleEvalResult:
|
|
57
57
|
"""
|
58
58
|
|
59
59
|
score: float | None
|
60
|
-
metrics:
|
60
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
61
61
|
html: str | None = None
|
62
62
|
convo: MessageList | None = None # sampled conversation
|
63
63
|
|
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
|
|
270
270
|
|
271
271
|
|
272
272
|
def aggregate_results(
|
273
|
-
single_eval_results:
|
274
|
-
default_stats:
|
275
|
-
name2stats:
|
273
|
+
single_eval_results: List[SingleEvalResult],
|
274
|
+
default_stats: Tuple[str] = ("mean", "std"),
|
275
|
+
name2stats: Dict[str, Tuple[str]] | None = None,
|
276
276
|
) -> EvalResult:
|
277
277
|
"""
|
278
278
|
Aggregate results from multiple evaluations into a single EvalResult.
|
@@ -302,7 +302,7 @@ def aggregate_results(
|
|
302
302
|
)
|
303
303
|
|
304
304
|
|
305
|
-
def map_with_progress(f: callable, xs:
|
305
|
+
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
306
306
|
"""
|
307
307
|
Apply f to each element of xs, using a ThreadPool, and show progress.
|
308
308
|
"""
|
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
|
|
422
422
|
)
|
423
423
|
|
424
424
|
|
425
|
-
def make_report_from_example_htmls(htmls:
|
425
|
+
def make_report_from_example_htmls(htmls: List[str]):
|
426
426
|
"""
|
427
427
|
Create a standalone HTML report from a list of example htmls
|
428
428
|
"""
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
5
|
+
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
6
|
+
https://arxiv.org/abs/2311.12022
|
7
|
+
"""
|
8
|
+
|
9
|
+
import random
|
10
|
+
import re
|
11
|
+
|
12
|
+
import pandas
|
13
|
+
|
14
|
+
from sglang.test import simple_eval_common as common
|
15
|
+
from sglang.test.simple_eval_common import (
|
16
|
+
ANSWER_PATTERN_MULTICHOICE,
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
MessageList,
|
21
|
+
SamplerBase,
|
22
|
+
SingleEvalResult,
|
23
|
+
format_multichoice_question,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class GPQAEval(Eval):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
filename: str,
|
31
|
+
num_examples: int | None,
|
32
|
+
num_threads: int,
|
33
|
+
n_repeats: int = 1,
|
34
|
+
):
|
35
|
+
df = pandas.read_csv(filename)
|
36
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
37
|
+
rng = random.Random(0)
|
38
|
+
if num_examples:
|
39
|
+
assert n_repeats == 1, "n_repeats only supported for num_examples"
|
40
|
+
examples = rng.sample(examples, num_examples)
|
41
|
+
examples = examples * n_repeats
|
42
|
+
examples = [
|
43
|
+
example | {"permutation": rng.sample(range(4), 4)} for example in examples
|
44
|
+
]
|
45
|
+
self.examples = examples
|
46
|
+
self.n_repeats = n_repeats
|
47
|
+
self.num_threads = num_threads
|
48
|
+
|
49
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
50
|
+
def fn(row: dict):
|
51
|
+
choices = [
|
52
|
+
row["Correct Answer"],
|
53
|
+
row["Incorrect Answer 1"],
|
54
|
+
row["Incorrect Answer 2"],
|
55
|
+
row["Incorrect Answer 3"],
|
56
|
+
]
|
57
|
+
choices = [choices[i] for i in row["permutation"]]
|
58
|
+
correct_index = choices.index(row["Correct Answer"])
|
59
|
+
correct_answer = "ABCD"[correct_index]
|
60
|
+
choices_dict = dict(
|
61
|
+
A=choices[0],
|
62
|
+
B=choices[1],
|
63
|
+
C=choices[2],
|
64
|
+
D=choices[3],
|
65
|
+
Question=row["Question"],
|
66
|
+
)
|
67
|
+
prompt_messages = [
|
68
|
+
sampler._pack_message(
|
69
|
+
content=format_multichoice_question(choices_dict), role="user"
|
70
|
+
)
|
71
|
+
]
|
72
|
+
response_text = sampler(prompt_messages)
|
73
|
+
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
74
|
+
extracted_answer = match.group(1) if match else None
|
75
|
+
score = 1.0 if extracted_answer == correct_answer else 0.0
|
76
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
77
|
+
prompt_messages=prompt_messages,
|
78
|
+
next_message=dict(content=response_text, role="assistant"),
|
79
|
+
score=score,
|
80
|
+
correct_answer=correct_answer,
|
81
|
+
extracted_answer=extracted_answer,
|
82
|
+
)
|
83
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
84
|
+
return SingleEvalResult(
|
85
|
+
html=html,
|
86
|
+
score=score,
|
87
|
+
convo=convo,
|
88
|
+
metrics={"chars": len(response_text)},
|
89
|
+
)
|
90
|
+
|
91
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
92
|
+
return common.aggregate_results(results)
|
@@ -14,7 +14,7 @@ import re
|
|
14
14
|
from collections import Counter, defaultdict
|
15
15
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
16
|
from io import BytesIO
|
17
|
-
from typing import Any, Tuple
|
17
|
+
from typing import Any, Dict, List, Tuple
|
18
18
|
|
19
19
|
import blobfile as bf
|
20
20
|
import tqdm
|
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
|
|
38
38
|
|
39
39
|
|
40
40
|
def evaluate_functional_correctness(
|
41
|
-
sample:
|
42
|
-
completions:
|
41
|
+
sample: Dict[str, str],
|
42
|
+
completions: List[str],
|
43
43
|
n_workers: int = 4,
|
44
44
|
timeout: float = 3.0,
|
45
45
|
):
|
@@ -70,7 +70,7 @@ class HumanEval(Eval):
|
|
70
70
|
num_examples: int | None,
|
71
71
|
num_threads: int,
|
72
72
|
num_samples_per_task: int = 5,
|
73
|
-
ks_passes:
|
73
|
+
ks_passes: List[int] = [1, 2, 5],
|
74
74
|
timeout: int = 120,
|
75
75
|
):
|
76
76
|
self.seed = 0
|
@@ -97,7 +97,7 @@ class HumanEval(Eval):
|
|
97
97
|
] # remove signature
|
98
98
|
return extracted_answer
|
99
99
|
|
100
|
-
def fn(sample:
|
100
|
+
def fn(sample: Dict[str, str]):
|
101
101
|
prompt_messages = [
|
102
102
|
sampler._pack_message(
|
103
103
|
role="user", content=instruction + sample["prompt"]
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
Measuring Mathematical Problem Solving With the MATH Dataset
|
5
|
+
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
|
6
|
+
https://arxiv.org/abs/2103.03874
|
7
|
+
"""
|
8
|
+
|
9
|
+
import random
|
10
|
+
import re
|
11
|
+
|
12
|
+
import pandas
|
13
|
+
|
14
|
+
from sglang.test import simple_eval_common as common
|
15
|
+
from sglang.test.simple_eval_common import (
|
16
|
+
ANSWER_PATTERN,
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
SamplerBase,
|
21
|
+
SingleEvalResult,
|
22
|
+
check_equality,
|
23
|
+
)
|
24
|
+
|
25
|
+
QUERY_TEMPLATE = """
|
26
|
+
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
|
27
|
+
|
28
|
+
{Question}
|
29
|
+
|
30
|
+
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
|
31
|
+
""".strip()
|
32
|
+
|
33
|
+
|
34
|
+
class MathEval(Eval):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
filename: str,
|
38
|
+
equality_checker: SamplerBase,
|
39
|
+
num_examples: int | None,
|
40
|
+
num_threads: int,
|
41
|
+
):
|
42
|
+
df = pandas.read_csv(filename)
|
43
|
+
examples = [row.to_dict() for _, row in df.iterrows()]
|
44
|
+
if num_examples:
|
45
|
+
examples = random.Random(0).sample(examples, num_examples)
|
46
|
+
self.examples = examples
|
47
|
+
self.equality_checker = equality_checker
|
48
|
+
self.num_threads = num_threads
|
49
|
+
|
50
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
51
|
+
def fn(row: dict):
|
52
|
+
prompt_messages = [
|
53
|
+
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
|
54
|
+
]
|
55
|
+
response_text = sampler(prompt_messages)
|
56
|
+
match = re.search(ANSWER_PATTERN, response_text)
|
57
|
+
extracted_answer = match.group(1) if match else None
|
58
|
+
score = float(
|
59
|
+
check_equality(self.equality_checker, row["Answer"], extracted_answer)
|
60
|
+
)
|
61
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
62
|
+
prompt_messages=prompt_messages,
|
63
|
+
next_message=dict(content=response_text, role="assistant"),
|
64
|
+
score=score,
|
65
|
+
correct_answer=row["Answer"],
|
66
|
+
extracted_answer=extracted_answer,
|
67
|
+
)
|
68
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
69
|
+
return SingleEvalResult(html=html, score=score, convo=convo)
|
70
|
+
|
71
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
72
|
+
return common.aggregate_results(results)
|