sglang 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -5
- sglang/lang/interpreter.py +2 -1
- sglang/lang/ir.py +0 -1
- sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +2 -2
- sglang/srt/constrained/fsm_cache.py +2 -2
- sglang/srt/constrained/jump_forward.py +2 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +29 -9
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/managers/tp_worker.py +29 -6
- sglang/srt/mem_cache/base_cache.py +43 -0
- sglang/srt/mem_cache/chunk_cache.py +60 -0
- sglang/srt/mem_cache/radix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +17 -2
- sglang/srt/models/llama2.py +5 -21
- sglang/srt/openai_api/adapter.py +76 -22
- sglang/srt/openai_api/protocol.py +20 -2
- sglang/srt/server.py +9 -14
- sglang/srt/server_args.py +18 -4
- sglang/srt/utils.py +20 -0
- sglang/test/run_eval.py +104 -0
- sglang/test/simple_eval_common.py +467 -0
- sglang/test/simple_eval_humaneval.py +139 -0
- sglang/test/simple_eval_mmlu.py +120 -0
- sglang/test/test_programs.py +12 -9
- sglang/test/test_utils.py +32 -0
- sglang/version.py +1 -1
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/METADATA +4 -4
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/RECORD +32 -28
- sglang/test/test_conversation.py +0 -46
- sglang/test/test_openai_protocol.py +0 -51
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/LICENSE +0 -0
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/WHEEL +0 -0
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
|
|
43
43
|
ChatCompletionResponseChoice,
|
44
44
|
ChatCompletionResponseStreamChoice,
|
45
45
|
ChatCompletionStreamResponse,
|
46
|
+
ChatCompletionTokenLogprob,
|
46
47
|
ChatMessage,
|
48
|
+
ChoiceLogprobs,
|
47
49
|
CompletionRequest,
|
48
50
|
CompletionResponse,
|
49
51
|
CompletionResponseChoice,
|
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
|
|
54
56
|
FileRequest,
|
55
57
|
FileResponse,
|
56
58
|
LogProbs,
|
59
|
+
TopLogprob,
|
57
60
|
UsageInfo,
|
58
61
|
)
|
59
62
|
|
@@ -70,7 +73,7 @@ class FileMetadata:
|
|
70
73
|
batch_storage: Dict[str, BatchResponse] = {}
|
71
74
|
file_id_request: Dict[str, FileMetadata] = {}
|
72
75
|
file_id_response: Dict[str, FileResponse] = {}
|
73
|
-
|
76
|
+
# map file id to file path in SGlang backend
|
74
77
|
file_id_storage: Dict[str, str] = {}
|
75
78
|
|
76
79
|
|
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
261
264
|
failed_requests += len(file_request_list)
|
262
265
|
|
263
266
|
for idx, response in enumerate(responses):
|
264
|
-
|
267
|
+
# the batch_req here can be changed to be named within a batch granularity
|
265
268
|
response_json = {
|
266
269
|
"id": f"batch_req_{uuid.uuid4()}",
|
267
270
|
"custom_id": file_request_list[idx].get("custom_id"),
|
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
|
|
333
336
|
|
334
337
|
prompts = []
|
335
338
|
sampling_params_list = []
|
339
|
+
return_logprobs = []
|
340
|
+
top_logprobs_nums = []
|
336
341
|
first_prompt_type = type(all_requests[0].prompt)
|
337
342
|
for request in all_requests:
|
338
343
|
prompt = request.prompt
|
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
|
|
340
345
|
type(prompt) == first_prompt_type
|
341
346
|
), "All prompts must be of the same type in file input settings"
|
342
347
|
prompts.append(prompt)
|
348
|
+
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
349
|
+
top_logprobs_nums.append(
|
350
|
+
request.logprobs if request.logprobs is not None else 0
|
351
|
+
)
|
343
352
|
sampling_params_list.append(
|
344
353
|
{
|
345
354
|
"temperature": request.temperature,
|
@@ -361,7 +370,9 @@ def v1_generate_request(all_requests):
|
|
361
370
|
if len(all_requests) == 1:
|
362
371
|
prompt = prompts[0]
|
363
372
|
sampling_params_list = sampling_params_list[0]
|
364
|
-
|
373
|
+
return_logprobs = return_logprobs[0]
|
374
|
+
top_logprobs_nums = top_logprobs_nums[0]
|
375
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
365
376
|
prompt_kwargs = {"text": prompt}
|
366
377
|
else:
|
367
378
|
prompt_kwargs = {"input_ids": prompt}
|
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
|
|
370
381
|
prompt_kwargs = {"text": prompts}
|
371
382
|
else:
|
372
383
|
prompt_kwargs = {"input_ids": prompts}
|
373
|
-
|
374
384
|
adapted_request = GenerateReqInput(
|
375
385
|
**prompt_kwargs,
|
376
386
|
sampling_params=sampling_params_list,
|
377
|
-
return_logprob=
|
378
|
-
|
379
|
-
top_logprobs_num=(
|
380
|
-
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
381
|
-
),
|
387
|
+
return_logprob=return_logprobs,
|
388
|
+
top_logprobs_num=top_logprobs_nums,
|
382
389
|
return_text_in_logprobs=True,
|
383
390
|
stream=all_requests[0].stream,
|
384
391
|
)
|
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
|
|
430
437
|
logprobs = None
|
431
438
|
|
432
439
|
if to_file:
|
433
|
-
|
440
|
+
# to make the choise data json serializable
|
434
441
|
choice_data = {
|
435
442
|
"index": 0,
|
436
443
|
"text": text,
|
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
|
|
454
461
|
"status_code": 200,
|
455
462
|
"request_id": ret[i]["meta_info"]["id"],
|
456
463
|
"body": {
|
457
|
-
|
464
|
+
# remain the same but if needed we can change that
|
458
465
|
"id": ret[i]["meta_info"]["id"],
|
459
466
|
"object": "text_completion",
|
460
467
|
"created": int(time.time()),
|
@@ -587,9 +594,11 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
587
594
|
|
588
595
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
589
596
|
|
590
|
-
|
597
|
+
input_ids = []
|
591
598
|
sampling_params_list = []
|
592
599
|
image_data_list = []
|
600
|
+
return_logprobs = []
|
601
|
+
top_logprobs_nums = []
|
593
602
|
for request in all_requests:
|
594
603
|
# Prep the data needed for the underlying GenerateReqInput:
|
595
604
|
# - prompt: The full prompt string.
|
@@ -599,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
599
608
|
if not isinstance(request.messages, str):
|
600
609
|
# Apply chat template and its stop strings.
|
601
610
|
if chat_template_name is None:
|
602
|
-
|
603
|
-
request.messages, tokenize=
|
611
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
612
|
+
request.messages, tokenize=True, add_generation_prompt=True
|
604
613
|
)
|
605
614
|
stop = request.stop
|
606
615
|
image_data = None
|
@@ -614,12 +623,15 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
614
623
|
stop.append(request.stop)
|
615
624
|
else:
|
616
625
|
stop.extend(request.stop)
|
626
|
+
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
617
627
|
else:
|
618
628
|
# Use the raw prompt and stop strings if the messages is already a string.
|
619
629
|
prompt = request.messages
|
620
630
|
stop = request.stop
|
621
631
|
image_data = None
|
622
|
-
|
632
|
+
input_ids.append(prompt_ids)
|
633
|
+
return_logprobs.append(request.logprobs)
|
634
|
+
top_logprobs_nums.append(request.top_logprobs)
|
623
635
|
sampling_params_list.append(
|
624
636
|
{
|
625
637
|
"temperature": request.temperature,
|
@@ -634,14 +646,19 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
634
646
|
)
|
635
647
|
image_data_list.append(image_data)
|
636
648
|
if len(all_requests) == 1:
|
637
|
-
|
649
|
+
input_ids = input_ids[0]
|
638
650
|
sampling_params_list = sampling_params_list[0]
|
639
651
|
image_data = image_data_list[0]
|
652
|
+
return_logprobs = return_logprobs[0]
|
653
|
+
top_logprobs_nums = top_logprobs_nums[0]
|
640
654
|
adapted_request = GenerateReqInput(
|
641
|
-
|
655
|
+
input_ids=input_ids,
|
642
656
|
image_data=image_data,
|
643
657
|
sampling_params=sampling_params_list,
|
644
|
-
|
658
|
+
return_logprob=return_logprobs,
|
659
|
+
top_logprobs_num=top_logprobs_nums,
|
660
|
+
stream=all_requests[0].stream,
|
661
|
+
return_text_in_logprobs=True,
|
645
662
|
)
|
646
663
|
if len(all_requests) == 1:
|
647
664
|
return adapted_request, all_requests[0]
|
@@ -654,26 +671,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
654
671
|
total_completion_tokens = 0
|
655
672
|
|
656
673
|
for idx, ret_item in enumerate(ret):
|
674
|
+
logprobs = False
|
675
|
+
if isinstance(request, List) and request[idx].logprobs:
|
676
|
+
logprobs = True
|
677
|
+
elif (not isinstance(request, List)) and request.logprobs:
|
678
|
+
logprobs = True
|
679
|
+
if logprobs:
|
680
|
+
logprobs = to_openai_style_logprobs(
|
681
|
+
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
682
|
+
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
683
|
+
)
|
684
|
+
token_logprobs = []
|
685
|
+
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
|
686
|
+
token_bytes = list(token.encode("utf-8"))
|
687
|
+
top_logprobs = []
|
688
|
+
if logprobs.top_logprobs:
|
689
|
+
for top_token, top_logprob in logprobs.top_logprobs[0].items():
|
690
|
+
top_token_bytes = list(top_token.encode("utf-8"))
|
691
|
+
top_logprobs.append(
|
692
|
+
TopLogprob(
|
693
|
+
token=top_token,
|
694
|
+
bytes=top_token_bytes,
|
695
|
+
logprob=top_logprob,
|
696
|
+
)
|
697
|
+
)
|
698
|
+
token_logprobs.append(
|
699
|
+
ChatCompletionTokenLogprob(
|
700
|
+
token=token,
|
701
|
+
bytes=token_bytes,
|
702
|
+
logprob=logprob,
|
703
|
+
top_logprobs=top_logprobs,
|
704
|
+
)
|
705
|
+
)
|
706
|
+
|
707
|
+
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
708
|
+
else:
|
709
|
+
choice_logprobs = None
|
657
710
|
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
658
711
|
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
659
712
|
|
660
713
|
if to_file:
|
661
|
-
|
714
|
+
# to make the choice data json serializable
|
662
715
|
choice_data = {
|
663
716
|
"index": 0,
|
664
717
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
665
|
-
"logprobs":
|
718
|
+
"logprobs": choice_logprobs,
|
666
719
|
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
667
720
|
}
|
668
721
|
else:
|
669
722
|
choice_data = ChatCompletionResponseChoice(
|
670
723
|
index=idx,
|
671
724
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
725
|
+
logprobs=choice_logprobs,
|
672
726
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
673
727
|
)
|
674
728
|
|
675
729
|
choices.append(choice_data)
|
676
|
-
total_prompt_tokens
|
730
|
+
total_prompt_tokens += prompt_tokens
|
677
731
|
total_completion_tokens += completion_tokens
|
678
732
|
if to_file:
|
679
733
|
responses = []
|
@@ -683,7 +737,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
683
737
|
"status_code": 200,
|
684
738
|
"request_id": ret[i]["meta_info"]["id"],
|
685
739
|
"body": {
|
686
|
-
|
740
|
+
# remain the same but if needed we can change that
|
687
741
|
"id": ret[i]["meta_info"]["id"],
|
688
742
|
"object": "chat.completion",
|
689
743
|
"created": int(time.time()),
|
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
|
|
54
54
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
55
55
|
|
56
56
|
|
57
|
+
class TopLogprob(BaseModel):
|
58
|
+
token: str
|
59
|
+
bytes: List[int]
|
60
|
+
logprob: float
|
61
|
+
|
62
|
+
|
63
|
+
class ChatCompletionTokenLogprob(BaseModel):
|
64
|
+
token: str
|
65
|
+
bytes: List[int]
|
66
|
+
logprob: float
|
67
|
+
top_logprobs: List[TopLogprob]
|
68
|
+
|
69
|
+
|
70
|
+
class ChoiceLogprobs(BaseModel):
|
71
|
+
# build for v1/chat/completions response
|
72
|
+
content: List[ChatCompletionTokenLogprob]
|
73
|
+
|
74
|
+
|
57
75
|
class UsageInfo(BaseModel):
|
58
76
|
prompt_tokens: int = 0
|
59
77
|
total_tokens: int = 0
|
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
|
|
239
257
|
class ChatCompletionResponseChoice(BaseModel):
|
240
258
|
index: int
|
241
259
|
message: ChatMessage
|
242
|
-
logprobs: Optional[LogProbs] = None
|
243
|
-
finish_reason:
|
260
|
+
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
261
|
+
finish_reason: str
|
244
262
|
|
245
263
|
|
246
264
|
class ChatCompletionResponse(BaseModel):
|
sglang/srt/server.py
CHANGED
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
allocate_init_ports,
|
73
73
|
assert_pkg_version,
|
74
74
|
enable_show_time_cost,
|
75
|
+
kill_child_process,
|
75
76
|
maybe_set_triton_cache_manager,
|
76
77
|
set_ulimit,
|
77
78
|
)
|
@@ -189,10 +190,10 @@ async def retrieve_file_content(file_id: str):
|
|
189
190
|
@app.get("/v1/models")
|
190
191
|
def available_models():
|
191
192
|
"""Show available models."""
|
192
|
-
|
193
|
+
served_model_names = [tokenizer_manager.served_model_name]
|
193
194
|
model_cards = []
|
194
|
-
for
|
195
|
-
model_cards.append(ModelCard(id=
|
195
|
+
for served_model_name in served_model_names:
|
196
|
+
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
196
197
|
return ModelList(data=model_cards)
|
197
198
|
|
198
199
|
|
@@ -260,7 +261,7 @@ def launch_server(
|
|
260
261
|
if not server_args.disable_flashinfer:
|
261
262
|
assert_pkg_version(
|
262
263
|
"flashinfer",
|
263
|
-
"0.1.
|
264
|
+
"0.1.3",
|
264
265
|
"Please uninstall the old version and "
|
265
266
|
"reinstall the latest version by following the instructions "
|
266
267
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -467,18 +468,12 @@ class Runtime:
|
|
467
468
|
|
468
469
|
def shutdown(self):
|
469
470
|
if self.pid is not None:
|
470
|
-
|
471
|
-
parent = psutil.Process(self.pid)
|
472
|
-
except psutil.NoSuchProcess:
|
473
|
-
return
|
474
|
-
children = parent.children(recursive=True)
|
475
|
-
for child in children:
|
476
|
-
child.kill()
|
477
|
-
psutil.wait_procs(children, timeout=5)
|
478
|
-
parent.kill()
|
479
|
-
parent.wait(timeout=5)
|
471
|
+
kill_child_process(self.pid)
|
480
472
|
self.pid = None
|
481
473
|
|
474
|
+
def cache_prefix(self, prefix: str):
|
475
|
+
self.endpoint.cache_prefix(prefix)
|
476
|
+
|
482
477
|
def get_tokenizer(self):
|
483
478
|
return get_tokenizer(
|
484
479
|
self.server_args.tokenizer_path,
|
sglang/srt/server_args.py
CHANGED
@@ -32,6 +32,7 @@ class ServerArgs:
|
|
32
32
|
trust_remote_code: bool = True
|
33
33
|
context_length: Optional[int] = None
|
34
34
|
quantization: Optional[str] = None
|
35
|
+
served_model_name: Optional[str] = None
|
35
36
|
chat_template: Optional[str] = None
|
36
37
|
|
37
38
|
# Port
|
@@ -44,6 +45,7 @@ class ServerArgs:
|
|
44
45
|
max_prefill_tokens: Optional[int] = None
|
45
46
|
max_running_requests: Optional[int] = None
|
46
47
|
max_num_reqs: Optional[int] = None
|
48
|
+
max_total_tokens: Optional[int] = None
|
47
49
|
schedule_policy: str = "lpm"
|
48
50
|
schedule_conservativeness: float = 1.0
|
49
51
|
|
@@ -89,6 +91,10 @@ class ServerArgs:
|
|
89
91
|
def __post_init__(self):
|
90
92
|
if self.tokenizer_path is None:
|
91
93
|
self.tokenizer_path = self.model_path
|
94
|
+
|
95
|
+
if self.served_model_name is None:
|
96
|
+
self.served_model_name = self.model_path
|
97
|
+
|
92
98
|
if self.mem_fraction_static is None:
|
93
99
|
if self.tp_size >= 16:
|
94
100
|
self.mem_fraction_static = 0.79
|
@@ -201,6 +207,12 @@ class ServerArgs:
|
|
201
207
|
],
|
202
208
|
help="The quantization method.",
|
203
209
|
)
|
210
|
+
parser.add_argument(
|
211
|
+
"--served-model-name",
|
212
|
+
type=str,
|
213
|
+
default=ServerArgs.served_model_name,
|
214
|
+
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
215
|
+
)
|
204
216
|
parser.add_argument(
|
205
217
|
"--chat-template",
|
206
218
|
type=str,
|
@@ -231,6 +243,12 @@ class ServerArgs:
|
|
231
243
|
default=ServerArgs.max_num_reqs,
|
232
244
|
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
233
245
|
)
|
246
|
+
parser.add_argument(
|
247
|
+
"--max-total-tokens",
|
248
|
+
type=int,
|
249
|
+
default=ServerArgs.max_total_tokens,
|
250
|
+
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
251
|
+
)
|
234
252
|
parser.add_argument(
|
235
253
|
"--schedule-policy",
|
236
254
|
type=str,
|
@@ -412,10 +430,6 @@ class ServerArgs:
|
|
412
430
|
self.dp_size > 1 and self.node_rank is not None
|
413
431
|
), "multi-node data parallel is not supported"
|
414
432
|
|
415
|
-
assert not (
|
416
|
-
self.chunked_prefill_size is not None and self.disable_radix_cache
|
417
|
-
), "chunked prefill is not supported with radix cache disabled currently"
|
418
|
-
|
419
433
|
|
420
434
|
@dataclasses.dataclass
|
421
435
|
class PortArgs:
|
sglang/srt/utils.py
CHANGED
@@ -366,6 +366,26 @@ def kill_parent_process():
|
|
366
366
|
os.kill(parent_process.pid, 9)
|
367
367
|
|
368
368
|
|
369
|
+
def kill_child_process(pid, including_parent=True):
|
370
|
+
try:
|
371
|
+
parent = psutil.Process(pid)
|
372
|
+
except psutil.NoSuchProcess:
|
373
|
+
return
|
374
|
+
|
375
|
+
children = parent.children(recursive=True)
|
376
|
+
for child in children:
|
377
|
+
try:
|
378
|
+
child.kill()
|
379
|
+
except psutil.NoSuchProcess:
|
380
|
+
pass
|
381
|
+
|
382
|
+
if including_parent:
|
383
|
+
try:
|
384
|
+
parent.kill()
|
385
|
+
except psutil.NoSuchProcess:
|
386
|
+
pass
|
387
|
+
|
388
|
+
|
369
389
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
370
390
|
"""
|
371
391
|
Monkey patch the slow p2p access check in vllm.
|
sglang/test/run_eval.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import time
|
10
|
+
|
11
|
+
from sglang.test.simple_eval_common import (
|
12
|
+
ChatCompletionSampler,
|
13
|
+
download_dataset,
|
14
|
+
make_report,
|
15
|
+
set_ulimit,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
def run_eval(args):
|
20
|
+
if "OPENAI_API_KEY" not in os.environ:
|
21
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
22
|
+
|
23
|
+
base_url = (
|
24
|
+
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
25
|
+
)
|
26
|
+
|
27
|
+
if args.eval_name == "mmlu":
|
28
|
+
from sglang.test.simple_eval_mmlu import MMLUEval
|
29
|
+
|
30
|
+
dataset_path = "mmlu.csv"
|
31
|
+
|
32
|
+
if not os.path.exists(dataset_path):
|
33
|
+
download_dataset(
|
34
|
+
dataset_path,
|
35
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
36
|
+
)
|
37
|
+
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
38
|
+
elif args.eval_name == "humaneval":
|
39
|
+
from sglang.test.simple_eval_humaneval import HumanEval
|
40
|
+
|
41
|
+
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
42
|
+
else:
|
43
|
+
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
44
|
+
|
45
|
+
sampler = ChatCompletionSampler(
|
46
|
+
model=args.model,
|
47
|
+
max_tokens=2048,
|
48
|
+
base_url=base_url,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Run eval
|
52
|
+
tic = time.time()
|
53
|
+
result = eval_obj(sampler)
|
54
|
+
latency = time.time() - tic
|
55
|
+
|
56
|
+
# Dump reports
|
57
|
+
metrics = result.metrics | {"score": result.score}
|
58
|
+
file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
|
59
|
+
report_filename = f"/tmp/{file_stem}.html"
|
60
|
+
print(f"Writing report to {report_filename}")
|
61
|
+
with open(report_filename, "w") as fh:
|
62
|
+
fh.write(make_report(result))
|
63
|
+
metrics = result.metrics | {"score": result.score}
|
64
|
+
print(metrics)
|
65
|
+
result_filename = f"/tmp/{file_stem}.json"
|
66
|
+
with open(result_filename, "w") as f:
|
67
|
+
f.write(json.dumps(metrics, indent=2))
|
68
|
+
print(f"Writing results to {result_filename}")
|
69
|
+
|
70
|
+
# Print results
|
71
|
+
print(f"Total latency: {latency:.3f} s")
|
72
|
+
print(f"Score: {metrics['score']:.3f}")
|
73
|
+
|
74
|
+
return metrics
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == "__main__":
|
78
|
+
parser = argparse.ArgumentParser()
|
79
|
+
parser.add_argument(
|
80
|
+
"--base-url",
|
81
|
+
type=str,
|
82
|
+
default=None,
|
83
|
+
help="Server or API base url if not using http host and port.",
|
84
|
+
)
|
85
|
+
parser.add_argument(
|
86
|
+
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
87
|
+
)
|
88
|
+
parser.add_argument(
|
89
|
+
"--port",
|
90
|
+
type=int,
|
91
|
+
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
92
|
+
)
|
93
|
+
parser.add_argument(
|
94
|
+
"--model",
|
95
|
+
type=str,
|
96
|
+
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
97
|
+
)
|
98
|
+
parser.add_argument("--eval-name", type=str, default="mmlu")
|
99
|
+
parser.add_argument("--num-examples", type=int)
|
100
|
+
parser.add_argument("--num-threads", type=int, default=64)
|
101
|
+
set_ulimit()
|
102
|
+
args = parser.parse_args()
|
103
|
+
|
104
|
+
run_eval(args)
|