sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/test/runners.py ADDED
@@ -0,0 +1,237 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ import json
17
+ import multiprocessing
18
+ from dataclasses import dataclass
19
+ from typing import List, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+ from sglang.srt.server import Runtime
26
+
27
+ DEFAULT_PROMPTS = [
28
+ "The capital of France is",
29
+ "The capital of the United Kindom is",
30
+ "Today is a sunny day and I like",
31
+ ]
32
+
33
+ NUM_TOP_LOGPROBS = 5
34
+
35
+
36
+ def is_embedding_model(model_path):
37
+ # FIXME incomplete list
38
+ if "e5-mistral-7b-instruct" in model_path.lower():
39
+ return True
40
+ return False
41
+
42
+
43
+ def get_dtype_str(torch_dtype):
44
+ if torch_dtype is torch.float16:
45
+ return "float16"
46
+ else:
47
+ raise NotImplementedError()
48
+
49
+
50
+ @dataclass
51
+ class ModelOutput:
52
+ output_strs: str = None
53
+ top_input_logprobs: torch.Tensor = None
54
+ top_output_logprobs: torch.Tensor = None
55
+ embed_logits: torch.Tensor = None
56
+
57
+
58
+ class HFRunner:
59
+ def __init__(
60
+ self,
61
+ model_path,
62
+ torch_dtype=torch.float16,
63
+ is_embedding_model=None,
64
+ ):
65
+ self.in_queue = multiprocessing.Queue()
66
+ self.out_queue = multiprocessing.Queue()
67
+
68
+ self.model_proc = multiprocessing.Process(
69
+ target=self.start_model_process,
70
+ args=(
71
+ self.in_queue,
72
+ self.out_queue,
73
+ model_path,
74
+ torch_dtype,
75
+ is_embedding_model,
76
+ ),
77
+ )
78
+ self.model_proc.start()
79
+
80
+ def start_model_process(
81
+ self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
82
+ ):
83
+ self.tokenizer = AutoTokenizer.from_pretrained(
84
+ model_path,
85
+ torch_dtype=torch_dtype,
86
+ trust_remote_code=True,
87
+ )
88
+
89
+ self.is_embedding_model = (
90
+ is_embedding_model(model_path)
91
+ if is_embedding_model is None
92
+ else is_embedding_model
93
+ )
94
+ if not self.is_embedding_model:
95
+ self.model = AutoModelForCausalLM.from_pretrained(
96
+ model_path,
97
+ torch_dtype=torch_dtype,
98
+ low_cpu_mem_usage=True,
99
+ trust_remote_code=True,
100
+ ).cuda()
101
+ else:
102
+ from sentence_transformers import SentenceTransformer
103
+
104
+ self.model = SentenceTransformer(
105
+ model_path,
106
+ device="cpu",
107
+ ).to(dtype=torch_dtype)
108
+
109
+ while True:
110
+ prompts, max_new_tokens = in_queue.get()
111
+ if prompts is not None:
112
+ if not self.is_embedding_model:
113
+ output_strs = []
114
+ prefill_logprobs = []
115
+ for p in prompts:
116
+ if isinstance(p, str):
117
+ input_ids = self.tokenizer.encode(
118
+ p, return_tensors="pt"
119
+ ).cuda()
120
+ else:
121
+ input_ids = torch.tensor([p], device="cuda")
122
+
123
+ output_ids = self.model.generate(
124
+ input_ids, do_sample=False, max_new_tokens=max_new_tokens
125
+ )
126
+ output_strs.append(self.tokenizer.decode(output_ids[0]))
127
+
128
+ logits = self.model.forward(input_ids).logits[0]
129
+ logprobs = F.log_softmax(
130
+ logits, dim=-1, dtype=torch.float32
131
+ ).tolist()
132
+ # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
133
+ # print("index", index_of_max)
134
+ logprobs = [
135
+ sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
136
+ for token_logprobs in logprobs
137
+ ]
138
+ prefill_logprobs.append(logprobs)
139
+
140
+ out_queue.put(
141
+ ModelOutput(
142
+ output_strs=output_strs, top_input_logprobs=prefill_logprobs
143
+ )
144
+ )
145
+
146
+ else:
147
+ assert isinstance(prompts, List[str])
148
+ logits = self.model.encode(prompts).tolist()
149
+
150
+ out_queue.put(ModelOutput(embed_logits=logits))
151
+
152
+ def forward(
153
+ self,
154
+ prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
155
+ max_new_tokens=64,
156
+ ):
157
+ self.in_queue.put((prompts, max_new_tokens))
158
+ return self.out_queue.get()
159
+
160
+ def terminate(self):
161
+ self.model_proc.terminate()
162
+ self.in_queue = self.out_queue = None
163
+
164
+ def __enter__(self):
165
+ return self
166
+
167
+ def __exit__(self, exc_type, exc_value, traceback):
168
+ self.model_proc.terminate()
169
+ self.in_queue = self.out_queue = None
170
+
171
+
172
+ class SRTRunner:
173
+ def __init__(
174
+ self,
175
+ model_path,
176
+ tp_size=1,
177
+ torch_dtype=torch.float16,
178
+ is_embedding_model=None,
179
+ ):
180
+ self.is_embedding_model = (
181
+ is_embedding_model(model_path)
182
+ if is_embedding_model is None
183
+ else is_embedding_model
184
+ )
185
+ if self.is_embedding_model:
186
+ raise NotImplementedError()
187
+
188
+ self.runtime = Runtime(
189
+ model_path=model_path,
190
+ tp_size=tp_size,
191
+ dtype=get_dtype_str(torch_dtype),
192
+ )
193
+
194
+ def forward(
195
+ self,
196
+ prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
197
+ max_new_tokens=64,
198
+ ):
199
+ # the return value contains logprobs from prefill
200
+ output_strs = []
201
+ top_input_logprobs = []
202
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
203
+ for prompt in prompts:
204
+ response = self.runtime.generate(
205
+ prompt,
206
+ sampling_params=sampling_params,
207
+ return_logprob=True,
208
+ top_logprobs_num=NUM_TOP_LOGPROBS,
209
+ )
210
+ response = json.loads(response)
211
+ output_strs.append(response["text"])
212
+ top_input_logprobs.append(
213
+ [
214
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
215
+ for x in response["meta_info"]["input_top_logprobs"][1:]
216
+ ]
217
+ + [
218
+ [
219
+ tup[0]
220
+ for tup in response["meta_info"]["output_top_logprobs"][0][
221
+ :NUM_TOP_LOGPROBS
222
+ ]
223
+ ]
224
+ ]
225
+ )
226
+ # print(response["meta_info"]["output_top_logprobs"][0])
227
+
228
+ return ModelOutput(
229
+ output_strs=output_strs, top_input_logprobs=top_input_logprobs
230
+ )
231
+
232
+ def __enter__(self):
233
+ return self
234
+
235
+ def __exit__(self, exc_type, exc_value, traceback):
236
+ self.runtime.shutdown()
237
+ del self.runtime
@@ -7,7 +7,7 @@ import time
7
7
  from collections import defaultdict
8
8
  from dataclasses import dataclass, field
9
9
  from multiprocessing.pool import ThreadPool
10
- from typing import Any
10
+ from typing import Any, Dict, List, Tuple
11
11
 
12
12
  import httpx
13
13
  import jinja2
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
24
24
  )
25
25
 
26
26
 
27
- Message = dict[str, Any] # keys role, content
28
- MessageList = list[Message]
27
+ Message = Dict[str, Any] # keys role, content
28
+ MessageList = List[Message]
29
29
 
30
30
 
31
31
  class SamplerBase:
@@ -45,9 +45,9 @@ class EvalResult:
45
45
  """
46
46
 
47
47
  score: float | None # top-line metric
48
- metrics: dict[str, float] | None # other metrics
49
- htmls: list[str] # strings of valid HTML
50
- convos: list[MessageList] # sampled conversations
48
+ metrics: Dict[str, float] | None # other metrics
49
+ htmls: List[str] # strings of valid HTML
50
+ convos: List[MessageList] # sampled conversations
51
51
 
52
52
 
53
53
  @dataclass
@@ -57,7 +57,7 @@ class SingleEvalResult:
57
57
  """
58
58
 
59
59
  score: float | None
60
- metrics: dict[str, float] = field(default_factory=dict)
60
+ metrics: Dict[str, float] = field(default_factory=dict)
61
61
  html: str | None = None
62
62
  convo: MessageList | None = None # sampled conversation
63
63
 
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
270
270
 
271
271
 
272
272
  def aggregate_results(
273
- single_eval_results: list[SingleEvalResult],
274
- default_stats: tuple[str] = ("mean", "std"),
275
- name2stats: dict[str, tuple[str]] | None = None,
273
+ single_eval_results: List[SingleEvalResult],
274
+ default_stats: Tuple[str] = ("mean", "std"),
275
+ name2stats: Dict[str, Tuple[str]] | None = None,
276
276
  ) -> EvalResult:
277
277
  """
278
278
  Aggregate results from multiple evaluations into a single EvalResult.
@@ -302,7 +302,7 @@ def aggregate_results(
302
302
  )
303
303
 
304
304
 
305
- def map_with_progress(f: callable, xs: list[Any], num_threads: int):
305
+ def map_with_progress(f: callable, xs: List[Any], num_threads: int):
306
306
  """
307
307
  Apply f to each element of xs, using a ThreadPool, and show progress.
308
308
  """
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
422
422
  )
423
423
 
424
424
 
425
- def make_report_from_example_htmls(htmls: list[str]):
425
+ def make_report_from_example_htmls(htmls: List[str]):
426
426
  """
427
427
  Create a standalone HTML report from a list of example htmls
428
428
  """
@@ -0,0 +1,92 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ GPQA: A Graduate-Level Google-Proof Q&A Benchmark
5
+ David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
6
+ https://arxiv.org/abs/2311.12022
7
+ """
8
+
9
+ import random
10
+ import re
11
+
12
+ import pandas
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ ANSWER_PATTERN_MULTICHOICE,
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ MessageList,
21
+ SamplerBase,
22
+ SingleEvalResult,
23
+ format_multichoice_question,
24
+ )
25
+
26
+
27
+ class GPQAEval(Eval):
28
+ def __init__(
29
+ self,
30
+ filename: str,
31
+ num_examples: int | None,
32
+ num_threads: int,
33
+ n_repeats: int = 1,
34
+ ):
35
+ df = pandas.read_csv(filename)
36
+ examples = [row.to_dict() for _, row in df.iterrows()]
37
+ rng = random.Random(0)
38
+ if num_examples:
39
+ assert n_repeats == 1, "n_repeats only supported for num_examples"
40
+ examples = rng.sample(examples, num_examples)
41
+ examples = examples * n_repeats
42
+ examples = [
43
+ example | {"permutation": rng.sample(range(4), 4)} for example in examples
44
+ ]
45
+ self.examples = examples
46
+ self.n_repeats = n_repeats
47
+ self.num_threads = num_threads
48
+
49
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
50
+ def fn(row: dict):
51
+ choices = [
52
+ row["Correct Answer"],
53
+ row["Incorrect Answer 1"],
54
+ row["Incorrect Answer 2"],
55
+ row["Incorrect Answer 3"],
56
+ ]
57
+ choices = [choices[i] for i in row["permutation"]]
58
+ correct_index = choices.index(row["Correct Answer"])
59
+ correct_answer = "ABCD"[correct_index]
60
+ choices_dict = dict(
61
+ A=choices[0],
62
+ B=choices[1],
63
+ C=choices[2],
64
+ D=choices[3],
65
+ Question=row["Question"],
66
+ )
67
+ prompt_messages = [
68
+ sampler._pack_message(
69
+ content=format_multichoice_question(choices_dict), role="user"
70
+ )
71
+ ]
72
+ response_text = sampler(prompt_messages)
73
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
74
+ extracted_answer = match.group(1) if match else None
75
+ score = 1.0 if extracted_answer == correct_answer else 0.0
76
+ html = common.jinja_env.from_string(HTML_JINJA).render(
77
+ prompt_messages=prompt_messages,
78
+ next_message=dict(content=response_text, role="assistant"),
79
+ score=score,
80
+ correct_answer=correct_answer,
81
+ extracted_answer=extracted_answer,
82
+ )
83
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
84
+ return SingleEvalResult(
85
+ html=html,
86
+ score=score,
87
+ convo=convo,
88
+ metrics={"chars": len(response_text)},
89
+ )
90
+
91
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
92
+ return common.aggregate_results(results)
@@ -14,7 +14,7 @@ import re
14
14
  from collections import Counter, defaultdict
15
15
  from concurrent.futures import ThreadPoolExecutor, as_completed
16
16
  from io import BytesIO
17
- from typing import Any, Tuple
17
+ from typing import Any, Dict, List, Tuple
18
18
 
19
19
  import blobfile as bf
20
20
  import tqdm
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
38
38
 
39
39
 
40
40
  def evaluate_functional_correctness(
41
- sample: dict[str, str],
42
- completions: list[str],
41
+ sample: Dict[str, str],
42
+ completions: List[str],
43
43
  n_workers: int = 4,
44
44
  timeout: float = 3.0,
45
45
  ):
@@ -70,7 +70,7 @@ class HumanEval(Eval):
70
70
  num_examples: int | None,
71
71
  num_threads: int,
72
72
  num_samples_per_task: int = 5,
73
- ks_passes: list[int] = [1, 2, 5],
73
+ ks_passes: List[int] = [1, 2, 5],
74
74
  timeout: int = 120,
75
75
  ):
76
76
  self.seed = 0
@@ -97,7 +97,7 @@ class HumanEval(Eval):
97
97
  ] # remove signature
98
98
  return extracted_answer
99
99
 
100
- def fn(sample: dict[str, str]):
100
+ def fn(sample: Dict[str, str]):
101
101
  prompt_messages = [
102
102
  sampler._pack_message(
103
103
  role="user", content=instruction + sample["prompt"]
@@ -0,0 +1,72 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ Measuring Mathematical Problem Solving With the MATH Dataset
5
+ Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
6
+ https://arxiv.org/abs/2103.03874
7
+ """
8
+
9
+ import random
10
+ import re
11
+
12
+ import pandas
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ ANSWER_PATTERN,
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ SamplerBase,
21
+ SingleEvalResult,
22
+ check_equality,
23
+ )
24
+
25
+ QUERY_TEMPLATE = """
26
+ Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
27
+
28
+ {Question}
29
+
30
+ Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
31
+ """.strip()
32
+
33
+
34
+ class MathEval(Eval):
35
+ def __init__(
36
+ self,
37
+ filename: str,
38
+ equality_checker: SamplerBase,
39
+ num_examples: int | None,
40
+ num_threads: int,
41
+ ):
42
+ df = pandas.read_csv(filename)
43
+ examples = [row.to_dict() for _, row in df.iterrows()]
44
+ if num_examples:
45
+ examples = random.Random(0).sample(examples, num_examples)
46
+ self.examples = examples
47
+ self.equality_checker = equality_checker
48
+ self.num_threads = num_threads
49
+
50
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
51
+ def fn(row: dict):
52
+ prompt_messages = [
53
+ sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
54
+ ]
55
+ response_text = sampler(prompt_messages)
56
+ match = re.search(ANSWER_PATTERN, response_text)
57
+ extracted_answer = match.group(1) if match else None
58
+ score = float(
59
+ check_equality(self.equality_checker, row["Answer"], extracted_answer)
60
+ )
61
+ html = common.jinja_env.from_string(HTML_JINJA).render(
62
+ prompt_messages=prompt_messages,
63
+ next_message=dict(content=response_text, role="assistant"),
64
+ score=score,
65
+ correct_answer=row["Answer"],
66
+ extracted_answer=extracted_answer,
67
+ )
68
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
69
+ return SingleEvalResult(html=html, score=score, convo=convo)
70
+
71
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
72
+ return common.aggregate_results(results)
sglang/test/test_utils.py CHANGED
@@ -1,9 +1,14 @@
1
1
  """Common utilities for testing and benchmarking"""
2
2
 
3
+ import argparse
3
4
  import asyncio
5
+ import multiprocessing
4
6
  import subprocess
7
+ import threading
5
8
  import time
9
+ import unittest
6
10
  from functools import partial
11
+ from typing import Callable, List, Optional
7
12
 
8
13
  import numpy as np
9
14
  import requests
@@ -13,7 +18,7 @@ from sglang.lang.backend.openai import OpenAI
13
18
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
14
19
  from sglang.utils import get_exception_traceback
15
20
 
16
- MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
21
+ DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
17
22
 
18
23
 
19
24
  def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
247
252
  return choices.index(answer)
248
253
 
249
254
 
250
- def add_common_other_args_and_parse(parser):
255
+ def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
251
256
  parser.add_argument("--parallel", type=int, default=64)
252
257
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
253
258
  parser.add_argument("--port", type=int, default=None)
@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
286
291
  return args
287
292
 
288
293
 
289
- def add_common_sglang_args_and_parse(parser):
294
+ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
290
295
  parser.add_argument("--parallel", type=int, default=64)
291
296
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
292
297
  parser.add_argument("--port", type=int, default=30000)
@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
296
301
  return args
297
302
 
298
303
 
299
- def select_sglang_backend(args):
304
+ def select_sglang_backend(args: argparse.Namespace):
300
305
  if args.backend.startswith("srt"):
301
306
  if args.backend == "srt-no-parallel":
302
307
  global_config.enable_parallel_decoding = False
@@ -309,7 +314,7 @@ def select_sglang_backend(args):
309
314
  return backend
310
315
 
311
316
 
312
- def _get_call_generate(args):
317
+ def _get_call_generate(args: argparse.Namespace):
313
318
  if args.backend == "lightllm":
314
319
  return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
315
320
  elif args.backend == "vllm":
@@ -336,7 +341,7 @@ def _get_call_generate(args):
336
341
  raise ValueError(f"Invalid backend: {args.backend}")
337
342
 
338
343
 
339
- def _get_call_select(args):
344
+ def _get_call_select(args: argparse.Namespace):
340
345
  if args.backend == "lightllm":
341
346
  return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
342
347
  elif args.backend == "vllm":
@@ -359,7 +364,7 @@ def _get_call_select(args):
359
364
  raise ValueError(f"Invalid backend: {args.backend}")
360
365
 
361
366
 
362
- def get_call_generate(args):
367
+ def get_call_generate(args: argparse.Namespace):
363
368
  call_generate = _get_call_generate(args)
364
369
 
365
370
  def func(*args, **kwargs):
@@ -372,7 +377,7 @@ def get_call_generate(args):
372
377
  return func
373
378
 
374
379
 
375
- def get_call_select(args):
380
+ def get_call_select(args: argparse.Namespace):
376
381
  call_select = _get_call_select(args)
377
382
 
378
383
  def func(*args, **kwargs):
@@ -385,7 +390,16 @@ def get_call_select(args):
385
390
  return func
386
391
 
387
392
 
388
- def popen_launch_server(model, port, timeout, *args):
393
+ def popen_launch_server(
394
+ model: str,
395
+ base_url: str,
396
+ timeout: float,
397
+ api_key: Optional[str] = None,
398
+ other_args: tuple = (),
399
+ ):
400
+ _, host, port = base_url.split(":")
401
+ host = host[2:]
402
+
389
403
  command = [
390
404
  "python3",
391
405
  "-m",
@@ -393,21 +407,88 @@ def popen_launch_server(model, port, timeout, *args):
393
407
  "--model-path",
394
408
  model,
395
409
  "--host",
396
- "localhost",
410
+ host,
397
411
  "--port",
398
- str(port),
399
- *args,
412
+ port,
413
+ *other_args,
400
414
  ]
415
+ if api_key:
416
+ command += ["--api-key", api_key]
417
+
401
418
  process = subprocess.Popen(command, stdout=None, stderr=None)
402
- base_url = f"http://localhost:{port}/v1"
403
419
 
404
420
  start_time = time.time()
405
421
  while time.time() - start_time < timeout:
406
422
  try:
407
- response = requests.get(f"{base_url}/models")
423
+ headers = {
424
+ "Content-Type": "application/json; charset=utf-8",
425
+ "Authorization": f"Bearer {api_key}",
426
+ }
427
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
408
428
  if response.status_code == 200:
409
429
  return process
410
430
  except requests.RequestException:
411
431
  pass
412
432
  time.sleep(10)
413
433
  raise TimeoutError("Server failed to start within the timeout period.")
434
+
435
+
436
+ def run_with_timeout(
437
+ func: Callable,
438
+ args: tuple = (),
439
+ kwargs: Optional[dict] = None,
440
+ timeout: float = None,
441
+ ):
442
+ """Run a function with timeout."""
443
+ ret_value = []
444
+
445
+ def _target_func():
446
+ ret_value.append(func(*args, **(kwargs or {})))
447
+
448
+ t = threading.Thread(target=_target_func)
449
+ t.start()
450
+ t.join(timeout=timeout)
451
+ if t.is_alive():
452
+ raise TimeoutError()
453
+
454
+ if not ret_value:
455
+ raise RuntimeError()
456
+
457
+ return ret_value[0]
458
+
459
+
460
+ def run_unittest_files(files: List[str], timeout_per_file: float):
461
+ tic = time.time()
462
+ success = True
463
+
464
+ for filename in files:
465
+
466
+ def func():
467
+ print(f"\n\nRun {filename}\n\n")
468
+ ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
469
+
470
+ p = multiprocessing.Process(target=func)
471
+
472
+ def run_one_file():
473
+ p.start()
474
+ p.join()
475
+
476
+ try:
477
+ run_with_timeout(run_one_file, timeout=timeout_per_file)
478
+ if p.exitcode != 0:
479
+ success = False
480
+ break
481
+ except TimeoutError:
482
+ p.terminate()
483
+ time.sleep(5)
484
+ print(
485
+ "\nTimeout after {timeout_per_file} seconds when running {filename}\n"
486
+ )
487
+ return False
488
+
489
+ if success:
490
+ print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
491
+ else:
492
+ print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
493
+
494
+ return 0 if success else -1