sglang 0.2.8__py3-none-any.whl → 0.2.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,467 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ import base64
4
+ import os
5
+ import resource
6
+ import time
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass, field
9
+ from multiprocessing.pool import ThreadPool
10
+ from typing import Any
11
+
12
+ import httpx
13
+ import jinja2
14
+ import numpy as np
15
+ import openai
16
+ import requests
17
+ from openai import OpenAI
18
+ from tqdm import tqdm
19
+
20
+ OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
21
+ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
22
+ "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
23
+ + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
24
+ )
25
+
26
+
27
+ Message = dict[str, Any] # keys role, content
28
+ MessageList = list[Message]
29
+
30
+
31
+ class SamplerBase:
32
+ """
33
+ Base class for defining a sampling model, which can be evaluated,
34
+ or used as part of the grading process.
35
+ """
36
+
37
+ def __call__(self, message_list: MessageList) -> str:
38
+ raise NotImplementedError()
39
+
40
+
41
+ @dataclass
42
+ class EvalResult:
43
+ """
44
+ Result of running an evaluation (usually consisting of many samples)
45
+ """
46
+
47
+ score: float | None # top-line metric
48
+ metrics: dict[str, float] | None # other metrics
49
+ htmls: list[str] # strings of valid HTML
50
+ convos: list[MessageList] # sampled conversations
51
+
52
+
53
+ @dataclass
54
+ class SingleEvalResult:
55
+ """
56
+ Result of evaluating a single sample
57
+ """
58
+
59
+ score: float | None
60
+ metrics: dict[str, float] = field(default_factory=dict)
61
+ html: str | None = None
62
+ convo: MessageList | None = None # sampled conversation
63
+
64
+
65
+ class Eval:
66
+ """
67
+ Base class for defining an evaluation.
68
+ """
69
+
70
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
71
+ raise NotImplementedError()
72
+
73
+
74
+ class LargerHttpxClient(httpx.Client):
75
+ def __init__(self):
76
+ timeout_config = httpx.Timeout(3600)
77
+ limits = httpx.Limits(
78
+ max_keepalive_connections=3600,
79
+ max_connections=3600,
80
+ )
81
+ super().__init__(timeout=timeout_config, limits=limits)
82
+
83
+
84
+ class ChatCompletionSampler(SamplerBase):
85
+ """
86
+ Sample from OpenAI's chat completion API
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ base_url: str = None,
92
+ model: str | None = None,
93
+ system_message: str | None = None,
94
+ temperature: float = 0.0,
95
+ max_tokens: int = 2048,
96
+ ):
97
+ self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
98
+
99
+ if model is None:
100
+ model = self.client.models.list().data[0].id
101
+
102
+ self.model = model
103
+ self.system_message = system_message
104
+ self.temperature = temperature
105
+ self.max_tokens = max_tokens
106
+ self.image_format = "url"
107
+
108
+ def _handle_image(
109
+ self,
110
+ image: str,
111
+ encoding: str = "base64",
112
+ format: str = "png",
113
+ fovea: int = 768,
114
+ ):
115
+ new_image = {
116
+ "type": "image_url",
117
+ "image_url": {
118
+ "url": f"data:image/{format};{encoding},{image}",
119
+ },
120
+ }
121
+ return new_image
122
+
123
+ def _handle_text(self, text: str):
124
+ return {"type": "text", "text": text}
125
+
126
+ def _pack_message(self, role: str, content: Any):
127
+ return {"role": str(role), "content": content}
128
+
129
+ def __call__(self, message_list: MessageList) -> str:
130
+ if self.system_message:
131
+ message_list = [
132
+ self._pack_message("system", self.system_message)
133
+ ] + message_list
134
+ trial = 0
135
+ while True:
136
+ try:
137
+ response = self.client.chat.completions.create(
138
+ model=self.model,
139
+ messages=message_list,
140
+ temperature=self.temperature,
141
+ max_tokens=self.max_tokens,
142
+ )
143
+ return response.choices[0].message.content
144
+ # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
145
+ except openai.BadRequestError as e:
146
+ print("Bad Request Error", e)
147
+ return ""
148
+ except Exception as e:
149
+ exception_backoff = 2**trial # expontial back off
150
+ print(
151
+ f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
152
+ e,
153
+ )
154
+ time.sleep(exception_backoff)
155
+ trial += 1
156
+ # unknown error shall throw exception
157
+
158
+
159
+ QUERY_TEMPLATE_MULTICHOICE = """
160
+ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
161
+
162
+ {Question}
163
+
164
+ A) {A}
165
+ B) {B}
166
+ C) {C}
167
+ D) {D}
168
+ """.strip()
169
+
170
+ ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
171
+ ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
172
+
173
+
174
+ EQUALITY_TEMPLATE = r"""
175
+ Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
176
+
177
+ Examples:
178
+
179
+ Expression 1: $2x+3$
180
+ Expression 2: $3+2x$
181
+
182
+ Yes
183
+
184
+ Expression 1: 3/2
185
+ Expression 2: 1.5
186
+
187
+ Yes
188
+
189
+ Expression 1: $x^2+2x+1$
190
+ Expression 2: $y^2+2y+1$
191
+
192
+ No
193
+
194
+ Expression 1: $x^2+2x+1$
195
+ Expression 2: $(x+1)^2$
196
+
197
+ Yes
198
+
199
+ Expression 1: 3245/5
200
+ Expression 2: 649
201
+
202
+ No
203
+ (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
204
+
205
+ Expression 1: 2/(-3)
206
+ Expression 2: -2/3
207
+
208
+ Yes
209
+ (trivial simplifications are allowed)
210
+
211
+ Expression 1: 72 degrees
212
+ Expression 2: 72
213
+
214
+ Yes
215
+ (give benefit of the doubt to units)
216
+
217
+ Expression 1: 64
218
+ Expression 2: 64 square feet
219
+
220
+ Yes
221
+ (give benefit of the doubt to units)
222
+
223
+ ---
224
+
225
+ YOUR TASK
226
+
227
+
228
+ Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
229
+
230
+ Expression 1: %(expression1)s
231
+ Expression 2: %(expression2)s
232
+ """.strip()
233
+
234
+
235
+ HTML_JINJA = """
236
+ <h3>Prompt conversation</h3>
237
+ {% for message in prompt_messages %}
238
+ {{ message_to_html(message) | safe }}
239
+ {% endfor %}
240
+ <h3>Sampled message</h3>
241
+ {{ message_to_html(next_message) | safe }}
242
+ <h3>Results</h3>
243
+ <p>Correct Answer: {{ correct_answer }}</p>
244
+ <p>Extracted Answer: {{ extracted_answer }}</p>
245
+ <p>Score: {{ score }}</p>
246
+ """
247
+
248
+
249
+ def format_multichoice_question(row):
250
+ return QUERY_TEMPLATE_MULTICHOICE.format(**row)
251
+
252
+
253
+ def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
254
+ prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
255
+ response = sampler([dict(content=prompt, role="user")])
256
+ return response.lower().strip() == "yes"
257
+
258
+
259
+ def _compute_stat(values: list, stat: str):
260
+ if stat == "mean":
261
+ return np.mean(values)
262
+ elif stat == "std":
263
+ return np.std(values)
264
+ elif stat == "min":
265
+ return np.min(values)
266
+ elif stat == "max":
267
+ return np.max(values)
268
+ else:
269
+ raise ValueError(f"Unknown {stat =}")
270
+
271
+
272
+ def aggregate_results(
273
+ single_eval_results: list[SingleEvalResult],
274
+ default_stats: tuple[str] = ("mean", "std"),
275
+ name2stats: dict[str, tuple[str]] | None = None,
276
+ ) -> EvalResult:
277
+ """
278
+ Aggregate results from multiple evaluations into a single EvalResult.
279
+ """
280
+ name2stats = name2stats or {}
281
+ name2values = defaultdict(list)
282
+ htmls = []
283
+ convos = []
284
+ for single_eval_result in single_eval_results:
285
+ for name, value in single_eval_result.metrics.items():
286
+ name2values[name].append(value)
287
+ if single_eval_result.score is not None:
288
+ name2values["score"].append(single_eval_result.score)
289
+ htmls.append(single_eval_result.html)
290
+ convos.append(single_eval_result.convo)
291
+ final_metrics = {}
292
+ for name, values in name2values.items():
293
+ stats = name2stats.get(name, default_stats)
294
+ for stat in stats:
295
+ key = name if stat == "mean" else f"{name}:{stat}"
296
+ final_metrics[key] = _compute_stat(values, stat)
297
+ return EvalResult(
298
+ score=final_metrics.pop("score", None),
299
+ metrics=final_metrics,
300
+ htmls=htmls,
301
+ convos=convos,
302
+ )
303
+
304
+
305
+ def map_with_progress(f: callable, xs: list[Any], num_threads: int):
306
+ """
307
+ Apply f to each element of xs, using a ThreadPool, and show progress.
308
+ """
309
+ if os.getenv("debug"):
310
+ return list(map(f, tqdm(xs, total=len(xs))))
311
+ else:
312
+ with ThreadPool(min(num_threads, len(xs))) as pool:
313
+ return list(tqdm(pool.imap(f, xs), total=len(xs)))
314
+
315
+
316
+ jinja_env = jinja2.Environment(
317
+ loader=jinja2.BaseLoader(),
318
+ undefined=jinja2.StrictUndefined,
319
+ autoescape=jinja2.select_autoescape(["html", "xml"]),
320
+ )
321
+ _message_template = """
322
+ <div class="message {{ role }}">
323
+ <div class="role">
324
+ {{ role }}
325
+ {% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
326
+ </div>
327
+ <div class="content">
328
+ <pre>{{ content }}</pre>
329
+ </div>
330
+ </div>
331
+ """
332
+
333
+
334
+ def message_to_html(message: Message) -> str:
335
+ """
336
+ Generate HTML snippet (inside a <div>) for a message.
337
+ """
338
+ return jinja_env.from_string(_message_template).render(
339
+ role=message["role"],
340
+ content=message["content"],
341
+ variant=message.get("variant", None),
342
+ )
343
+
344
+
345
+ jinja_env.globals["message_to_html"] = message_to_html
346
+
347
+
348
+ _report_template = """<!DOCTYPE html>
349
+ <html>
350
+ <head>
351
+ <style>
352
+ .message {
353
+ padding: 8px 16px;
354
+ margin-bottom: 8px;
355
+ border-radius: 4px;
356
+ }
357
+ .message.user {
358
+ background-color: #B2DFDB;
359
+ color: #00695C;
360
+ }
361
+ .message.assistant {
362
+ background-color: #B39DDB;
363
+ color: #4527A0;
364
+ }
365
+ .message.system {
366
+ background-color: #EEEEEE;
367
+ color: #212121;
368
+ }
369
+ .role {
370
+ font-weight: bold;
371
+ margin-bottom: 4px;
372
+ }
373
+ .variant {
374
+ color: #795548;
375
+ }
376
+ table, th, td {
377
+ border: 1px solid black;
378
+ }
379
+ pre {
380
+ white-space: pre-wrap;
381
+ }
382
+ </style>
383
+ </head>
384
+ <body>
385
+ {% if metrics %}
386
+ <h1>Metrics</h1>
387
+ <table>
388
+ <tr>
389
+ <th>Metric</th>
390
+ <th>Value</th>
391
+ </tr>
392
+ <tr>
393
+ <td><b>Score</b></td>
394
+ <td>{{ score | float | round(3) }}</td>
395
+ </tr>
396
+ {% for name, value in metrics.items() %}
397
+ <tr>
398
+ <td>{{ name }}</td>
399
+ <td>{{ value }}</td>
400
+ </tr>
401
+ {% endfor %}
402
+ </table>
403
+ {% endif %}
404
+ <h1>Examples</h1>
405
+ {% for html in htmls %}
406
+ {{ html | safe }}
407
+ <hr>
408
+ {% endfor %}
409
+ </body>
410
+ </html>
411
+ """
412
+
413
+
414
+ def make_report(eval_result: EvalResult) -> str:
415
+ """
416
+ Create a standalone HTML report from an EvalResult.
417
+ """
418
+ return jinja_env.from_string(_report_template).render(
419
+ score=eval_result.score,
420
+ metrics=eval_result.metrics,
421
+ htmls=eval_result.htmls,
422
+ )
423
+
424
+
425
+ def make_report_from_example_htmls(htmls: list[str]):
426
+ """
427
+ Create a standalone HTML report from a list of example htmls
428
+ """
429
+ return jinja_env.from_string(_report_template).render(
430
+ score=None, metrics={}, htmls=htmls
431
+ )
432
+
433
+
434
+ def download_dataset(path, url):
435
+ print(f"Downloading dataset {path} from {url}")
436
+ try:
437
+ response = requests.get(url, stream=True)
438
+ response.raise_for_status()
439
+
440
+ total_size = int(response.headers.get("content-length", 0))
441
+ block_size = 8192
442
+
443
+ with open(path, "wb") as f, tqdm(
444
+ desc="Downloading",
445
+ total=total_size,
446
+ unit="iB",
447
+ unit_scale=True,
448
+ unit_divisor=1024,
449
+ ) as progress_bar:
450
+ for data in response.iter_content(block_size):
451
+ size = f.write(data)
452
+ progress_bar.update(size)
453
+
454
+ print(f"Dataset downloaded and saved to {path}")
455
+ except requests.RequestException as e:
456
+ raise Exception(f"Failed to download dataset: {e}")
457
+
458
+
459
+ def set_ulimit(target_soft_limit=65535):
460
+ resource_type = resource.RLIMIT_NOFILE
461
+ current_soft, current_hard = resource.getrlimit(resource_type)
462
+
463
+ if current_soft < target_soft_limit:
464
+ try:
465
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
466
+ except ValueError as e:
467
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
@@ -0,0 +1,139 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ HumanEval: Evaluating Large Language Models Trained on Code
5
+ Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
6
+ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ import multiprocessing
12
+ import random
13
+ import re
14
+ from collections import Counter, defaultdict
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ from io import BytesIO
17
+ from typing import Any, Tuple
18
+
19
+ import blobfile as bf
20
+ import tqdm
21
+
22
+ try:
23
+ from human_eval.data import HUMAN_EVAL, read_problems
24
+ from human_eval.evaluation import estimate_pass_at_k
25
+ from human_eval.execution import check_correctness # , unsafe_execute
26
+ except (ImportError, ModuleNotFoundError):
27
+ print("\nPlease install human-eval at https://github.com/openai/human-eval.\n")
28
+ raise
29
+
30
+ from sglang.test import simple_eval_common as common
31
+ from sglang.test.simple_eval_common import (
32
+ HTML_JINJA,
33
+ Eval,
34
+ EvalResult,
35
+ SamplerBase,
36
+ SingleEvalResult,
37
+ )
38
+
39
+
40
+ def evaluate_functional_correctness(
41
+ sample: dict[str, str],
42
+ completions: list[str],
43
+ n_workers: int = 4,
44
+ timeout: float = 3.0,
45
+ ):
46
+ """
47
+ Evaluates the functional correctness of generated samples, and writes
48
+ results to f"{sample_file}_results.jsonl.gz"
49
+ """
50
+ import copy
51
+
52
+ # Check the generated samples against test suites.
53
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
54
+ futures = []
55
+ for i, completion in enumerate(completions):
56
+ args = (sample, completion, timeout, i)
57
+ future = executor.submit(check_correctness, *args)
58
+ futures.append(future)
59
+ results = []
60
+ for future in as_completed(futures):
61
+ result = future.result()
62
+ results.append(result)
63
+ passed = [int(r["passed"]) for r in results]
64
+ return passed
65
+
66
+
67
+ class HumanEval(Eval):
68
+ def __init__(
69
+ self,
70
+ num_examples: int | None,
71
+ num_threads: int,
72
+ num_samples_per_task: int = 5,
73
+ ks_passes: list[int] = [1, 2, 5],
74
+ timeout: int = 120,
75
+ ):
76
+ self.seed = 0
77
+ self.examples = read_problems()
78
+ self.examples = list(self.examples.values())
79
+
80
+ self._num_examples = num_examples
81
+ if self._num_examples:
82
+ self.examples = random.Random(self.seed).sample(self.examples, num_examples)
83
+ self._num_samples_per_task = num_samples_per_task
84
+ self._ks_passes = ks_passes
85
+ self._timeout = timeout
86
+ self._num_threads = num_threads
87
+
88
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
89
+ instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
90
+
91
+ def find_code(completion):
92
+ pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
93
+ matches = pattern.findall(completion)
94
+ extracted_answer = matches[0] if len(matches) >= 1 else completion
95
+ extracted_answer = extracted_answer[
96
+ extracted_answer.find(":\n ") + 2 :
97
+ ] # remove signature
98
+ return extracted_answer
99
+
100
+ def fn(sample: dict[str, str]):
101
+ prompt_messages = [
102
+ sampler._pack_message(
103
+ role="user", content=instruction + sample["prompt"]
104
+ )
105
+ ]
106
+ completions = [
107
+ find_code(sampler(prompt_messages))
108
+ for _ in range(self._num_samples_per_task)
109
+ ]
110
+ results = evaluate_functional_correctness(sample, completions)
111
+ total = len(results)
112
+ correct = sum(results)
113
+ score = sum(results) / len(results)
114
+ html = common.jinja_env.from_string(HTML_JINJA).render(
115
+ prompt_messages=prompt_messages,
116
+ next_message=dict(content=completions[0], role="assistant"),
117
+ score=score,
118
+ correct_answer=[1] * len(results),
119
+ extracted_answer=results,
120
+ )
121
+ convo = prompt_messages + [
122
+ dict(content=completion, role="assistant") for completion in completions
123
+ ]
124
+ return SingleEvalResult(
125
+ html=html,
126
+ score=score,
127
+ convo=convo,
128
+ metrics={
129
+ f"pass@{k}": estimate_pass_at_k([total], [correct], k)
130
+ # this will be aggrated so no need of .mean()
131
+ for k in self._ks_passes
132
+ if total >= k
133
+ },
134
+ )
135
+
136
+ results = common.map_with_progress(
137
+ fn, self.examples, num_threads=self._num_threads
138
+ )
139
+ return common.aggregate_results(results)