sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
- # SGL API Components
1
+ # SGLang public APIs
2
2
 
3
+ # Frontend Language APIs
3
4
  from sglang.api import (
4
5
  Engine,
5
6
  Runtime,
@@ -23,16 +24,26 @@ from sglang.api import (
23
24
  user_end,
24
25
  video,
25
26
  )
27
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
26
28
  from sglang.lang.choices import (
27
29
  greedy_token_selection,
28
30
  token_length_normalized,
29
31
  unconditional_likelihood_normalized,
30
32
  )
33
+ from sglang.utils import LazyImport
34
+
35
+ Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
36
+ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
37
+ OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
38
+ VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
39
+
40
+ # Other configs
41
+ from sglang.global_config import global_config
42
+ from sglang.version import __version__
31
43
 
32
- # SGLang DSL APIs
33
44
  __all__ = [
34
- "Runtime",
35
45
  "Engine",
46
+ "Runtime",
36
47
  "assistant",
37
48
  "assistant_begin",
38
49
  "assistant_end",
@@ -52,27 +63,14 @@ __all__ = [
52
63
  "user_begin",
53
64
  "user_end",
54
65
  "video",
66
+ "RuntimeEndpoint",
55
67
  "greedy_token_selection",
56
68
  "token_length_normalized",
57
69
  "unconditional_likelihood_normalized",
70
+ "Anthropic",
71
+ "LiteLLM",
72
+ "OpenAI",
73
+ "VertexAI",
74
+ "global_config",
75
+ "__version__",
58
76
  ]
59
-
60
- # Global Configurations
61
- from sglang.global_config import global_config
62
-
63
- __all__ += ["global_config"]
64
-
65
- from sglang.version import __version__
66
-
67
- __all__ += ["__version__"]
68
-
69
- # SGLang Backends
70
- from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
71
- from sglang.utils import LazyImport
72
-
73
- Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
74
- LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
75
- OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
76
- VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
77
-
78
- __all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
sglang/api.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Public APIs of the language."""
2
2
 
3
- import os
4
3
  import re
5
4
  from typing import Callable, List, Optional, Union
6
5
 
@@ -33,19 +32,15 @@ def function(
33
32
 
34
33
 
35
34
  def Runtime(*args, **kwargs):
36
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
37
-
38
35
  # Avoid importing unnecessary dependency
39
- from sglang.srt.server import Runtime
36
+ from sglang.lang.backend.runtime_endpoint import Runtime
40
37
 
41
38
  return Runtime(*args, **kwargs)
42
39
 
43
40
 
44
41
  def Engine(*args, **kwargs):
45
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
46
-
47
42
  # Avoid importing unnecessary dependency
48
- from sglang.srt.server import Engine
43
+ from sglang.srt.entrypoints.engine import Engine
49
44
 
50
45
  return Engine(*args, **kwargs)
51
46
 
@@ -27,7 +27,8 @@ from sglang.bench_serving import (
27
27
  sample_random_requests,
28
28
  set_ulimit,
29
29
  )
30
- from sglang.srt.server import Engine, Runtime
30
+ from sglang.lang.backend.runtime_endpoint import Runtime
31
+ from sglang.srt.entrypoints.engine import Engine
31
32
  from sglang.srt.server_args import ServerArgs
32
33
 
33
34
 
@@ -39,14 +40,15 @@ class BenchArgs:
39
40
  dataset_path: str = ""
40
41
  num_prompts: int = 1000
41
42
  sharegpt_output_len: Optional[int] = None
43
+ sharegpt_context_len: Optional[int] = None
42
44
  random_input_len: int = 1024
43
45
  random_output_len: int = 1024
44
46
  random_range_ratio: float = 0.0
45
- gen_num_groups: int = 64
46
- gen_prompts_per_group: int = 16
47
- gen_system_prompt_len: int = 2048
48
- gen_question_len: int = 128
49
- gen_output_len: int = 256
47
+ gsp_num_groups: int = 64
48
+ gsp_prompts_per_group: int = 16
49
+ gsp_system_prompt_len: int = 2048
50
+ gsp_question_len: int = 128
51
+ gsp_output_len: int = 256
50
52
  disable_ignore_eos: bool = False
51
53
  extra_request_body: Optional[str] = None
52
54
  seed: int = 1
@@ -82,6 +84,12 @@ class BenchArgs:
82
84
  default=BenchArgs.sharegpt_output_len,
83
85
  help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
84
86
  )
87
+ parser.add_argument(
88
+ "--sharegpt-context-len",
89
+ type=int,
90
+ default=BenchArgs.sharegpt_context_len,
91
+ help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
92
+ )
85
93
  parser.add_argument(
86
94
  "--random-input-len",
87
95
  type=int,
@@ -102,35 +110,35 @@ class BenchArgs:
102
110
  "used only for random dataset.",
103
111
  )
104
112
  parser.add_argument(
105
- "--gen-num-groups",
113
+ "--gsp-num-groups",
106
114
  type=int,
107
- default=BenchArgs.gen_num_groups,
115
+ default=BenchArgs.gsp_num_groups,
108
116
  help="Number of groups with shared prefix, used"
109
117
  "only for generate-shared-prefix",
110
118
  )
111
119
  parser.add_argument(
112
- "--gen-prompts-per-group",
120
+ "--gsp-prompts-per-group",
113
121
  type=int,
114
- default=BenchArgs.gen_prompts_per_group,
122
+ default=BenchArgs.gsp_prompts_per_group,
115
123
  help="Number of prompts per group of shared prefix, used"
116
124
  "only for generate-shared-prefix",
117
125
  )
118
126
  parser.add_argument(
119
- "--gen-system-prompt-len",
127
+ "--gsp-system-prompt-len",
120
128
  type=int,
121
- default=BenchArgs.gen_system_prompt_len,
129
+ default=BenchArgs.gsp_system_prompt_len,
122
130
  help="System prompt length, used" "only for generate-shared-prefix",
123
131
  )
124
132
  parser.add_argument(
125
- "--gen-question-len",
133
+ "--gsp-question-len",
126
134
  type=int,
127
- default=BenchArgs.gen_question_len,
135
+ default=BenchArgs.gsp_question_len,
128
136
  help="Question length, used" "only for generate-shared-prefix",
129
137
  )
130
138
  parser.add_argument(
131
- "--gen-output-len",
139
+ "--gsp-output-len",
132
140
  type=int,
133
- default=BenchArgs.gen_output_len,
141
+ default=BenchArgs.gsp_output_len,
134
142
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
135
143
  )
136
144
  parser.add_argument(
sglang/bench_one_batch.py CHANGED
@@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen
9
9
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
10
  ## sweep through multiple data points and store (append) the results in a jsonl file:
11
11
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
-
12
+ ## run with profiling:
13
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
13
14
  # Usage (correctness test):
14
15
  python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
15
16
 
@@ -56,12 +57,12 @@ import torch
56
57
  import torch.distributed as dist
57
58
 
58
59
  from sglang.srt.configs.model_config import ModelConfig
60
+ from sglang.srt.entrypoints.engine import _set_envs_and_config
59
61
  from sglang.srt.hf_transformers_utils import get_tokenizer
60
62
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
61
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
64
  from sglang.srt.model_executor.model_runner import ModelRunner
63
65
  from sglang.srt.sampling.sampling_params import SamplingParams
64
- from sglang.srt.server import _set_envs_and_config
65
66
  from sglang.srt.server_args import PortArgs, ServerArgs
66
67
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
67
68
  from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
@@ -77,6 +78,8 @@ class BenchArgs:
77
78
  correctness_test: bool = False
78
79
  # This is only used for correctness test
79
80
  cut_len: int = 4
81
+ profile: bool = False
82
+ profile_filename_prefix: str = "profile"
80
83
 
81
84
  @staticmethod
82
85
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -95,6 +98,19 @@ class BenchArgs:
95
98
  )
96
99
  parser.add_argument("--correctness-test", action="store_true")
97
100
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
101
+ parser.add_argument(
102
+ "--profile",
103
+ action="store_true",
104
+ help="Use Torch Profiler. The endpoint must be launched with "
105
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
106
+ )
107
+ parser.add_argument(
108
+ "--profile-filename-prefix",
109
+ type=str,
110
+ default=BenchArgs.profile_filename_prefix,
111
+ help="Prefix of the profiling file names. The full profiling result file(s) be "
112
+ '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
113
+ )
98
114
 
99
115
  @classmethod
100
116
  def from_cli_args(cls, args: argparse.Namespace):
@@ -216,6 +232,7 @@ def extend(reqs, model_runner):
216
232
  model_config=model_runner.model_config,
217
233
  enable_overlap=False,
218
234
  spec_algorithm=SpeculativeAlgorithm.NONE,
235
+ enable_custom_logit_processor=False,
219
236
  )
220
237
  batch.prepare_for_extend()
221
238
  model_worker_batch = batch.get_model_worker_batch()
@@ -286,7 +303,16 @@ def synchronize(device):
286
303
 
287
304
 
288
305
  def latency_test_run_once(
289
- run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
306
+ run_name,
307
+ model_runner,
308
+ rank_print,
309
+ reqs,
310
+ batch_size,
311
+ input_len,
312
+ output_len,
313
+ device,
314
+ profile,
315
+ profile_filename_prefix,
290
316
  ):
291
317
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
292
318
  if batch_size > max_batch_size:
@@ -308,6 +334,17 @@ def latency_test_run_once(
308
334
 
309
335
  tot_latency = 0
310
336
 
337
+ profiler = None
338
+ if profile:
339
+ profiler = torch.profiler.profile(
340
+ activities=[
341
+ torch.profiler.ProfilerActivity.CPU,
342
+ torch.profiler.ProfilerActivity.CUDA,
343
+ ],
344
+ with_stack=True,
345
+ )
346
+ profiler.start()
347
+
311
348
  # Prefill
312
349
  synchronize(device)
313
350
  tic = time.time()
@@ -338,6 +375,13 @@ def latency_test_run_once(
338
375
  f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
339
376
  )
340
377
 
378
+ if profile:
379
+ profiler.stop()
380
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
381
+ parent_dir = os.path.dirname(os.path.abspath(profile_filename))
382
+ os.makedirs(parent_dir, exist_ok=True)
383
+ profiler.export_chrome_trace(profile_filename)
384
+
341
385
  # Record decode timing from 2nd output
342
386
  if output_len > 1:
343
387
  med_decode_latency = np.median(decode_latencies)
@@ -386,6 +430,8 @@ def latency_test(
386
430
  bench_args.input_len[0],
387
431
  8, # shorter decoding to speed up the warmup
388
432
  server_args.device,
433
+ profile=False,
434
+ profile_filename_prefix="", # not used
389
435
  )
390
436
 
391
437
  rank_print("Benchmark ...")
@@ -405,6 +451,8 @@ def latency_test(
405
451
  il,
406
452
  ol,
407
453
  server_args.device,
454
+ bench_args.profile,
455
+ bench_args.profile_filename_prefix,
408
456
  )
409
457
  if ret is not None:
410
458
  result_list.append(ret)
@@ -22,7 +22,7 @@ from typing import Tuple
22
22
  import numpy as np
23
23
  import requests
24
24
 
25
- from sglang.srt.server import launch_server
25
+ from sglang.srt.entrypoints.http_server import launch_server
26
26
  from sglang.srt.server_args import ServerArgs
27
27
  from sglang.srt.utils import kill_process_tree
28
28
 
sglang/bench_serving.py CHANGED
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
452
452
  num_requests=args.num_prompts,
453
453
  tokenizer=tokenizer,
454
454
  fixed_output_len=args.sharegpt_output_len,
455
+ context_len=args.sharegpt_context_len,
455
456
  )
456
457
  elif args.dataset_name == "random":
457
458
  input_requests = sample_random_requests(
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
464
465
  )
465
466
  elif args.dataset_name == "generated-shared-prefix":
466
467
  input_requests = sample_generated_shared_prefix_requests(
467
- num_groups=args.gen_num_groups,
468
- prompts_per_group=args.gen_prompts_per_group,
469
- system_prompt_len=args.gen_system_prompt_len,
470
- question_len=args.gen_question_len,
471
- output_len=args.gen_output_len,
468
+ num_groups=args.gsp_num_groups,
469
+ prompts_per_group=args.gsp_prompts_per_group,
470
+ system_prompt_len=args.gsp_system_prompt_len,
471
+ question_len=args.gsp_question_len,
472
+ output_len=args.gsp_output_len,
472
473
  tokenizer=tokenizer,
473
474
  )
474
475
  else:
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
560
561
  num_requests: int,
561
562
  tokenizer: PreTrainedTokenizerBase,
562
563
  fixed_output_len: Optional[int] = None,
564
+ context_len: Optional[int] = None,
563
565
  ) -> List[Tuple[str, int, int]]:
564
566
  if fixed_output_len is not None and fixed_output_len < 4:
565
567
  raise ValueError("output_len too small")
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
597
599
  output_len = (
598
600
  len(completion_token_ids) if fixed_output_len is None else fixed_output_len
599
601
  )
600
- if prompt_len < 4 or output_len < 4:
602
+
603
+ if prompt_len < 1 or output_len < 1:
601
604
  # Prune too short sequences.
602
605
  continue
603
- if prompt_len > 1024 or (
604
- prompt_len + output_len > 2048 and fixed_output_len is None
605
- ):
606
+
607
+ if context_len and prompt_len + output_len > context_len:
606
608
  # Prune too long sequences.
607
609
  continue
610
+
608
611
  filtered_dataset.append((prompt, prompt_len, output_len))
609
612
 
610
613
  print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
706
709
 
707
710
  # Create a unique cache filename based on the generation parameters
708
711
  cache_key = (
709
- f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
710
- f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
712
+ f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
713
+ f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
711
714
  f"{tokenizer.__class__.__name__}.pkl"
712
715
  )
713
716
  return cache_dir / cache_key
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
1374
1377
  default=None,
1375
1378
  help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
1376
1379
  )
1380
+ parser.add_argument(
1381
+ "--sharegpt-context-len",
1382
+ type=int,
1383
+ default=None,
1384
+ help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
1385
+ )
1377
1386
  parser.add_argument(
1378
1387
  "--random-input-len",
1379
1388
  type=int,
@@ -1453,49 +1462,49 @@ if __name__ == "__main__":
1453
1462
  help="Append given JSON object to the request payload. You can use this to specify"
1454
1463
  "additional generate params like sampling params.",
1455
1464
  )
1465
+ parser.add_argument(
1466
+ "--profile",
1467
+ action="store_true",
1468
+ help="Use Torch Profiler. The endpoint must be launched with "
1469
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1470
+ )
1471
+ parser.add_argument(
1472
+ "--lora-name",
1473
+ type=str,
1474
+ default=None,
1475
+ help="The name of LoRA adapter",
1476
+ )
1456
1477
 
1457
1478
  group = parser.add_argument_group("generated-shared-prefix dataset arguments")
1458
1479
  group.add_argument(
1459
- "--gen-num-groups",
1480
+ "--gsp-num-groups",
1460
1481
  type=int,
1461
1482
  default=64,
1462
1483
  help="Number of system prompt groups for generated-shared-prefix dataset",
1463
1484
  )
1464
1485
  group.add_argument(
1465
- "--gen-prompts-per-group",
1486
+ "--gsp-prompts-per-group",
1466
1487
  type=int,
1467
1488
  default=16,
1468
1489
  help="Number of prompts per system prompt group for generated-shared-prefix dataset",
1469
1490
  )
1470
1491
  group.add_argument(
1471
- "--gen-system-prompt-len",
1492
+ "--gsp-system-prompt-len",
1472
1493
  type=int,
1473
1494
  default=2048,
1474
1495
  help="Target length in tokens for system prompts in generated-shared-prefix dataset",
1475
1496
  )
1476
1497
  group.add_argument(
1477
- "--gen-question-len",
1498
+ "--gsp-question-len",
1478
1499
  type=int,
1479
1500
  default=128,
1480
1501
  help="Target length in tokens for questions in generated-shared-prefix dataset",
1481
1502
  )
1482
1503
  group.add_argument(
1483
- "--gen-output-len",
1504
+ "--gsp-output-len",
1484
1505
  type=int,
1485
1506
  default=256,
1486
1507
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
1487
1508
  )
1488
- parser.add_argument(
1489
- "--profile",
1490
- action="store_true",
1491
- help="Use Torch Profiler. The endpoint must be launched with "
1492
- "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1493
- )
1494
- parser.add_argument(
1495
- "--lora-name",
1496
- type=str,
1497
- default=None,
1498
- help="The name of LoRA adapter",
1499
- )
1500
1509
  args = parser.parse_args()
1501
1510
  run_benchmark(args)
@@ -1,6 +1,11 @@
1
+ import atexit
1
2
  import json
3
+ import multiprocessing
2
4
  import warnings
3
- from typing import List, Optional
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import aiohttp
8
+ import requests
4
9
 
5
10
  from sglang.global_config import global_config
6
11
  from sglang.lang.backend.base_backend import BaseBackend
@@ -251,11 +256,12 @@ class RuntimeEndpoint(BaseBackend):
251
256
  }
252
257
  obj = self._generate_http_request(s, data)
253
258
 
254
- normalized_prompt_logprobs = [
255
- r["meta_info"]["normalized_prompt_logprob"] for r in obj
256
- ]
257
259
  input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
258
260
  output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
261
+ normalized_prompt_logprobs = [
262
+ compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
263
+ for r in obj
264
+ ]
259
265
 
260
266
  # Remove extra token if no token healing occurred
261
267
  for i in range(len(input_token_logprobs)):
@@ -319,3 +325,176 @@ class RuntimeEndpoint(BaseBackend):
319
325
  def _assert_success(self, res):
320
326
  if res.status_code != 200:
321
327
  raise RuntimeError(res.json())
328
+
329
+
330
+ def compute_normalized_prompt_logprobs(input_logprobs):
331
+ values = [x[0] for x in input_logprobs if x[0]]
332
+ return sum(values) / len(values)
333
+
334
+
335
+ class Runtime:
336
+ """
337
+ A wrapper for the HTTP server.
338
+ This is used for launching the server in a python program without
339
+ using the commond line interface.
340
+
341
+ It is mainly used for the frontend language.
342
+ You should use the Engine class if you want to do normal offline processing without the frontend language.
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ log_level: str = "error",
348
+ *args,
349
+ **kwargs,
350
+ ):
351
+ """See the arguments in server_args.py::ServerArgs"""
352
+ # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
353
+ # client code without installing SRT server and its dependency if they want.
354
+ from sglang.srt.entrypoints.http_server import launch_server
355
+ from sglang.srt.server_args import ServerArgs
356
+ from sglang.srt.utils import is_port_available
357
+
358
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
359
+
360
+ # Pre-allocate ports
361
+ for port in range(self.server_args.port, 40000):
362
+ if is_port_available(port):
363
+ break
364
+ self.server_args.port = port
365
+
366
+ self.url = self.server_args.url()
367
+ self.generate_url = self.url + "/generate"
368
+
369
+ # NOTE: We store pid instead of proc to fix some issues during __delete__
370
+ self.pid = None
371
+ pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
372
+
373
+ proc = multiprocessing.Process(
374
+ target=launch_server,
375
+ args=(self.server_args, pipe_writer),
376
+ )
377
+ proc.start()
378
+ pipe_writer.close()
379
+ self.pid = proc.pid
380
+
381
+ # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
382
+ atexit.register(self.shutdown)
383
+
384
+ # TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
385
+ try:
386
+ init_state = pipe_reader.recv()
387
+ except EOFError:
388
+ init_state = ""
389
+
390
+ if init_state != "ready":
391
+ self.shutdown()
392
+ raise RuntimeError(
393
+ "Initialization failed. Please see the error messages above."
394
+ )
395
+
396
+ self.endpoint = RuntimeEndpoint(self.url)
397
+
398
+ def shutdown(self):
399
+ from sglang.srt.utils import kill_process_tree
400
+
401
+ if self.pid is not None:
402
+ kill_process_tree(self.pid)
403
+ self.pid = None
404
+
405
+ def cache_prefix(self, prefix: str):
406
+ self.endpoint.cache_prefix(prefix)
407
+
408
+ def get_tokenizer(self):
409
+ from sglang.srt.hf_transformers_utils import get_tokenizer
410
+
411
+ return get_tokenizer(
412
+ self.server_args.tokenizer_path,
413
+ tokenizer_mode=self.server_args.tokenizer_mode,
414
+ trust_remote_code=self.server_args.trust_remote_code,
415
+ revision=self.server_args.revision,
416
+ )
417
+
418
+ async def async_generate(
419
+ self,
420
+ prompt: str,
421
+ sampling_params: Optional[Dict] = None,
422
+ ):
423
+ if self.server_args.skip_tokenizer_init:
424
+ json_data = {
425
+ "input_ids": prompt,
426
+ "sampling_params": sampling_params,
427
+ "stream": True,
428
+ }
429
+ else:
430
+ json_data = {
431
+ "text": prompt,
432
+ "sampling_params": sampling_params,
433
+ "stream": True,
434
+ }
435
+ pos = 0
436
+
437
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
438
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
439
+ async with session.post(self.generate_url, json=json_data) as response:
440
+ async for chunk, _ in response.content.iter_chunks():
441
+ chunk = chunk.decode("utf-8")
442
+ if chunk and chunk.startswith("data:"):
443
+ if chunk == "data: [DONE]\n\n":
444
+ break
445
+ data = json.loads(chunk[5:].strip("\n"))
446
+ if "text" in data:
447
+ cur = data["text"][pos:]
448
+ if cur:
449
+ yield cur
450
+ pos += len(cur)
451
+ else:
452
+ yield data
453
+
454
+ add_request = async_generate
455
+
456
+ def generate(
457
+ self,
458
+ prompt: Union[str, List[str]],
459
+ sampling_params: Optional[Dict] = None,
460
+ return_logprob: Optional[Union[List[bool], bool]] = False,
461
+ logprob_start_len: Optional[Union[List[int], int]] = None,
462
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
463
+ lora_path: Optional[List[Optional[str]]] = None,
464
+ ):
465
+ json_data = {
466
+ "text": prompt,
467
+ "sampling_params": sampling_params,
468
+ "return_logprob": return_logprob,
469
+ "logprob_start_len": logprob_start_len,
470
+ "top_logprobs_num": top_logprobs_num,
471
+ "lora_path": lora_path,
472
+ }
473
+ assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
474
+ response = requests.post(
475
+ self.url + "/generate",
476
+ json=json_data,
477
+ )
478
+ return json.dumps(response.json())
479
+
480
+ def encode(
481
+ self,
482
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
483
+ ):
484
+ json_data = {"text": prompt}
485
+ response = requests.post(self.url + "/encode", json=json_data)
486
+ return json.dumps(response.json())
487
+
488
+ async def get_server_info(self):
489
+ async with aiohttp.ClientSession() as session:
490
+ async with session.get(f"{self.url}/get_server_info") as response:
491
+ if response.status == 200:
492
+ return await response.json()
493
+ else:
494
+ error_data = await response.json()
495
+ raise RuntimeError(
496
+ f"Failed to get server info. {error_data['error']['message']}"
497
+ )
498
+
499
+ def __del__(self):
500
+ self.shutdown()