sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/test/run_eval.py CHANGED
@@ -10,7 +10,6 @@ import time
10
10
 
11
11
  from sglang.test.simple_eval_common import (
12
12
  ChatCompletionSampler,
13
- download_dataset,
14
13
  make_report,
15
14
  set_ulimit,
16
15
  )
@@ -27,14 +26,26 @@ def run_eval(args):
27
26
  if args.eval_name == "mmlu":
28
27
  from sglang.test.simple_eval_mmlu import MMLUEval
29
28
 
30
- dataset_path = "mmlu.csv"
31
-
32
- if not os.path.exists(dataset_path):
33
- download_dataset(
34
- dataset_path,
35
- "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
36
- )
37
- eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
29
+ filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
30
+ eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
31
+ elif args.eval_name == "math":
32
+ from sglang.test.simple_eval_math import MathEval
33
+
34
+ equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
35
+
36
+ filename = (
37
+ "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
38
+ )
39
+ eval_obj = MathEval(
40
+ filename, equality_checker, args.num_examples, args.num_threads
41
+ )
42
+ elif args.eval_name == "gpqa":
43
+ from sglang.test.simple_eval_gpqa import GPQAEval
44
+
45
+ filename = (
46
+ "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
47
+ )
48
+ eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
38
49
  elif args.eval_name == "humaneval":
39
50
  from sglang.test.simple_eval_humaneval import HumanEval
40
51
 
@@ -97,7 +108,7 @@ if __name__ == "__main__":
97
108
  )
98
109
  parser.add_argument("--eval-name", type=str, default="mmlu")
99
110
  parser.add_argument("--num-examples", type=int)
100
- parser.add_argument("--num-threads", type=int, default=64)
111
+ parser.add_argument("--num-threads", type=int, default=512)
101
112
  set_ulimit()
102
113
  args = parser.parse_args()
103
114
 
sglang/test/runners.py ADDED
@@ -0,0 +1,237 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ import json
17
+ import multiprocessing
18
+ from dataclasses import dataclass
19
+ from typing import List, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+ from sglang.srt.server import Runtime
26
+
27
+ DEFAULT_PROMPTS = [
28
+ "The capital of France is",
29
+ "The capital of the United Kindom is",
30
+ "Today is a sunny day and I like",
31
+ ]
32
+
33
+ NUM_TOP_LOGPROBS = 5
34
+
35
+
36
+ def is_embedding_model(model_path):
37
+ # FIXME incomplete list
38
+ if "e5-mistral-7b-instruct" in model_path.lower():
39
+ return True
40
+ return False
41
+
42
+
43
+ def get_dtype_str(torch_dtype):
44
+ if torch_dtype is torch.float16:
45
+ return "float16"
46
+ else:
47
+ raise NotImplementedError()
48
+
49
+
50
+ @dataclass
51
+ class ModelOutput:
52
+ output_strs: str = None
53
+ top_input_logprobs: torch.Tensor = None
54
+ top_output_logprobs: torch.Tensor = None
55
+ embed_logits: torch.Tensor = None
56
+
57
+
58
+ class HFRunner:
59
+ def __init__(
60
+ self,
61
+ model_path,
62
+ torch_dtype=torch.float16,
63
+ is_embedding_model=None,
64
+ ):
65
+ self.in_queue = multiprocessing.Queue()
66
+ self.out_queue = multiprocessing.Queue()
67
+
68
+ self.model_proc = multiprocessing.Process(
69
+ target=self.start_model_process,
70
+ args=(
71
+ self.in_queue,
72
+ self.out_queue,
73
+ model_path,
74
+ torch_dtype,
75
+ is_embedding_model,
76
+ ),
77
+ )
78
+ self.model_proc.start()
79
+
80
+ def start_model_process(
81
+ self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
82
+ ):
83
+ self.tokenizer = AutoTokenizer.from_pretrained(
84
+ model_path,
85
+ torch_dtype=torch_dtype,
86
+ trust_remote_code=True,
87
+ )
88
+
89
+ self.is_embedding_model = (
90
+ is_embedding_model(model_path)
91
+ if is_embedding_model is None
92
+ else is_embedding_model
93
+ )
94
+ if not self.is_embedding_model:
95
+ self.model = AutoModelForCausalLM.from_pretrained(
96
+ model_path,
97
+ torch_dtype=torch_dtype,
98
+ low_cpu_mem_usage=True,
99
+ trust_remote_code=True,
100
+ ).cuda()
101
+ else:
102
+ from sentence_transformers import SentenceTransformer
103
+
104
+ self.model = SentenceTransformer(
105
+ model_path,
106
+ device="cpu",
107
+ ).to(dtype=torch_dtype)
108
+
109
+ while True:
110
+ prompts, max_new_tokens = in_queue.get()
111
+ if prompts is not None:
112
+ if not self.is_embedding_model:
113
+ output_strs = []
114
+ prefill_logprobs = []
115
+ for p in prompts:
116
+ if isinstance(p, str):
117
+ input_ids = self.tokenizer.encode(
118
+ p, return_tensors="pt"
119
+ ).cuda()
120
+ else:
121
+ input_ids = torch.tensor([p], device="cuda")
122
+
123
+ output_ids = self.model.generate(
124
+ input_ids, do_sample=False, max_new_tokens=max_new_tokens
125
+ )
126
+ output_strs.append(self.tokenizer.decode(output_ids[0]))
127
+
128
+ logits = self.model.forward(input_ids).logits[0]
129
+ logprobs = F.log_softmax(
130
+ logits, dim=-1, dtype=torch.float32
131
+ ).tolist()
132
+ # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
133
+ # print("index", index_of_max)
134
+ logprobs = [
135
+ sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
136
+ for token_logprobs in logprobs
137
+ ]
138
+ prefill_logprobs.append(logprobs)
139
+
140
+ out_queue.put(
141
+ ModelOutput(
142
+ output_strs=output_strs, top_input_logprobs=prefill_logprobs
143
+ )
144
+ )
145
+
146
+ else:
147
+ assert isinstance(prompts, List[str])
148
+ logits = self.model.encode(prompts).tolist()
149
+
150
+ out_queue.put(ModelOutput(embed_logits=logits))
151
+
152
+ def forward(
153
+ self,
154
+ prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
155
+ max_new_tokens=64,
156
+ ):
157
+ self.in_queue.put((prompts, max_new_tokens))
158
+ return self.out_queue.get()
159
+
160
+ def terminate(self):
161
+ self.model_proc.terminate()
162
+ self.in_queue = self.out_queue = None
163
+
164
+ def __enter__(self):
165
+ return self
166
+
167
+ def __exit__(self, exc_type, exc_value, traceback):
168
+ self.model_proc.terminate()
169
+ self.in_queue = self.out_queue = None
170
+
171
+
172
+ class SRTRunner:
173
+ def __init__(
174
+ self,
175
+ model_path,
176
+ tp_size=1,
177
+ torch_dtype=torch.float16,
178
+ is_embedding_model=None,
179
+ ):
180
+ self.is_embedding_model = (
181
+ is_embedding_model(model_path)
182
+ if is_embedding_model is None
183
+ else is_embedding_model
184
+ )
185
+ if self.is_embedding_model:
186
+ raise NotImplementedError()
187
+
188
+ self.runtime = Runtime(
189
+ model_path=model_path,
190
+ tp_size=tp_size,
191
+ dtype=get_dtype_str(torch_dtype),
192
+ )
193
+
194
+ def forward(
195
+ self,
196
+ prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
197
+ max_new_tokens=64,
198
+ ):
199
+ # the return value contains logprobs from prefill
200
+ output_strs = []
201
+ top_input_logprobs = []
202
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
203
+ for prompt in prompts:
204
+ response = self.runtime.generate(
205
+ prompt,
206
+ sampling_params=sampling_params,
207
+ return_logprob=True,
208
+ top_logprobs_num=NUM_TOP_LOGPROBS,
209
+ )
210
+ response = json.loads(response)
211
+ output_strs.append(response["text"])
212
+ top_input_logprobs.append(
213
+ [
214
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
215
+ for x in response["meta_info"]["input_top_logprobs"][1:]
216
+ ]
217
+ + [
218
+ [
219
+ tup[0]
220
+ for tup in response["meta_info"]["output_top_logprobs"][0][
221
+ :NUM_TOP_LOGPROBS
222
+ ]
223
+ ]
224
+ ]
225
+ )
226
+ # print(response["meta_info"]["output_top_logprobs"][0])
227
+
228
+ return ModelOutput(
229
+ output_strs=output_strs, top_input_logprobs=top_input_logprobs
230
+ )
231
+
232
+ def __enter__(self):
233
+ return self
234
+
235
+ def __exit__(self, exc_type, exc_value, traceback):
236
+ self.runtime.shutdown()
237
+ del self.runtime
@@ -7,7 +7,7 @@ import time
7
7
  from collections import defaultdict
8
8
  from dataclasses import dataclass, field
9
9
  from multiprocessing.pool import ThreadPool
10
- from typing import Any
10
+ from typing import Any, Dict, List, Tuple
11
11
 
12
12
  import httpx
13
13
  import jinja2
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
24
24
  )
25
25
 
26
26
 
27
- Message = dict[str, Any] # keys role, content
28
- MessageList = list[Message]
27
+ Message = Dict[str, Any] # keys role, content
28
+ MessageList = List[Message]
29
29
 
30
30
 
31
31
  class SamplerBase:
@@ -45,9 +45,9 @@ class EvalResult:
45
45
  """
46
46
 
47
47
  score: float | None # top-line metric
48
- metrics: dict[str, float] | None # other metrics
49
- htmls: list[str] # strings of valid HTML
50
- convos: list[MessageList] # sampled conversations
48
+ metrics: Dict[str, float] | None # other metrics
49
+ htmls: List[str] # strings of valid HTML
50
+ convos: List[MessageList] # sampled conversations
51
51
 
52
52
 
53
53
  @dataclass
@@ -57,7 +57,7 @@ class SingleEvalResult:
57
57
  """
58
58
 
59
59
  score: float | None
60
- metrics: dict[str, float] = field(default_factory=dict)
60
+ metrics: Dict[str, float] = field(default_factory=dict)
61
61
  html: str | None = None
62
62
  convo: MessageList | None = None # sampled conversation
63
63
 
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
270
270
 
271
271
 
272
272
  def aggregate_results(
273
- single_eval_results: list[SingleEvalResult],
274
- default_stats: tuple[str] = ("mean", "std"),
275
- name2stats: dict[str, tuple[str]] | None = None,
273
+ single_eval_results: List[SingleEvalResult],
274
+ default_stats: Tuple[str] = ("mean", "std"),
275
+ name2stats: Dict[str, Tuple[str]] | None = None,
276
276
  ) -> EvalResult:
277
277
  """
278
278
  Aggregate results from multiple evaluations into a single EvalResult.
@@ -302,7 +302,7 @@ def aggregate_results(
302
302
  )
303
303
 
304
304
 
305
- def map_with_progress(f: callable, xs: list[Any], num_threads: int):
305
+ def map_with_progress(f: callable, xs: List[Any], num_threads: int):
306
306
  """
307
307
  Apply f to each element of xs, using a ThreadPool, and show progress.
308
308
  """
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
422
422
  )
423
423
 
424
424
 
425
- def make_report_from_example_htmls(htmls: list[str]):
425
+ def make_report_from_example_htmls(htmls: List[str]):
426
426
  """
427
427
  Create a standalone HTML report from a list of example htmls
428
428
  """
@@ -0,0 +1,92 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ GPQA: A Graduate-Level Google-Proof Q&A Benchmark
5
+ David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
6
+ https://arxiv.org/abs/2311.12022
7
+ """
8
+
9
+ import random
10
+ import re
11
+
12
+ import pandas
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ ANSWER_PATTERN_MULTICHOICE,
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ MessageList,
21
+ SamplerBase,
22
+ SingleEvalResult,
23
+ format_multichoice_question,
24
+ )
25
+
26
+
27
+ class GPQAEval(Eval):
28
+ def __init__(
29
+ self,
30
+ filename: str,
31
+ num_examples: int | None,
32
+ num_threads: int,
33
+ n_repeats: int = 1,
34
+ ):
35
+ df = pandas.read_csv(filename)
36
+ examples = [row.to_dict() for _, row in df.iterrows()]
37
+ rng = random.Random(0)
38
+ if num_examples:
39
+ assert n_repeats == 1, "n_repeats only supported for num_examples"
40
+ examples = rng.sample(examples, num_examples)
41
+ examples = examples * n_repeats
42
+ examples = [
43
+ example | {"permutation": rng.sample(range(4), 4)} for example in examples
44
+ ]
45
+ self.examples = examples
46
+ self.n_repeats = n_repeats
47
+ self.num_threads = num_threads
48
+
49
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
50
+ def fn(row: dict):
51
+ choices = [
52
+ row["Correct Answer"],
53
+ row["Incorrect Answer 1"],
54
+ row["Incorrect Answer 2"],
55
+ row["Incorrect Answer 3"],
56
+ ]
57
+ choices = [choices[i] for i in row["permutation"]]
58
+ correct_index = choices.index(row["Correct Answer"])
59
+ correct_answer = "ABCD"[correct_index]
60
+ choices_dict = dict(
61
+ A=choices[0],
62
+ B=choices[1],
63
+ C=choices[2],
64
+ D=choices[3],
65
+ Question=row["Question"],
66
+ )
67
+ prompt_messages = [
68
+ sampler._pack_message(
69
+ content=format_multichoice_question(choices_dict), role="user"
70
+ )
71
+ ]
72
+ response_text = sampler(prompt_messages)
73
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
74
+ extracted_answer = match.group(1) if match else None
75
+ score = 1.0 if extracted_answer == correct_answer else 0.0
76
+ html = common.jinja_env.from_string(HTML_JINJA).render(
77
+ prompt_messages=prompt_messages,
78
+ next_message=dict(content=response_text, role="assistant"),
79
+ score=score,
80
+ correct_answer=correct_answer,
81
+ extracted_answer=extracted_answer,
82
+ )
83
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
84
+ return SingleEvalResult(
85
+ html=html,
86
+ score=score,
87
+ convo=convo,
88
+ metrics={"chars": len(response_text)},
89
+ )
90
+
91
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
92
+ return common.aggregate_results(results)
@@ -14,7 +14,7 @@ import re
14
14
  from collections import Counter, defaultdict
15
15
  from concurrent.futures import ThreadPoolExecutor, as_completed
16
16
  from io import BytesIO
17
- from typing import Any, Tuple
17
+ from typing import Any, Dict, List, Tuple
18
18
 
19
19
  import blobfile as bf
20
20
  import tqdm
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
38
38
 
39
39
 
40
40
  def evaluate_functional_correctness(
41
- sample: dict[str, str],
42
- completions: list[str],
41
+ sample: Dict[str, str],
42
+ completions: List[str],
43
43
  n_workers: int = 4,
44
44
  timeout: float = 3.0,
45
45
  ):
@@ -70,7 +70,7 @@ class HumanEval(Eval):
70
70
  num_examples: int | None,
71
71
  num_threads: int,
72
72
  num_samples_per_task: int = 5,
73
- ks_passes: list[int] = [1, 2, 5],
73
+ ks_passes: List[int] = [1, 2, 5],
74
74
  timeout: int = 120,
75
75
  ):
76
76
  self.seed = 0
@@ -97,7 +97,7 @@ class HumanEval(Eval):
97
97
  ] # remove signature
98
98
  return extracted_answer
99
99
 
100
- def fn(sample: dict[str, str]):
100
+ def fn(sample: Dict[str, str]):
101
101
  prompt_messages = [
102
102
  sampler._pack_message(
103
103
  role="user", content=instruction + sample["prompt"]
@@ -0,0 +1,72 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ Measuring Mathematical Problem Solving With the MATH Dataset
5
+ Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
6
+ https://arxiv.org/abs/2103.03874
7
+ """
8
+
9
+ import random
10
+ import re
11
+
12
+ import pandas
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ ANSWER_PATTERN,
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ SamplerBase,
21
+ SingleEvalResult,
22
+ check_equality,
23
+ )
24
+
25
+ QUERY_TEMPLATE = """
26
+ Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
27
+
28
+ {Question}
29
+
30
+ Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
31
+ """.strip()
32
+
33
+
34
+ class MathEval(Eval):
35
+ def __init__(
36
+ self,
37
+ filename: str,
38
+ equality_checker: SamplerBase,
39
+ num_examples: int | None,
40
+ num_threads: int,
41
+ ):
42
+ df = pandas.read_csv(filename)
43
+ examples = [row.to_dict() for _, row in df.iterrows()]
44
+ if num_examples:
45
+ examples = random.Random(0).sample(examples, num_examples)
46
+ self.examples = examples
47
+ self.equality_checker = equality_checker
48
+ self.num_threads = num_threads
49
+
50
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
51
+ def fn(row: dict):
52
+ prompt_messages = [
53
+ sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
54
+ ]
55
+ response_text = sampler(prompt_messages)
56
+ match = re.search(ANSWER_PATTERN, response_text)
57
+ extracted_answer = match.group(1) if match else None
58
+ score = float(
59
+ check_equality(self.equality_checker, row["Answer"], extracted_answer)
60
+ )
61
+ html = common.jinja_env.from_string(HTML_JINJA).render(
62
+ prompt_messages=prompt_messages,
63
+ next_message=dict(content=response_text, role="assistant"),
64
+ score=score,
65
+ correct_answer=row["Answer"],
66
+ extracted_answer=extracted_answer,
67
+ )
68
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
69
+ return SingleEvalResult(html=html, score=score, convo=convo)
70
+
71
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
72
+ return common.aggregate_results(results)