sglang 0.2.8__py3-none-any.whl → 0.2.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -5
- sglang/check_env.py +1 -0
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/model_executor/model_runner.py +6 -4
- sglang/srt/openai_api/adapter.py +7 -6
- sglang/srt/server.py +5 -13
- sglang/srt/server_args.py +11 -0
- sglang/srt/utils.py +20 -0
- sglang/test/run_eval.py +104 -0
- sglang/test/simple_eval_common.py +467 -0
- sglang/test/simple_eval_humaneval.py +139 -0
- sglang/test/simple_eval_mmlu.py +120 -0
- sglang/test/test_programs.py +4 -4
- sglang/test/test_utils.py +32 -0
- sglang/version.py +1 -1
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/RECORD +21 -19
- sglang/test/test_conversation.py +0 -46
- sglang/test/test_openai_protocol.py +0 -51
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.2.8.dist-info → sglang-0.2.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,467 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import os
|
5
|
+
import resource
|
6
|
+
import time
|
7
|
+
from collections import defaultdict
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from multiprocessing.pool import ThreadPool
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
import httpx
|
13
|
+
import jinja2
|
14
|
+
import numpy as np
|
15
|
+
import openai
|
16
|
+
import requests
|
17
|
+
from openai import OpenAI
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
21
|
+
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
22
|
+
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
23
|
+
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
Message = dict[str, Any] # keys role, content
|
28
|
+
MessageList = list[Message]
|
29
|
+
|
30
|
+
|
31
|
+
class SamplerBase:
|
32
|
+
"""
|
33
|
+
Base class for defining a sampling model, which can be evaluated,
|
34
|
+
or used as part of the grading process.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __call__(self, message_list: MessageList) -> str:
|
38
|
+
raise NotImplementedError()
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class EvalResult:
|
43
|
+
"""
|
44
|
+
Result of running an evaluation (usually consisting of many samples)
|
45
|
+
"""
|
46
|
+
|
47
|
+
score: float | None # top-line metric
|
48
|
+
metrics: dict[str, float] | None # other metrics
|
49
|
+
htmls: list[str] # strings of valid HTML
|
50
|
+
convos: list[MessageList] # sampled conversations
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class SingleEvalResult:
|
55
|
+
"""
|
56
|
+
Result of evaluating a single sample
|
57
|
+
"""
|
58
|
+
|
59
|
+
score: float | None
|
60
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
61
|
+
html: str | None = None
|
62
|
+
convo: MessageList | None = None # sampled conversation
|
63
|
+
|
64
|
+
|
65
|
+
class Eval:
|
66
|
+
"""
|
67
|
+
Base class for defining an evaluation.
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
71
|
+
raise NotImplementedError()
|
72
|
+
|
73
|
+
|
74
|
+
class LargerHttpxClient(httpx.Client):
|
75
|
+
def __init__(self):
|
76
|
+
timeout_config = httpx.Timeout(3600)
|
77
|
+
limits = httpx.Limits(
|
78
|
+
max_keepalive_connections=3600,
|
79
|
+
max_connections=3600,
|
80
|
+
)
|
81
|
+
super().__init__(timeout=timeout_config, limits=limits)
|
82
|
+
|
83
|
+
|
84
|
+
class ChatCompletionSampler(SamplerBase):
|
85
|
+
"""
|
86
|
+
Sample from OpenAI's chat completion API
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
base_url: str = None,
|
92
|
+
model: str | None = None,
|
93
|
+
system_message: str | None = None,
|
94
|
+
temperature: float = 0.0,
|
95
|
+
max_tokens: int = 2048,
|
96
|
+
):
|
97
|
+
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
98
|
+
|
99
|
+
if model is None:
|
100
|
+
model = self.client.models.list().data[0].id
|
101
|
+
|
102
|
+
self.model = model
|
103
|
+
self.system_message = system_message
|
104
|
+
self.temperature = temperature
|
105
|
+
self.max_tokens = max_tokens
|
106
|
+
self.image_format = "url"
|
107
|
+
|
108
|
+
def _handle_image(
|
109
|
+
self,
|
110
|
+
image: str,
|
111
|
+
encoding: str = "base64",
|
112
|
+
format: str = "png",
|
113
|
+
fovea: int = 768,
|
114
|
+
):
|
115
|
+
new_image = {
|
116
|
+
"type": "image_url",
|
117
|
+
"image_url": {
|
118
|
+
"url": f"data:image/{format};{encoding},{image}",
|
119
|
+
},
|
120
|
+
}
|
121
|
+
return new_image
|
122
|
+
|
123
|
+
def _handle_text(self, text: str):
|
124
|
+
return {"type": "text", "text": text}
|
125
|
+
|
126
|
+
def _pack_message(self, role: str, content: Any):
|
127
|
+
return {"role": str(role), "content": content}
|
128
|
+
|
129
|
+
def __call__(self, message_list: MessageList) -> str:
|
130
|
+
if self.system_message:
|
131
|
+
message_list = [
|
132
|
+
self._pack_message("system", self.system_message)
|
133
|
+
] + message_list
|
134
|
+
trial = 0
|
135
|
+
while True:
|
136
|
+
try:
|
137
|
+
response = self.client.chat.completions.create(
|
138
|
+
model=self.model,
|
139
|
+
messages=message_list,
|
140
|
+
temperature=self.temperature,
|
141
|
+
max_tokens=self.max_tokens,
|
142
|
+
)
|
143
|
+
return response.choices[0].message.content
|
144
|
+
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
145
|
+
except openai.BadRequestError as e:
|
146
|
+
print("Bad Request Error", e)
|
147
|
+
return ""
|
148
|
+
except Exception as e:
|
149
|
+
exception_backoff = 2**trial # expontial back off
|
150
|
+
print(
|
151
|
+
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
152
|
+
e,
|
153
|
+
)
|
154
|
+
time.sleep(exception_backoff)
|
155
|
+
trial += 1
|
156
|
+
# unknown error shall throw exception
|
157
|
+
|
158
|
+
|
159
|
+
QUERY_TEMPLATE_MULTICHOICE = """
|
160
|
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
161
|
+
|
162
|
+
{Question}
|
163
|
+
|
164
|
+
A) {A}
|
165
|
+
B) {B}
|
166
|
+
C) {C}
|
167
|
+
D) {D}
|
168
|
+
""".strip()
|
169
|
+
|
170
|
+
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
|
171
|
+
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
|
172
|
+
|
173
|
+
|
174
|
+
EQUALITY_TEMPLATE = r"""
|
175
|
+
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
|
179
|
+
Expression 1: $2x+3$
|
180
|
+
Expression 2: $3+2x$
|
181
|
+
|
182
|
+
Yes
|
183
|
+
|
184
|
+
Expression 1: 3/2
|
185
|
+
Expression 2: 1.5
|
186
|
+
|
187
|
+
Yes
|
188
|
+
|
189
|
+
Expression 1: $x^2+2x+1$
|
190
|
+
Expression 2: $y^2+2y+1$
|
191
|
+
|
192
|
+
No
|
193
|
+
|
194
|
+
Expression 1: $x^2+2x+1$
|
195
|
+
Expression 2: $(x+1)^2$
|
196
|
+
|
197
|
+
Yes
|
198
|
+
|
199
|
+
Expression 1: 3245/5
|
200
|
+
Expression 2: 649
|
201
|
+
|
202
|
+
No
|
203
|
+
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
204
|
+
|
205
|
+
Expression 1: 2/(-3)
|
206
|
+
Expression 2: -2/3
|
207
|
+
|
208
|
+
Yes
|
209
|
+
(trivial simplifications are allowed)
|
210
|
+
|
211
|
+
Expression 1: 72 degrees
|
212
|
+
Expression 2: 72
|
213
|
+
|
214
|
+
Yes
|
215
|
+
(give benefit of the doubt to units)
|
216
|
+
|
217
|
+
Expression 1: 64
|
218
|
+
Expression 2: 64 square feet
|
219
|
+
|
220
|
+
Yes
|
221
|
+
(give benefit of the doubt to units)
|
222
|
+
|
223
|
+
---
|
224
|
+
|
225
|
+
YOUR TASK
|
226
|
+
|
227
|
+
|
228
|
+
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
229
|
+
|
230
|
+
Expression 1: %(expression1)s
|
231
|
+
Expression 2: %(expression2)s
|
232
|
+
""".strip()
|
233
|
+
|
234
|
+
|
235
|
+
HTML_JINJA = """
|
236
|
+
<h3>Prompt conversation</h3>
|
237
|
+
{% for message in prompt_messages %}
|
238
|
+
{{ message_to_html(message) | safe }}
|
239
|
+
{% endfor %}
|
240
|
+
<h3>Sampled message</h3>
|
241
|
+
{{ message_to_html(next_message) | safe }}
|
242
|
+
<h3>Results</h3>
|
243
|
+
<p>Correct Answer: {{ correct_answer }}</p>
|
244
|
+
<p>Extracted Answer: {{ extracted_answer }}</p>
|
245
|
+
<p>Score: {{ score }}</p>
|
246
|
+
"""
|
247
|
+
|
248
|
+
|
249
|
+
def format_multichoice_question(row):
|
250
|
+
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
251
|
+
|
252
|
+
|
253
|
+
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
254
|
+
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
255
|
+
response = sampler([dict(content=prompt, role="user")])
|
256
|
+
return response.lower().strip() == "yes"
|
257
|
+
|
258
|
+
|
259
|
+
def _compute_stat(values: list, stat: str):
|
260
|
+
if stat == "mean":
|
261
|
+
return np.mean(values)
|
262
|
+
elif stat == "std":
|
263
|
+
return np.std(values)
|
264
|
+
elif stat == "min":
|
265
|
+
return np.min(values)
|
266
|
+
elif stat == "max":
|
267
|
+
return np.max(values)
|
268
|
+
else:
|
269
|
+
raise ValueError(f"Unknown {stat =}")
|
270
|
+
|
271
|
+
|
272
|
+
def aggregate_results(
|
273
|
+
single_eval_results: list[SingleEvalResult],
|
274
|
+
default_stats: tuple[str] = ("mean", "std"),
|
275
|
+
name2stats: dict[str, tuple[str]] | None = None,
|
276
|
+
) -> EvalResult:
|
277
|
+
"""
|
278
|
+
Aggregate results from multiple evaluations into a single EvalResult.
|
279
|
+
"""
|
280
|
+
name2stats = name2stats or {}
|
281
|
+
name2values = defaultdict(list)
|
282
|
+
htmls = []
|
283
|
+
convos = []
|
284
|
+
for single_eval_result in single_eval_results:
|
285
|
+
for name, value in single_eval_result.metrics.items():
|
286
|
+
name2values[name].append(value)
|
287
|
+
if single_eval_result.score is not None:
|
288
|
+
name2values["score"].append(single_eval_result.score)
|
289
|
+
htmls.append(single_eval_result.html)
|
290
|
+
convos.append(single_eval_result.convo)
|
291
|
+
final_metrics = {}
|
292
|
+
for name, values in name2values.items():
|
293
|
+
stats = name2stats.get(name, default_stats)
|
294
|
+
for stat in stats:
|
295
|
+
key = name if stat == "mean" else f"{name}:{stat}"
|
296
|
+
final_metrics[key] = _compute_stat(values, stat)
|
297
|
+
return EvalResult(
|
298
|
+
score=final_metrics.pop("score", None),
|
299
|
+
metrics=final_metrics,
|
300
|
+
htmls=htmls,
|
301
|
+
convos=convos,
|
302
|
+
)
|
303
|
+
|
304
|
+
|
305
|
+
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
306
|
+
"""
|
307
|
+
Apply f to each element of xs, using a ThreadPool, and show progress.
|
308
|
+
"""
|
309
|
+
if os.getenv("debug"):
|
310
|
+
return list(map(f, tqdm(xs, total=len(xs))))
|
311
|
+
else:
|
312
|
+
with ThreadPool(min(num_threads, len(xs))) as pool:
|
313
|
+
return list(tqdm(pool.imap(f, xs), total=len(xs)))
|
314
|
+
|
315
|
+
|
316
|
+
jinja_env = jinja2.Environment(
|
317
|
+
loader=jinja2.BaseLoader(),
|
318
|
+
undefined=jinja2.StrictUndefined,
|
319
|
+
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
320
|
+
)
|
321
|
+
_message_template = """
|
322
|
+
<div class="message {{ role }}">
|
323
|
+
<div class="role">
|
324
|
+
{{ role }}
|
325
|
+
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
326
|
+
</div>
|
327
|
+
<div class="content">
|
328
|
+
<pre>{{ content }}</pre>
|
329
|
+
</div>
|
330
|
+
</div>
|
331
|
+
"""
|
332
|
+
|
333
|
+
|
334
|
+
def message_to_html(message: Message) -> str:
|
335
|
+
"""
|
336
|
+
Generate HTML snippet (inside a <div>) for a message.
|
337
|
+
"""
|
338
|
+
return jinja_env.from_string(_message_template).render(
|
339
|
+
role=message["role"],
|
340
|
+
content=message["content"],
|
341
|
+
variant=message.get("variant", None),
|
342
|
+
)
|
343
|
+
|
344
|
+
|
345
|
+
jinja_env.globals["message_to_html"] = message_to_html
|
346
|
+
|
347
|
+
|
348
|
+
_report_template = """<!DOCTYPE html>
|
349
|
+
<html>
|
350
|
+
<head>
|
351
|
+
<style>
|
352
|
+
.message {
|
353
|
+
padding: 8px 16px;
|
354
|
+
margin-bottom: 8px;
|
355
|
+
border-radius: 4px;
|
356
|
+
}
|
357
|
+
.message.user {
|
358
|
+
background-color: #B2DFDB;
|
359
|
+
color: #00695C;
|
360
|
+
}
|
361
|
+
.message.assistant {
|
362
|
+
background-color: #B39DDB;
|
363
|
+
color: #4527A0;
|
364
|
+
}
|
365
|
+
.message.system {
|
366
|
+
background-color: #EEEEEE;
|
367
|
+
color: #212121;
|
368
|
+
}
|
369
|
+
.role {
|
370
|
+
font-weight: bold;
|
371
|
+
margin-bottom: 4px;
|
372
|
+
}
|
373
|
+
.variant {
|
374
|
+
color: #795548;
|
375
|
+
}
|
376
|
+
table, th, td {
|
377
|
+
border: 1px solid black;
|
378
|
+
}
|
379
|
+
pre {
|
380
|
+
white-space: pre-wrap;
|
381
|
+
}
|
382
|
+
</style>
|
383
|
+
</head>
|
384
|
+
<body>
|
385
|
+
{% if metrics %}
|
386
|
+
<h1>Metrics</h1>
|
387
|
+
<table>
|
388
|
+
<tr>
|
389
|
+
<th>Metric</th>
|
390
|
+
<th>Value</th>
|
391
|
+
</tr>
|
392
|
+
<tr>
|
393
|
+
<td><b>Score</b></td>
|
394
|
+
<td>{{ score | float | round(3) }}</td>
|
395
|
+
</tr>
|
396
|
+
{% for name, value in metrics.items() %}
|
397
|
+
<tr>
|
398
|
+
<td>{{ name }}</td>
|
399
|
+
<td>{{ value }}</td>
|
400
|
+
</tr>
|
401
|
+
{% endfor %}
|
402
|
+
</table>
|
403
|
+
{% endif %}
|
404
|
+
<h1>Examples</h1>
|
405
|
+
{% for html in htmls %}
|
406
|
+
{{ html | safe }}
|
407
|
+
<hr>
|
408
|
+
{% endfor %}
|
409
|
+
</body>
|
410
|
+
</html>
|
411
|
+
"""
|
412
|
+
|
413
|
+
|
414
|
+
def make_report(eval_result: EvalResult) -> str:
|
415
|
+
"""
|
416
|
+
Create a standalone HTML report from an EvalResult.
|
417
|
+
"""
|
418
|
+
return jinja_env.from_string(_report_template).render(
|
419
|
+
score=eval_result.score,
|
420
|
+
metrics=eval_result.metrics,
|
421
|
+
htmls=eval_result.htmls,
|
422
|
+
)
|
423
|
+
|
424
|
+
|
425
|
+
def make_report_from_example_htmls(htmls: list[str]):
|
426
|
+
"""
|
427
|
+
Create a standalone HTML report from a list of example htmls
|
428
|
+
"""
|
429
|
+
return jinja_env.from_string(_report_template).render(
|
430
|
+
score=None, metrics={}, htmls=htmls
|
431
|
+
)
|
432
|
+
|
433
|
+
|
434
|
+
def download_dataset(path, url):
|
435
|
+
print(f"Downloading dataset {path} from {url}")
|
436
|
+
try:
|
437
|
+
response = requests.get(url, stream=True)
|
438
|
+
response.raise_for_status()
|
439
|
+
|
440
|
+
total_size = int(response.headers.get("content-length", 0))
|
441
|
+
block_size = 8192
|
442
|
+
|
443
|
+
with open(path, "wb") as f, tqdm(
|
444
|
+
desc="Downloading",
|
445
|
+
total=total_size,
|
446
|
+
unit="iB",
|
447
|
+
unit_scale=True,
|
448
|
+
unit_divisor=1024,
|
449
|
+
) as progress_bar:
|
450
|
+
for data in response.iter_content(block_size):
|
451
|
+
size = f.write(data)
|
452
|
+
progress_bar.update(size)
|
453
|
+
|
454
|
+
print(f"Dataset downloaded and saved to {path}")
|
455
|
+
except requests.RequestException as e:
|
456
|
+
raise Exception(f"Failed to download dataset: {e}")
|
457
|
+
|
458
|
+
|
459
|
+
def set_ulimit(target_soft_limit=65535):
|
460
|
+
resource_type = resource.RLIMIT_NOFILE
|
461
|
+
current_soft, current_hard = resource.getrlimit(resource_type)
|
462
|
+
|
463
|
+
if current_soft < target_soft_limit:
|
464
|
+
try:
|
465
|
+
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
466
|
+
except ValueError as e:
|
467
|
+
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
@@ -0,0 +1,139 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
HumanEval: Evaluating Large Language Models Trained on Code
|
5
|
+
Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
|
6
|
+
https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import multiprocessing
|
12
|
+
import random
|
13
|
+
import re
|
14
|
+
from collections import Counter, defaultdict
|
15
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
|
+
from io import BytesIO
|
17
|
+
from typing import Any, Tuple
|
18
|
+
|
19
|
+
import blobfile as bf
|
20
|
+
import tqdm
|
21
|
+
|
22
|
+
try:
|
23
|
+
from human_eval.data import HUMAN_EVAL, read_problems
|
24
|
+
from human_eval.evaluation import estimate_pass_at_k
|
25
|
+
from human_eval.execution import check_correctness # , unsafe_execute
|
26
|
+
except (ImportError, ModuleNotFoundError):
|
27
|
+
print("\nPlease install human-eval at https://github.com/openai/human-eval.\n")
|
28
|
+
raise
|
29
|
+
|
30
|
+
from sglang.test import simple_eval_common as common
|
31
|
+
from sglang.test.simple_eval_common import (
|
32
|
+
HTML_JINJA,
|
33
|
+
Eval,
|
34
|
+
EvalResult,
|
35
|
+
SamplerBase,
|
36
|
+
SingleEvalResult,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def evaluate_functional_correctness(
|
41
|
+
sample: dict[str, str],
|
42
|
+
completions: list[str],
|
43
|
+
n_workers: int = 4,
|
44
|
+
timeout: float = 3.0,
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Evaluates the functional correctness of generated samples, and writes
|
48
|
+
results to f"{sample_file}_results.jsonl.gz"
|
49
|
+
"""
|
50
|
+
import copy
|
51
|
+
|
52
|
+
# Check the generated samples against test suites.
|
53
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
54
|
+
futures = []
|
55
|
+
for i, completion in enumerate(completions):
|
56
|
+
args = (sample, completion, timeout, i)
|
57
|
+
future = executor.submit(check_correctness, *args)
|
58
|
+
futures.append(future)
|
59
|
+
results = []
|
60
|
+
for future in as_completed(futures):
|
61
|
+
result = future.result()
|
62
|
+
results.append(result)
|
63
|
+
passed = [int(r["passed"]) for r in results]
|
64
|
+
return passed
|
65
|
+
|
66
|
+
|
67
|
+
class HumanEval(Eval):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
num_examples: int | None,
|
71
|
+
num_threads: int,
|
72
|
+
num_samples_per_task: int = 5,
|
73
|
+
ks_passes: list[int] = [1, 2, 5],
|
74
|
+
timeout: int = 120,
|
75
|
+
):
|
76
|
+
self.seed = 0
|
77
|
+
self.examples = read_problems()
|
78
|
+
self.examples = list(self.examples.values())
|
79
|
+
|
80
|
+
self._num_examples = num_examples
|
81
|
+
if self._num_examples:
|
82
|
+
self.examples = random.Random(self.seed).sample(self.examples, num_examples)
|
83
|
+
self._num_samples_per_task = num_samples_per_task
|
84
|
+
self._ks_passes = ks_passes
|
85
|
+
self._timeout = timeout
|
86
|
+
self._num_threads = num_threads
|
87
|
+
|
88
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
89
|
+
instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
|
90
|
+
|
91
|
+
def find_code(completion):
|
92
|
+
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
|
93
|
+
matches = pattern.findall(completion)
|
94
|
+
extracted_answer = matches[0] if len(matches) >= 1 else completion
|
95
|
+
extracted_answer = extracted_answer[
|
96
|
+
extracted_answer.find(":\n ") + 2 :
|
97
|
+
] # remove signature
|
98
|
+
return extracted_answer
|
99
|
+
|
100
|
+
def fn(sample: dict[str, str]):
|
101
|
+
prompt_messages = [
|
102
|
+
sampler._pack_message(
|
103
|
+
role="user", content=instruction + sample["prompt"]
|
104
|
+
)
|
105
|
+
]
|
106
|
+
completions = [
|
107
|
+
find_code(sampler(prompt_messages))
|
108
|
+
for _ in range(self._num_samples_per_task)
|
109
|
+
]
|
110
|
+
results = evaluate_functional_correctness(sample, completions)
|
111
|
+
total = len(results)
|
112
|
+
correct = sum(results)
|
113
|
+
score = sum(results) / len(results)
|
114
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
115
|
+
prompt_messages=prompt_messages,
|
116
|
+
next_message=dict(content=completions[0], role="assistant"),
|
117
|
+
score=score,
|
118
|
+
correct_answer=[1] * len(results),
|
119
|
+
extracted_answer=results,
|
120
|
+
)
|
121
|
+
convo = prompt_messages + [
|
122
|
+
dict(content=completion, role="assistant") for completion in completions
|
123
|
+
]
|
124
|
+
return SingleEvalResult(
|
125
|
+
html=html,
|
126
|
+
score=score,
|
127
|
+
convo=convo,
|
128
|
+
metrics={
|
129
|
+
f"pass@{k}": estimate_pass_at_k([total], [correct], k)
|
130
|
+
# this will be aggrated so no need of .mean()
|
131
|
+
for k in self._ks_passes
|
132
|
+
if total >= k
|
133
|
+
},
|
134
|
+
)
|
135
|
+
|
136
|
+
results = common.map_with_progress(
|
137
|
+
fn, self.examples, num_threads=self._num_threads
|
138
|
+
)
|
139
|
+
return common.aggregate_results(results)
|