sglang 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
43
43
  ChatCompletionResponseChoice,
44
44
  ChatCompletionResponseStreamChoice,
45
45
  ChatCompletionStreamResponse,
46
+ ChatCompletionTokenLogprob,
46
47
  ChatMessage,
48
+ ChoiceLogprobs,
47
49
  CompletionRequest,
48
50
  CompletionResponse,
49
51
  CompletionResponseChoice,
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
54
56
  FileRequest,
55
57
  FileResponse,
56
58
  LogProbs,
59
+ TopLogprob,
57
60
  UsageInfo,
58
61
  )
59
62
 
@@ -70,7 +73,7 @@ class FileMetadata:
70
73
  batch_storage: Dict[str, BatchResponse] = {}
71
74
  file_id_request: Dict[str, FileMetadata] = {}
72
75
  file_id_response: Dict[str, FileResponse] = {}
73
- ## map file id to file path in SGlang backend
76
+ # map file id to file path in SGlang backend
74
77
  file_id_storage: Dict[str, str] = {}
75
78
 
76
79
 
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
261
264
  failed_requests += len(file_request_list)
262
265
 
263
266
  for idx, response in enumerate(responses):
264
- ## the batch_req here can be changed to be named within a batch granularity
267
+ # the batch_req here can be changed to be named within a batch granularity
265
268
  response_json = {
266
269
  "id": f"batch_req_{uuid.uuid4()}",
267
270
  "custom_id": file_request_list[idx].get("custom_id"),
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
333
336
 
334
337
  prompts = []
335
338
  sampling_params_list = []
339
+ return_logprobs = []
340
+ top_logprobs_nums = []
336
341
  first_prompt_type = type(all_requests[0].prompt)
337
342
  for request in all_requests:
338
343
  prompt = request.prompt
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
340
345
  type(prompt) == first_prompt_type
341
346
  ), "All prompts must be of the same type in file input settings"
342
347
  prompts.append(prompt)
348
+ return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
349
+ top_logprobs_nums.append(
350
+ request.logprobs if request.logprobs is not None else 0
351
+ )
343
352
  sampling_params_list.append(
344
353
  {
345
354
  "temperature": request.temperature,
@@ -361,7 +370,9 @@ def v1_generate_request(all_requests):
361
370
  if len(all_requests) == 1:
362
371
  prompt = prompts[0]
363
372
  sampling_params_list = sampling_params_list[0]
364
- if isinstance(prompts, str) or isinstance(prompts[0], str):
373
+ return_logprobs = return_logprobs[0]
374
+ top_logprobs_nums = top_logprobs_nums[0]
375
+ if isinstance(prompt, str) or isinstance(prompt[0], str):
365
376
  prompt_kwargs = {"text": prompt}
366
377
  else:
367
378
  prompt_kwargs = {"input_ids": prompt}
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
370
381
  prompt_kwargs = {"text": prompts}
371
382
  else:
372
383
  prompt_kwargs = {"input_ids": prompts}
373
-
374
384
  adapted_request = GenerateReqInput(
375
385
  **prompt_kwargs,
376
386
  sampling_params=sampling_params_list,
377
- return_logprob=all_requests[0].logprobs is not None
378
- and all_requests[0].logprobs > 0,
379
- top_logprobs_num=(
380
- all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
381
- ),
387
+ return_logprob=return_logprobs,
388
+ top_logprobs_num=top_logprobs_nums,
382
389
  return_text_in_logprobs=True,
383
390
  stream=all_requests[0].stream,
384
391
  )
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
430
437
  logprobs = None
431
438
 
432
439
  if to_file:
433
- ## to make the choise data json serializable
440
+ # to make the choise data json serializable
434
441
  choice_data = {
435
442
  "index": 0,
436
443
  "text": text,
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
454
461
  "status_code": 200,
455
462
  "request_id": ret[i]["meta_info"]["id"],
456
463
  "body": {
457
- ## remain the same but if needed we can change that
464
+ # remain the same but if needed we can change that
458
465
  "id": ret[i]["meta_info"]["id"],
459
466
  "object": "text_completion",
460
467
  "created": int(time.time()),
@@ -587,9 +594,11 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
587
594
 
588
595
  def v1_chat_generate_request(all_requests, tokenizer_manager):
589
596
 
590
- texts = []
597
+ input_ids = []
591
598
  sampling_params_list = []
592
599
  image_data_list = []
600
+ return_logprobs = []
601
+ top_logprobs_nums = []
593
602
  for request in all_requests:
594
603
  # Prep the data needed for the underlying GenerateReqInput:
595
604
  # - prompt: The full prompt string.
@@ -599,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
599
608
  if not isinstance(request.messages, str):
600
609
  # Apply chat template and its stop strings.
601
610
  if chat_template_name is None:
602
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
603
- request.messages, tokenize=False, add_generation_prompt=True
611
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
612
+ request.messages, tokenize=True, add_generation_prompt=True
604
613
  )
605
614
  stop = request.stop
606
615
  image_data = None
@@ -614,12 +623,15 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
614
623
  stop.append(request.stop)
615
624
  else:
616
625
  stop.extend(request.stop)
626
+ prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
617
627
  else:
618
628
  # Use the raw prompt and stop strings if the messages is already a string.
619
629
  prompt = request.messages
620
630
  stop = request.stop
621
631
  image_data = None
622
- texts.append(prompt)
632
+ input_ids.append(prompt_ids)
633
+ return_logprobs.append(request.logprobs)
634
+ top_logprobs_nums.append(request.top_logprobs)
623
635
  sampling_params_list.append(
624
636
  {
625
637
  "temperature": request.temperature,
@@ -634,14 +646,19 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
634
646
  )
635
647
  image_data_list.append(image_data)
636
648
  if len(all_requests) == 1:
637
- texts = texts[0]
649
+ input_ids = input_ids[0]
638
650
  sampling_params_list = sampling_params_list[0]
639
651
  image_data = image_data_list[0]
652
+ return_logprobs = return_logprobs[0]
653
+ top_logprobs_nums = top_logprobs_nums[0]
640
654
  adapted_request = GenerateReqInput(
641
- text=texts,
655
+ input_ids=input_ids,
642
656
  image_data=image_data,
643
657
  sampling_params=sampling_params_list,
644
- stream=request.stream,
658
+ return_logprob=return_logprobs,
659
+ top_logprobs_num=top_logprobs_nums,
660
+ stream=all_requests[0].stream,
661
+ return_text_in_logprobs=True,
645
662
  )
646
663
  if len(all_requests) == 1:
647
664
  return adapted_request, all_requests[0]
@@ -654,26 +671,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
654
671
  total_completion_tokens = 0
655
672
 
656
673
  for idx, ret_item in enumerate(ret):
674
+ logprobs = False
675
+ if isinstance(request, List) and request[idx].logprobs:
676
+ logprobs = True
677
+ elif (not isinstance(request, List)) and request.logprobs:
678
+ logprobs = True
679
+ if logprobs:
680
+ logprobs = to_openai_style_logprobs(
681
+ output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
682
+ output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
683
+ )
684
+ token_logprobs = []
685
+ for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
686
+ token_bytes = list(token.encode("utf-8"))
687
+ top_logprobs = []
688
+ if logprobs.top_logprobs:
689
+ for top_token, top_logprob in logprobs.top_logprobs[0].items():
690
+ top_token_bytes = list(top_token.encode("utf-8"))
691
+ top_logprobs.append(
692
+ TopLogprob(
693
+ token=top_token,
694
+ bytes=top_token_bytes,
695
+ logprob=top_logprob,
696
+ )
697
+ )
698
+ token_logprobs.append(
699
+ ChatCompletionTokenLogprob(
700
+ token=token,
701
+ bytes=token_bytes,
702
+ logprob=logprob,
703
+ top_logprobs=top_logprobs,
704
+ )
705
+ )
706
+
707
+ choice_logprobs = ChoiceLogprobs(content=token_logprobs)
708
+ else:
709
+ choice_logprobs = None
657
710
  prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
658
711
  completion_tokens = ret_item["meta_info"]["completion_tokens"]
659
712
 
660
713
  if to_file:
661
- ## to make the choice data json serializable
714
+ # to make the choice data json serializable
662
715
  choice_data = {
663
716
  "index": 0,
664
717
  "message": {"role": "assistant", "content": ret_item["text"]},
665
- "logprobs": None,
718
+ "logprobs": choice_logprobs,
666
719
  "finish_reason": ret_item["meta_info"]["finish_reason"],
667
720
  }
668
721
  else:
669
722
  choice_data = ChatCompletionResponseChoice(
670
723
  index=idx,
671
724
  message=ChatMessage(role="assistant", content=ret_item["text"]),
725
+ logprobs=choice_logprobs,
672
726
  finish_reason=ret_item["meta_info"]["finish_reason"],
673
727
  )
674
728
 
675
729
  choices.append(choice_data)
676
- total_prompt_tokens = prompt_tokens
730
+ total_prompt_tokens += prompt_tokens
677
731
  total_completion_tokens += completion_tokens
678
732
  if to_file:
679
733
  responses = []
@@ -683,7 +737,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
683
737
  "status_code": 200,
684
738
  "request_id": ret[i]["meta_info"]["id"],
685
739
  "body": {
686
- ## remain the same but if needed we can change that
740
+ # remain the same but if needed we can change that
687
741
  "id": ret[i]["meta_info"]["id"],
688
742
  "object": "chat.completion",
689
743
  "created": int(time.time()),
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
54
54
  top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
55
55
 
56
56
 
57
+ class TopLogprob(BaseModel):
58
+ token: str
59
+ bytes: List[int]
60
+ logprob: float
61
+
62
+
63
+ class ChatCompletionTokenLogprob(BaseModel):
64
+ token: str
65
+ bytes: List[int]
66
+ logprob: float
67
+ top_logprobs: List[TopLogprob]
68
+
69
+
70
+ class ChoiceLogprobs(BaseModel):
71
+ # build for v1/chat/completions response
72
+ content: List[ChatCompletionTokenLogprob]
73
+
74
+
57
75
  class UsageInfo(BaseModel):
58
76
  prompt_tokens: int = 0
59
77
  total_tokens: int = 0
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
239
257
  class ChatCompletionResponseChoice(BaseModel):
240
258
  index: int
241
259
  message: ChatMessage
242
- logprobs: Optional[LogProbs] = None
243
- finish_reason: Optional[str] = None
260
+ logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
261
+ finish_reason: str
244
262
 
245
263
 
246
264
  class ChatCompletionResponse(BaseModel):
sglang/srt/server.py CHANGED
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
72
72
  allocate_init_ports,
73
73
  assert_pkg_version,
74
74
  enable_show_time_cost,
75
+ kill_child_process,
75
76
  maybe_set_triton_cache_manager,
76
77
  set_ulimit,
77
78
  )
@@ -189,10 +190,10 @@ async def retrieve_file_content(file_id: str):
189
190
  @app.get("/v1/models")
190
191
  def available_models():
191
192
  """Show available models."""
192
- model_names = [tokenizer_manager.model_path]
193
+ served_model_names = [tokenizer_manager.served_model_name]
193
194
  model_cards = []
194
- for model_name in model_names:
195
- model_cards.append(ModelCard(id=model_name, root=model_name))
195
+ for served_model_name in served_model_names:
196
+ model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
196
197
  return ModelList(data=model_cards)
197
198
 
198
199
 
@@ -260,7 +261,7 @@ def launch_server(
260
261
  if not server_args.disable_flashinfer:
261
262
  assert_pkg_version(
262
263
  "flashinfer",
263
- "0.1.2",
264
+ "0.1.3",
264
265
  "Please uninstall the old version and "
265
266
  "reinstall the latest version by following the instructions "
266
267
  "at https://docs.flashinfer.ai/installation.html.",
@@ -467,18 +468,12 @@ class Runtime:
467
468
 
468
469
  def shutdown(self):
469
470
  if self.pid is not None:
470
- try:
471
- parent = psutil.Process(self.pid)
472
- except psutil.NoSuchProcess:
473
- return
474
- children = parent.children(recursive=True)
475
- for child in children:
476
- child.kill()
477
- psutil.wait_procs(children, timeout=5)
478
- parent.kill()
479
- parent.wait(timeout=5)
471
+ kill_child_process(self.pid)
480
472
  self.pid = None
481
473
 
474
+ def cache_prefix(self, prefix: str):
475
+ self.endpoint.cache_prefix(prefix)
476
+
482
477
  def get_tokenizer(self):
483
478
  return get_tokenizer(
484
479
  self.server_args.tokenizer_path,
sglang/srt/server_args.py CHANGED
@@ -32,6 +32,7 @@ class ServerArgs:
32
32
  trust_remote_code: bool = True
33
33
  context_length: Optional[int] = None
34
34
  quantization: Optional[str] = None
35
+ served_model_name: Optional[str] = None
35
36
  chat_template: Optional[str] = None
36
37
 
37
38
  # Port
@@ -44,6 +45,7 @@ class ServerArgs:
44
45
  max_prefill_tokens: Optional[int] = None
45
46
  max_running_requests: Optional[int] = None
46
47
  max_num_reqs: Optional[int] = None
48
+ max_total_tokens: Optional[int] = None
47
49
  schedule_policy: str = "lpm"
48
50
  schedule_conservativeness: float = 1.0
49
51
 
@@ -89,6 +91,10 @@ class ServerArgs:
89
91
  def __post_init__(self):
90
92
  if self.tokenizer_path is None:
91
93
  self.tokenizer_path = self.model_path
94
+
95
+ if self.served_model_name is None:
96
+ self.served_model_name = self.model_path
97
+
92
98
  if self.mem_fraction_static is None:
93
99
  if self.tp_size >= 16:
94
100
  self.mem_fraction_static = 0.79
@@ -201,6 +207,12 @@ class ServerArgs:
201
207
  ],
202
208
  help="The quantization method.",
203
209
  )
210
+ parser.add_argument(
211
+ "--served-model-name",
212
+ type=str,
213
+ default=ServerArgs.served_model_name,
214
+ help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
215
+ )
204
216
  parser.add_argument(
205
217
  "--chat-template",
206
218
  type=str,
@@ -231,6 +243,12 @@ class ServerArgs:
231
243
  default=ServerArgs.max_num_reqs,
232
244
  help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
233
245
  )
246
+ parser.add_argument(
247
+ "--max-total-tokens",
248
+ type=int,
249
+ default=ServerArgs.max_total_tokens,
250
+ help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
251
+ )
234
252
  parser.add_argument(
235
253
  "--schedule-policy",
236
254
  type=str,
@@ -412,10 +430,6 @@ class ServerArgs:
412
430
  self.dp_size > 1 and self.node_rank is not None
413
431
  ), "multi-node data parallel is not supported"
414
432
 
415
- assert not (
416
- self.chunked_prefill_size is not None and self.disable_radix_cache
417
- ), "chunked prefill is not supported with radix cache disabled currently"
418
-
419
433
 
420
434
  @dataclasses.dataclass
421
435
  class PortArgs:
sglang/srt/utils.py CHANGED
@@ -366,6 +366,26 @@ def kill_parent_process():
366
366
  os.kill(parent_process.pid, 9)
367
367
 
368
368
 
369
+ def kill_child_process(pid, including_parent=True):
370
+ try:
371
+ parent = psutil.Process(pid)
372
+ except psutil.NoSuchProcess:
373
+ return
374
+
375
+ children = parent.children(recursive=True)
376
+ for child in children:
377
+ try:
378
+ child.kill()
379
+ except psutil.NoSuchProcess:
380
+ pass
381
+
382
+ if including_parent:
383
+ try:
384
+ parent.kill()
385
+ except psutil.NoSuchProcess:
386
+ pass
387
+
388
+
369
389
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
370
390
  """
371
391
  Monkey patch the slow p2p access check in vllm.
@@ -0,0 +1,104 @@
1
+ """
2
+ Usage:
3
+ python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import time
10
+
11
+ from sglang.test.simple_eval_common import (
12
+ ChatCompletionSampler,
13
+ download_dataset,
14
+ make_report,
15
+ set_ulimit,
16
+ )
17
+
18
+
19
+ def run_eval(args):
20
+ if "OPENAI_API_KEY" not in os.environ:
21
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
22
+
23
+ base_url = (
24
+ f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
25
+ )
26
+
27
+ if args.eval_name == "mmlu":
28
+ from sglang.test.simple_eval_mmlu import MMLUEval
29
+
30
+ dataset_path = "mmlu.csv"
31
+
32
+ if not os.path.exists(dataset_path):
33
+ download_dataset(
34
+ dataset_path,
35
+ "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
36
+ )
37
+ eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
38
+ elif args.eval_name == "humaneval":
39
+ from sglang.test.simple_eval_humaneval import HumanEval
40
+
41
+ eval_obj = HumanEval(args.num_examples, args.num_threads)
42
+ else:
43
+ raise ValueError(f"Invalid eval name: {args.eval_name}")
44
+
45
+ sampler = ChatCompletionSampler(
46
+ model=args.model,
47
+ max_tokens=2048,
48
+ base_url=base_url,
49
+ )
50
+
51
+ # Run eval
52
+ tic = time.time()
53
+ result = eval_obj(sampler)
54
+ latency = time.time() - tic
55
+
56
+ # Dump reports
57
+ metrics = result.metrics | {"score": result.score}
58
+ file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
59
+ report_filename = f"/tmp/{file_stem}.html"
60
+ print(f"Writing report to {report_filename}")
61
+ with open(report_filename, "w") as fh:
62
+ fh.write(make_report(result))
63
+ metrics = result.metrics | {"score": result.score}
64
+ print(metrics)
65
+ result_filename = f"/tmp/{file_stem}.json"
66
+ with open(result_filename, "w") as f:
67
+ f.write(json.dumps(metrics, indent=2))
68
+ print(f"Writing results to {result_filename}")
69
+
70
+ # Print results
71
+ print(f"Total latency: {latency:.3f} s")
72
+ print(f"Score: {metrics['score']:.3f}")
73
+
74
+ return metrics
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument(
80
+ "--base-url",
81
+ type=str,
82
+ default=None,
83
+ help="Server or API base url if not using http host and port.",
84
+ )
85
+ parser.add_argument(
86
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
87
+ )
88
+ parser.add_argument(
89
+ "--port",
90
+ type=int,
91
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
92
+ )
93
+ parser.add_argument(
94
+ "--model",
95
+ type=str,
96
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
97
+ )
98
+ parser.add_argument("--eval-name", type=str, default="mmlu")
99
+ parser.add_argument("--num-examples", type=int)
100
+ parser.add_argument("--num-threads", type=int, default=64)
101
+ set_ulimit()
102
+ args = parser.parse_args()
103
+
104
+ run_eval(args)