sglang 0.2.8__py3-none-any.whl → 0.2.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/bench_serving.py CHANGED
@@ -21,7 +21,7 @@ import sys
21
21
  import time
22
22
  import traceback
23
23
  import warnings
24
- from argparse import ArgumentParser as FlexibleArgumentParser
24
+ from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
27
  from typing import AsyncGenerator, List, Optional, Tuple, Union
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
868
868
 
869
869
 
870
870
  if __name__ == "__main__":
871
- parser = FlexibleArgumentParser(
872
- description="Benchmark the online serving throughput."
873
- )
871
+ parser = ArgumentParser(description="Benchmark the online serving throughput.")
874
872
  parser.add_argument(
875
873
  "--backend",
876
874
  type=str,
877
- required=True,
878
875
  choices=list(ASYNC_REQUEST_FUNCS.keys()),
876
+ default="sglang",
879
877
  help="Must specify a backend, depending on the LLM Inference Engine.",
880
878
  )
881
879
  parser.add_argument(
sglang/check_env.py CHANGED
@@ -30,6 +30,7 @@ PACKAGE_LIST = [
30
30
  "zmq",
31
31
  "vllm",
32
32
  "outlines",
33
+ "multipart",
33
34
  "openai",
34
35
  "tiktoken",
35
36
  "anthropic",
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
209
209
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
210
 
211
211
  all_logprobs = all_logits
212
- del all_logits
212
+ del all_logits, hidden_states
213
213
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
214
214
 
215
215
  # Get the logprob of top-k tokens
@@ -79,6 +79,7 @@ class TokenizerManager:
79
79
  self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
80
80
 
81
81
  self.model_path = server_args.model_path
82
+ self.served_model_name = server_args.served_model_name
82
83
  self.hf_config = get_config(
83
84
  self.model_path,
84
85
  trust_remote_code=server_args.trust_remote_code,
@@ -312,10 +312,12 @@ class ModelRunner:
312
312
  self.cuda_graph_runner.capture(batch_size_list)
313
313
  except RuntimeError as e:
314
314
  raise Exception(
315
- f"Capture cuda graph failed: {e}. Possible solutions:\n"
316
- f"1. disable cuda graph by --disable-cuda-graph\n"
317
- f"2. set --mem-fraction-static to a smaller value\n"
318
- f"Open an issue on GitHub with reproducible scripts if you need help.\n"
315
+ f"Capture cuda graph failed: {e}\n"
316
+ "Possible solutions:\n"
317
+ "1. disable torch compile by not using --enable-torch-compile\n"
318
+ "2. disable cuda graph by --disable-cuda-graph\n"
319
+ "3. set --mem-fraction-static to a smaller value\n"
320
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
319
321
  )
320
322
 
321
323
  @torch.inference_mode()
@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
594
594
 
595
595
  def v1_chat_generate_request(all_requests, tokenizer_manager):
596
596
 
597
- texts = []
597
+ input_ids = []
598
598
  sampling_params_list = []
599
599
  image_data_list = []
600
600
  return_logprobs = []
@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
608
608
  if not isinstance(request.messages, str):
609
609
  # Apply chat template and its stop strings.
610
610
  if chat_template_name is None:
611
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
612
- request.messages, tokenize=False, add_generation_prompt=True
611
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
612
+ request.messages, tokenize=True, add_generation_prompt=True
613
613
  )
614
614
  stop = request.stop
615
615
  image_data = None
@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
623
623
  stop.append(request.stop)
624
624
  else:
625
625
  stop.extend(request.stop)
626
+ prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
626
627
  else:
627
628
  # Use the raw prompt and stop strings if the messages is already a string.
628
629
  prompt = request.messages
629
630
  stop = request.stop
630
631
  image_data = None
631
- texts.append(prompt)
632
+ input_ids.append(prompt_ids)
632
633
  return_logprobs.append(request.logprobs)
633
634
  top_logprobs_nums.append(request.top_logprobs)
634
635
  sampling_params_list.append(
@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
645
646
  )
646
647
  image_data_list.append(image_data)
647
648
  if len(all_requests) == 1:
648
- texts = texts[0]
649
+ input_ids = input_ids[0]
649
650
  sampling_params_list = sampling_params_list[0]
650
651
  image_data = image_data_list[0]
651
652
  return_logprobs = return_logprobs[0]
652
653
  top_logprobs_nums = top_logprobs_nums[0]
653
654
  adapted_request = GenerateReqInput(
654
- text=texts,
655
+ input_ids=input_ids,
655
656
  image_data=image_data,
656
657
  sampling_params=sampling_params_list,
657
658
  return_logprob=return_logprobs,
sglang/srt/server.py CHANGED
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
72
72
  allocate_init_ports,
73
73
  assert_pkg_version,
74
74
  enable_show_time_cost,
75
+ kill_child_process,
75
76
  maybe_set_triton_cache_manager,
76
77
  set_ulimit,
77
78
  )
@@ -189,10 +190,10 @@ async def retrieve_file_content(file_id: str):
189
190
  @app.get("/v1/models")
190
191
  def available_models():
191
192
  """Show available models."""
192
- model_names = [tokenizer_manager.model_path]
193
+ served_model_names = [tokenizer_manager.served_model_name]
193
194
  model_cards = []
194
- for model_name in model_names:
195
- model_cards.append(ModelCard(id=model_name, root=model_name))
195
+ for served_model_name in served_model_names:
196
+ model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
196
197
  return ModelList(data=model_cards)
197
198
 
198
199
 
@@ -467,16 +468,7 @@ class Runtime:
467
468
 
468
469
  def shutdown(self):
469
470
  if self.pid is not None:
470
- try:
471
- parent = psutil.Process(self.pid)
472
- except psutil.NoSuchProcess:
473
- return
474
- children = parent.children(recursive=True)
475
- for child in children:
476
- child.kill()
477
- psutil.wait_procs(children, timeout=5)
478
- parent.kill()
479
- parent.wait(timeout=5)
471
+ kill_child_process(self.pid)
480
472
  self.pid = None
481
473
 
482
474
  def cache_prefix(self, prefix: str):
sglang/srt/server_args.py CHANGED
@@ -32,6 +32,7 @@ class ServerArgs:
32
32
  trust_remote_code: bool = True
33
33
  context_length: Optional[int] = None
34
34
  quantization: Optional[str] = None
35
+ served_model_name: Optional[str] = None
35
36
  chat_template: Optional[str] = None
36
37
 
37
38
  # Port
@@ -90,6 +91,10 @@ class ServerArgs:
90
91
  def __post_init__(self):
91
92
  if self.tokenizer_path is None:
92
93
  self.tokenizer_path = self.model_path
94
+
95
+ if self.served_model_name is None:
96
+ self.served_model_name = self.model_path
97
+
93
98
  if self.mem_fraction_static is None:
94
99
  if self.tp_size >= 16:
95
100
  self.mem_fraction_static = 0.79
@@ -202,6 +207,12 @@ class ServerArgs:
202
207
  ],
203
208
  help="The quantization method.",
204
209
  )
210
+ parser.add_argument(
211
+ "--served-model-name",
212
+ type=str,
213
+ default=ServerArgs.served_model_name,
214
+ help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
215
+ )
205
216
  parser.add_argument(
206
217
  "--chat-template",
207
218
  type=str,
sglang/srt/utils.py CHANGED
@@ -366,6 +366,26 @@ def kill_parent_process():
366
366
  os.kill(parent_process.pid, 9)
367
367
 
368
368
 
369
+ def kill_child_process(pid, including_parent=True):
370
+ try:
371
+ parent = psutil.Process(pid)
372
+ except psutil.NoSuchProcess:
373
+ return
374
+
375
+ children = parent.children(recursive=True)
376
+ for child in children:
377
+ try:
378
+ child.kill()
379
+ except psutil.NoSuchProcess:
380
+ pass
381
+
382
+ if including_parent:
383
+ try:
384
+ parent.kill()
385
+ except psutil.NoSuchProcess:
386
+ pass
387
+
388
+
369
389
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
370
390
  """
371
391
  Monkey patch the slow p2p access check in vllm.
@@ -0,0 +1,104 @@
1
+ """
2
+ Usage:
3
+ python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import time
10
+
11
+ from sglang.test.simple_eval_common import (
12
+ ChatCompletionSampler,
13
+ download_dataset,
14
+ make_report,
15
+ set_ulimit,
16
+ )
17
+
18
+
19
+ def run_eval(args):
20
+ if "OPENAI_API_KEY" not in os.environ:
21
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
22
+
23
+ base_url = (
24
+ f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
25
+ )
26
+
27
+ if args.eval_name == "mmlu":
28
+ from sglang.test.simple_eval_mmlu import MMLUEval
29
+
30
+ dataset_path = "mmlu.csv"
31
+
32
+ if not os.path.exists(dataset_path):
33
+ download_dataset(
34
+ dataset_path,
35
+ "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
36
+ )
37
+ eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
38
+ elif args.eval_name == "humaneval":
39
+ from sglang.test.simple_eval_humaneval import HumanEval
40
+
41
+ eval_obj = HumanEval(args.num_examples, args.num_threads)
42
+ else:
43
+ raise ValueError(f"Invalid eval name: {args.eval_name}")
44
+
45
+ sampler = ChatCompletionSampler(
46
+ model=args.model,
47
+ max_tokens=2048,
48
+ base_url=base_url,
49
+ )
50
+
51
+ # Run eval
52
+ tic = time.time()
53
+ result = eval_obj(sampler)
54
+ latency = time.time() - tic
55
+
56
+ # Dump reports
57
+ metrics = result.metrics | {"score": result.score}
58
+ file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
59
+ report_filename = f"/tmp/{file_stem}.html"
60
+ print(f"Writing report to {report_filename}")
61
+ with open(report_filename, "w") as fh:
62
+ fh.write(make_report(result))
63
+ metrics = result.metrics | {"score": result.score}
64
+ print(metrics)
65
+ result_filename = f"/tmp/{file_stem}.json"
66
+ with open(result_filename, "w") as f:
67
+ f.write(json.dumps(metrics, indent=2))
68
+ print(f"Writing results to {result_filename}")
69
+
70
+ # Print results
71
+ print(f"Total latency: {latency:.3f} s")
72
+ print(f"Score: {metrics['score']:.3f}")
73
+
74
+ return metrics
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument(
80
+ "--base-url",
81
+ type=str,
82
+ default=None,
83
+ help="Server or API base url if not using http host and port.",
84
+ )
85
+ parser.add_argument(
86
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
87
+ )
88
+ parser.add_argument(
89
+ "--port",
90
+ type=int,
91
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
92
+ )
93
+ parser.add_argument(
94
+ "--model",
95
+ type=str,
96
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
97
+ )
98
+ parser.add_argument("--eval-name", type=str, default="mmlu")
99
+ parser.add_argument("--num-examples", type=int)
100
+ parser.add_argument("--num-threads", type=int, default=64)
101
+ set_ulimit()
102
+ args = parser.parse_args()
103
+
104
+ run_eval(args)