sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,13 @@
1
1
  """
2
- Benchmark the throughput of using the offline LLM engine.
3
- This script does not launch a server.
2
+ Benchmark the throughput in the offline mode.
4
3
  It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
5
4
 
6
5
  # Usage
7
6
  ## Sharegpt dataset with default args
8
- python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
7
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
9
8
 
10
9
  ## Random dataset with default args
11
- python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
12
-
13
- ## Shared prefix dataset with default args
14
- python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
15
-
16
- ## Sharegpt dataset on runtime backend
17
- python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
10
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
18
11
  """
19
12
 
20
13
  import argparse
@@ -23,7 +16,7 @@ import json
23
16
  import logging
24
17
  import random
25
18
  import time
26
- from typing import List, Optional, Tuple
19
+ from typing import Dict, List, Optional, Tuple
27
20
 
28
21
  import numpy as np
29
22
 
@@ -55,7 +48,10 @@ class BenchArgs:
55
48
  gen_question_len: int = 128
56
49
  gen_output_len: int = 256
57
50
  disable_ignore_eos: bool = False
51
+ extra_request_body: Optional[str] = None
58
52
  seed: int = 1
53
+ skip_warmup: bool = False
54
+ do_not_exit: bool = False
59
55
 
60
56
  @staticmethod
61
57
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -142,7 +138,24 @@ class BenchArgs:
142
138
  default=BenchArgs.disable_ignore_eos,
143
139
  help="Disable ignore EOS token",
144
140
  )
141
+ parser.add_argument(
142
+ "--extra-request-body",
143
+ metavar='{"key1": "value1", "key2": "value2"}',
144
+ type=str,
145
+ help="Append given JSON object to the request payload. You can use this to specify"
146
+ "additional generate params like sampling params.",
147
+ )
145
148
  parser.add_argument("--seed", type=int, default=1, help="The random seed.")
149
+ parser.add_argument(
150
+ "--skip-warmup",
151
+ action="store_true",
152
+ help="Skip the warmup batches.",
153
+ )
154
+ parser.add_argument(
155
+ "--do-not-exit",
156
+ action="store_true",
157
+ help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
158
+ )
146
159
 
147
160
  @classmethod
148
161
  def from_cli_args(cls, args: argparse.Namespace):
@@ -155,6 +168,7 @@ def throughput_test_once(
155
168
  backend,
156
169
  reqs: List[Tuple[str, int, int]],
157
170
  ignore_eos: bool,
171
+ extra_request_body: Dict,
158
172
  ):
159
173
  measurement_results = {
160
174
  "backend": backend_name,
@@ -174,6 +188,7 @@ def throughput_test_once(
174
188
  "temperature": 0,
175
189
  "max_new_tokens": r[2],
176
190
  "ignore_eos": ignore_eos,
191
+ **extra_request_body,
177
192
  }
178
193
  for r in reqs
179
194
  ]
@@ -227,31 +242,41 @@ def throughput_test(
227
242
  random.seed(bench_args.seed)
228
243
  np.random.seed(bench_args.seed)
229
244
 
245
+ # Parse args
246
+ extra_request_body = {}
247
+ if bench_args.extra_request_body:
248
+ extra_request_body = json.loads(args.extra_request_body)
249
+
230
250
  # Read dataset
231
251
  input_requests = get_dataset(bench_args, tokenizer)
232
252
 
233
253
  warmup_requests = sample_random_requests(
234
- input_len=20,
235
- output_len=4,
236
- num_prompts=2,
254
+ input_len=256,
255
+ output_len=16,
256
+ num_prompts=16,
237
257
  range_ratio=0.8,
238
258
  tokenizer=tokenizer,
239
259
  dataset_path=bench_args.dataset_path,
240
260
  )
241
261
 
242
262
  # Warm up
243
- throughput_test_once(
244
- backend_name=bench_args.backend,
245
- backend=backend,
246
- reqs=warmup_requests,
247
- ignore_eos=not bench_args.disable_ignore_eos,
248
- )
263
+ if not bench_args.skip_warmup:
264
+ logging.info("\nWarmup...")
265
+ throughput_test_once(
266
+ backend_name=bench_args.backend,
267
+ backend=backend,
268
+ reqs=warmup_requests,
269
+ ignore_eos=not bench_args.disable_ignore_eos,
270
+ extra_request_body=extra_request_body,
271
+ )
249
272
 
273
+ logging.info("\nBenchmark...")
250
274
  result = throughput_test_once(
251
275
  backend_name=bench_args.backend,
252
276
  backend=backend,
253
277
  reqs=input_requests,
254
278
  ignore_eos=not bench_args.disable_ignore_eos,
279
+ extra_request_body=extra_request_body,
255
280
  )
256
281
 
257
282
  if bench_args.result_filename:
@@ -307,3 +332,6 @@ if __name__ == "__main__":
307
332
  )
308
333
 
309
334
  throughput_test(server_args, bench_args)
335
+
336
+ while bench_args.do_not_exit:
337
+ pass
@@ -0,0 +1,472 @@
1
+ """
2
+ Benchmark the latency of running a single static batch without a server.
3
+
4
+ This script does not launch a server and uses the low-level APIs.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ # Usage (latency test)
8
+ ## with dummy weights:
9
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
11
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
+
13
+ # Usage (correctness test):
14
+ python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
15
+
16
+ ## Reference output (of the correctness test above, can be gpu dependent):
17
+ input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
18
+
19
+ prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
20
+ [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
21
+ [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
22
+ device='cuda:0')
23
+
24
+ prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
25
+ [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
26
+ [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
27
+ device='cuda:0')
28
+
29
+ ========== Prompt 0 ==========
30
+ <s> The capital of France is Paris.
31
+ The capital of the United States is Washington, D.C.
32
+
33
+
34
+ ========== Prompt 1 ==========
35
+ <s> The capital of the United Kindom is London.
36
+ The capital of the United Kingdom is London.
37
+ The capital of the
38
+
39
+ ========== Prompt 2 ==========
40
+ <s> Today is a sunny day and I like to go for a walk in the park.
41
+ I'm going to the park
42
+ """
43
+
44
+ import argparse
45
+ import dataclasses
46
+ import itertools
47
+ import json
48
+ import logging
49
+ import multiprocessing
50
+ import time
51
+ from typing import Tuple
52
+
53
+ import numpy as np
54
+ import torch
55
+ import torch.distributed as dist
56
+
57
+ from sglang.srt.configs.model_config import ModelConfig
58
+ from sglang.srt.hf_transformers_utils import get_tokenizer
59
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
60
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
+ from sglang.srt.model_executor.model_runner import ModelRunner
62
+ from sglang.srt.sampling.sampling_params import SamplingParams
63
+ from sglang.srt.server import _set_envs_and_config
64
+ from sglang.srt.server_args import PortArgs, ServerArgs
65
+ from sglang.srt.utils import (
66
+ configure_logger,
67
+ kill_child_process,
68
+ suppress_other_loggers,
69
+ )
70
+
71
+
72
+ @dataclasses.dataclass
73
+ class BenchArgs:
74
+ run_name: str = "default"
75
+ batch_size: Tuple[int] = (1,)
76
+ input_len: Tuple[int] = (1024,)
77
+ output_len: Tuple[int] = (16,)
78
+ result_filename: str = "result.jsonl"
79
+ correctness_test: bool = False
80
+ # This is only used for correctness test
81
+ cut_len: int = 4
82
+
83
+ @staticmethod
84
+ def add_cli_args(parser: argparse.ArgumentParser):
85
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
86
+ parser.add_argument(
87
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
88
+ )
89
+ parser.add_argument(
90
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
91
+ )
92
+ parser.add_argument(
93
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
94
+ )
95
+ parser.add_argument(
96
+ "--result-filename", type=str, default=BenchArgs.result_filename
97
+ )
98
+ parser.add_argument("--correctness-test", action="store_true")
99
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
100
+
101
+ @classmethod
102
+ def from_cli_args(cls, args: argparse.Namespace):
103
+ # use the default value's type to case the args into correct types.
104
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
105
+ return cls(
106
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
107
+ )
108
+
109
+
110
+ def load_model(server_args, port_args, tp_rank):
111
+ suppress_other_loggers()
112
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
113
+
114
+ model_config = ModelConfig(
115
+ server_args.model_path,
116
+ trust_remote_code=server_args.trust_remote_code,
117
+ context_length=server_args.context_length,
118
+ model_override_args=server_args.json_model_override_args,
119
+ )
120
+ model_runner = ModelRunner(
121
+ model_config=model_config,
122
+ mem_fraction_static=server_args.mem_fraction_static,
123
+ gpu_id=tp_rank,
124
+ tp_rank=tp_rank,
125
+ tp_size=server_args.tp_size,
126
+ nccl_port=port_args.nccl_port,
127
+ server_args=server_args,
128
+ )
129
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
130
+ tokenizer = get_tokenizer(
131
+ server_args.tokenizer_path,
132
+ tokenizer_mode=server_args.tokenizer_mode,
133
+ trust_remote_code=server_args.trust_remote_code,
134
+ )
135
+ if server_args.tp_size > 1:
136
+ dist.barrier()
137
+ return model_runner, tokenizer
138
+
139
+
140
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
141
+ prompts = [
142
+ "The capital of France is",
143
+ "The capital of the United Kindom is",
144
+ "Today is a sunny day and I like",
145
+ ]
146
+ input_ids = [tokenizer.encode(p) for p in prompts]
147
+ sampling_params = SamplingParams(
148
+ temperature=0,
149
+ max_new_tokens=BenchArgs.output_len,
150
+ )
151
+
152
+ reqs = []
153
+ for i in range(len(prompts)):
154
+ assert len(input_ids[i]) > bench_args.cut_len
155
+
156
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
157
+ req = Req(
158
+ rid=i,
159
+ origin_input_text=prompts[i],
160
+ origin_input_ids=tmp_input_ids,
161
+ sampling_params=sampling_params,
162
+ )
163
+ req.prefix_indices = []
164
+ req.fill_ids = req.origin_input_ids
165
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
166
+ reqs.append(req)
167
+
168
+ return input_ids, reqs
169
+
170
+
171
+ def prepare_extend_inputs_for_correctness_test(
172
+ bench_args, input_ids, reqs, model_runner
173
+ ):
174
+ for i in range(len(reqs)):
175
+ req = reqs[i]
176
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
177
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
178
+ i, : bench_args.cut_len
179
+ ]
180
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
181
+ return reqs
182
+
183
+
184
+ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
185
+ input_ids = np.ones((batch_size, input_len), dtype=np.int32)
186
+ sampling_params = SamplingParams(
187
+ temperature=0,
188
+ max_new_tokens=BenchArgs.output_len,
189
+ )
190
+
191
+ reqs = []
192
+ for i in range(len(input_ids)):
193
+ req = Req(
194
+ rid=i,
195
+ origin_input_text="",
196
+ origin_input_ids=list(input_ids[i]),
197
+ sampling_params=sampling_params,
198
+ )
199
+ req.prefix_indices = []
200
+ req.fill_ids = req.origin_input_ids
201
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
202
+ reqs.append(req)
203
+
204
+ return reqs
205
+
206
+
207
+ @torch.no_grad
208
+ def extend(reqs, model_runner):
209
+ batch = ScheduleBatch.init_new(
210
+ reqs=reqs,
211
+ req_to_token_pool=model_runner.req_to_token_pool,
212
+ token_to_kv_pool=model_runner.token_to_kv_pool,
213
+ tree_cache=None,
214
+ model_config=model_runner.model_config,
215
+ enable_overlap=False,
216
+ )
217
+ batch.prepare_for_extend()
218
+ model_worker_batch = batch.get_model_worker_batch()
219
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
220
+ logits_output = model_runner.forward(forward_batch)
221
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
222
+ return next_token_ids, logits_output.next_token_logits, batch
223
+
224
+
225
+ @torch.no_grad
226
+ def decode(input_token_ids, batch, model_runner):
227
+ batch.output_ids = input_token_ids
228
+ batch.prepare_for_decode()
229
+ model_worker_batch = batch.get_model_worker_batch()
230
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
231
+ logits_output = model_runner.forward(forward_batch)
232
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
233
+ return next_token_ids, logits_output.next_token_logits
234
+
235
+
236
+ def correctness_test(
237
+ server_args,
238
+ port_args,
239
+ bench_args,
240
+ tp_rank,
241
+ ):
242
+ # Configure the logger
243
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
244
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
245
+
246
+ # Load the model
247
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
248
+
249
+ # Prepare inputs
250
+ input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
251
+ rank_print(f"\n{input_ids=}\n")
252
+
253
+ if bench_args.cut_len > 0:
254
+ # Prefill
255
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
256
+ rank_print(f"prefill logits (first half): {next_token_logits} \n")
257
+
258
+ # Prepare extend inputs
259
+ reqs = prepare_extend_inputs_for_correctness_test(
260
+ bench_args, input_ids, reqs, model_runner
261
+ )
262
+
263
+ # Extend (prefill w/ KV cache)
264
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
265
+ rank_print(f"prefill logits (final): {next_token_logits} \n")
266
+
267
+ # Decode
268
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
269
+ for _ in range(bench_args.output_len[0] - 1):
270
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
271
+ next_token_ids_list = next_token_ids.tolist()
272
+ for i in range(len(reqs)):
273
+ output_ids[i].append(next_token_ids_list[i])
274
+
275
+ # Print output texts
276
+ for i in range(len(reqs)):
277
+ rank_print(f"========== Prompt {i} ==========")
278
+ rank_print(tokenizer.decode(output_ids[i]), "\n")
279
+
280
+
281
+ def synchronize(device):
282
+ torch.get_device_module(device).synchronize()
283
+
284
+
285
+ def latency_test_run_once(
286
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
287
+ ):
288
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
289
+ if batch_size > max_batch_size:
290
+ rank_print(
291
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
292
+ )
293
+ return
294
+
295
+ # Clear the pools.
296
+ model_runner.req_to_token_pool.clear()
297
+ model_runner.token_to_kv_pool.clear()
298
+
299
+ measurement_results = {
300
+ "run_name": run_name,
301
+ "batch_size": batch_size,
302
+ "input_len": input_len,
303
+ "output_len": output_len,
304
+ }
305
+
306
+ tot_latency = 0
307
+
308
+ # Prefill
309
+ synchronize(device)
310
+ tic = time.time()
311
+ next_token_ids, _, batch = extend(reqs, model_runner)
312
+ synchronize(device)
313
+ prefill_latency = time.time() - tic
314
+ tot_latency += prefill_latency
315
+ throughput = input_len * batch_size / prefill_latency
316
+ rank_print(
317
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
318
+ )
319
+ measurement_results["prefill_latency"] = prefill_latency
320
+ measurement_results["prefill_throughput"] = throughput
321
+
322
+ # Decode
323
+ decode_latencies = []
324
+ for i in range(output_len - 1):
325
+ synchronize(device)
326
+ tic = time.time()
327
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
328
+ synchronize(device)
329
+ latency = time.time() - tic
330
+ tot_latency += latency
331
+ throughput = batch_size / latency
332
+ decode_latencies.append(latency)
333
+ if i < 5:
334
+ rank_print(
335
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
336
+ )
337
+
338
+ # Record decode timing from 2nd output
339
+ if output_len > 1:
340
+ med_decode_latency = np.median(decode_latencies)
341
+ med_decode_throughput = batch_size / med_decode_latency
342
+ rank_print(
343
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
344
+ )
345
+ measurement_results["median_decode_latency"] = med_decode_latency
346
+ measurement_results["median_decode_throughput"] = med_decode_throughput
347
+
348
+ throughput = (input_len + output_len) * batch_size / tot_latency
349
+ rank_print(
350
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
351
+ )
352
+ measurement_results["total_latency"] = tot_latency
353
+ measurement_results["overall_throughput"] = throughput
354
+ return measurement_results
355
+
356
+
357
+ def latency_test(
358
+ server_args,
359
+ port_args,
360
+ bench_args,
361
+ tp_rank,
362
+ ):
363
+ # Configure the logger
364
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
365
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
366
+
367
+ # Load the model
368
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
369
+
370
+ # Prepare inputs for warm up
371
+ reqs = prepare_synthetic_inputs_for_latency_test(
372
+ bench_args.batch_size[0], bench_args.input_len[0]
373
+ )
374
+
375
+ # Warm up
376
+ rank_print("Warmup ...")
377
+ latency_test_run_once(
378
+ bench_args.run_name,
379
+ model_runner,
380
+ rank_print,
381
+ reqs,
382
+ bench_args.batch_size[0],
383
+ bench_args.input_len[0],
384
+ 8, # shorter decoding to speed up the warmup
385
+ server_args.device,
386
+ )
387
+ rank_print("Benchmark ...")
388
+
389
+ # Run the sweep
390
+ result_list = []
391
+ for bs, il, ol in itertools.product(
392
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
393
+ ):
394
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
395
+ ret = latency_test_run_once(
396
+ bench_args.run_name,
397
+ model_runner,
398
+ rank_print,
399
+ reqs,
400
+ bs,
401
+ il,
402
+ ol,
403
+ server_args.device,
404
+ )
405
+ if ret is not None:
406
+ result_list.append(ret)
407
+
408
+ # Write results in jsonlines format on rank 0.
409
+ if tp_rank == 0 and bench_args.result_filename:
410
+ with open(bench_args.result_filename, "a") as fout:
411
+ for result in result_list:
412
+ fout.write(json.dumps(result) + "\n")
413
+
414
+
415
+ def main(server_args, bench_args):
416
+ _set_envs_and_config(server_args)
417
+
418
+ if server_args.model_path:
419
+ if bench_args.correctness_test:
420
+ work_func = correctness_test
421
+ else:
422
+ work_func = latency_test
423
+ else:
424
+ raise ValueError(
425
+ "Provide --model-path for running the tests or "
426
+ "provide --result-filename for plotting the results"
427
+ )
428
+
429
+ port_args = PortArgs.init_new(server_args)
430
+
431
+ if server_args.tp_size == 1:
432
+ work_func(server_args, port_args, bench_args, 0)
433
+ else:
434
+ workers = []
435
+ for tp_rank in range(server_args.tp_size):
436
+ proc = multiprocessing.Process(
437
+ target=work_func,
438
+ args=(
439
+ server_args,
440
+ port_args,
441
+ bench_args,
442
+ tp_rank,
443
+ ),
444
+ )
445
+ proc.start()
446
+ workers.append(proc)
447
+
448
+ for proc in workers:
449
+ proc.join()
450
+
451
+ proc.terminate()
452
+
453
+
454
+ if __name__ == "__main__":
455
+ parser = argparse.ArgumentParser()
456
+ ServerArgs.add_cli_args(parser)
457
+ BenchArgs.add_cli_args(parser)
458
+ args = parser.parse_args()
459
+ server_args = ServerArgs.from_cli_args(args)
460
+ bench_args = BenchArgs.from_cli_args(args)
461
+
462
+ logging.basicConfig(
463
+ level=getattr(logging, server_args.log_level.upper()),
464
+ format="%(message)s",
465
+ )
466
+
467
+ try:
468
+ main(server_args, bench_args)
469
+ except Exception as e:
470
+ raise e
471
+ finally:
472
+ kill_child_process()
@@ -1,10 +1,10 @@
1
1
  """
2
- Benchmark the latency of serving a single batch with a real server.
2
+ Benchmark the latency of running a single batch with a server.
3
+
3
4
  This script launches a server and uses the HTTP interface.
4
- It accepts arguments similar to those of launch_server.py.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
5
6
 
6
7
  Usage:
7
-
8
8
  python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
9
 
10
10
  python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8