sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ """
2
+ Benchmark the throughput in the offline mode.
3
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
4
+
5
+ # Usage
6
+ ## Sharegpt dataset with default args
7
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
8
+
9
+ ## Random dataset with default args
10
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
11
+ """
12
+
13
+ import argparse
14
+ import dataclasses
15
+ import json
16
+ import logging
17
+ import random
18
+ import time
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ import numpy as np
22
+
23
+ from sglang.api import Engine
24
+ from sglang.bench_serving import (
25
+ get_dataset,
26
+ get_tokenizer,
27
+ sample_random_requests,
28
+ set_ulimit,
29
+ )
30
+ from sglang.srt.server import Runtime
31
+ from sglang.srt.server_args import ServerArgs
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class BenchArgs:
36
+ backend: str = "engine"
37
+ result_filename: str = ""
38
+ dataset_name: str = "sharegpt"
39
+ dataset_path: str = ""
40
+ num_prompts: int = 1000
41
+ sharegpt_output_len: Optional[int] = None
42
+ random_input_len: int = 1024
43
+ random_output_len: int = 1024
44
+ random_range_ratio: float = 0.0
45
+ gen_num_groups: int = 64
46
+ gen_prompts_per_group: int = 16
47
+ gen_system_prompt_len: int = 2048
48
+ gen_question_len: int = 128
49
+ gen_output_len: int = 256
50
+ disable_ignore_eos: bool = False
51
+ extra_request_body: Optional[str] = None
52
+ seed: int = 1
53
+ skip_warmup: bool = False
54
+ do_not_exit: bool = False
55
+
56
+ @staticmethod
57
+ def add_cli_args(parser: argparse.ArgumentParser):
58
+ parser.add_argument("--backend", type=str, default=BenchArgs.backend)
59
+ parser.add_argument(
60
+ "--result-filename", type=str, default=BenchArgs.result_filename
61
+ )
62
+ parser.add_argument(
63
+ "--dataset-name",
64
+ type=str,
65
+ default="sharegpt",
66
+ choices=["sharegpt", "random", "generated-shared-prefix"],
67
+ help="Name of the dataset to benchmark on.",
68
+ )
69
+ parser.add_argument(
70
+ "--dataset-path", type=str, default="", help="Path to the dataset."
71
+ )
72
+ parser.add_argument(
73
+ "--num-prompts",
74
+ type=int,
75
+ default=BenchArgs.num_prompts,
76
+ help="Number of prompts to process. Default is 1000.",
77
+ )
78
+ parser.add_argument(
79
+ "--sharegpt-output-len",
80
+ type=int,
81
+ default=BenchArgs.sharegpt_output_len,
82
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
83
+ )
84
+ parser.add_argument(
85
+ "--random-input-len",
86
+ type=int,
87
+ default=BenchArgs.random_input_len,
88
+ help="Number of input tokens per request, used only for random dataset.",
89
+ )
90
+ parser.add_argument(
91
+ "--random-output-len",
92
+ type=int,
93
+ default=BenchArgs.random_output_len,
94
+ help="Number of output tokens per request, used only for random dataset.",
95
+ )
96
+ parser.add_argument(
97
+ "--random-range-ratio",
98
+ type=float,
99
+ default=BenchArgs.random_range_ratio,
100
+ help="Range of sampled ratio of input/output length, "
101
+ "used only for random dataset.",
102
+ )
103
+ parser.add_argument(
104
+ "--gen-num-groups",
105
+ type=int,
106
+ default=BenchArgs.gen_num_groups,
107
+ help="Number of groups with shared prefix, used"
108
+ "only for generate-shared-prefix",
109
+ )
110
+ parser.add_argument(
111
+ "--gen-prompts-per-group",
112
+ type=int,
113
+ default=BenchArgs.gen_prompts_per_group,
114
+ help="Number of prompts per group of shared prefix, used"
115
+ "only for generate-shared-prefix",
116
+ )
117
+ parser.add_argument(
118
+ "--gen-system-prompt-len",
119
+ type=int,
120
+ default=BenchArgs.gen_system_prompt_len,
121
+ help="System prompt length, used" "only for generate-shared-prefix",
122
+ )
123
+ parser.add_argument(
124
+ "--gen-question-len",
125
+ type=int,
126
+ default=BenchArgs.gen_question_len,
127
+ help="Question length, used" "only for generate-shared-prefix",
128
+ )
129
+ parser.add_argument(
130
+ "--gen-output-len",
131
+ type=int,
132
+ default=BenchArgs.gen_output_len,
133
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
134
+ )
135
+ parser.add_argument(
136
+ "--disable-ignore-eos",
137
+ type=bool,
138
+ default=BenchArgs.disable_ignore_eos,
139
+ help="Disable ignore EOS token",
140
+ )
141
+ parser.add_argument(
142
+ "--extra-request-body",
143
+ metavar='{"key1": "value1", "key2": "value2"}',
144
+ type=str,
145
+ help="Append given JSON object to the request payload. You can use this to specify"
146
+ "additional generate params like sampling params.",
147
+ )
148
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
149
+ parser.add_argument(
150
+ "--skip-warmup",
151
+ action="store_true",
152
+ help="Skip the warmup batches.",
153
+ )
154
+ parser.add_argument(
155
+ "--do-not-exit",
156
+ action="store_true",
157
+ help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
158
+ )
159
+
160
+ @classmethod
161
+ def from_cli_args(cls, args: argparse.Namespace):
162
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
163
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
164
+
165
+
166
+ def throughput_test_once(
167
+ backend_name: str,
168
+ backend,
169
+ reqs: List[Tuple[str, int, int]],
170
+ ignore_eos: bool,
171
+ extra_request_body: Dict,
172
+ ):
173
+ measurement_results = {
174
+ "backend": backend_name,
175
+ "successful_requests": len(reqs),
176
+ "total_latency": -1,
177
+ "total_input_tokens": sum(r[1] for r in reqs),
178
+ "total_output_tokens": -1,
179
+ "request_throughput": -1,
180
+ "input_throughput": -1,
181
+ "output_throughput": -1,
182
+ "total_throughput": -1,
183
+ }
184
+
185
+ prompt = [r[0] for r in reqs]
186
+ sampling_params = [
187
+ {
188
+ "temperature": 0,
189
+ "max_new_tokens": r[2],
190
+ "ignore_eos": ignore_eos,
191
+ **extra_request_body,
192
+ }
193
+ for r in reqs
194
+ ]
195
+
196
+ st = time.perf_counter()
197
+ gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
198
+ latency = time.perf_counter() - st
199
+
200
+ if backend_name == "runtime":
201
+ gen_out = json.loads(gen_out)
202
+
203
+ measurement_results["total_latency"] = latency
204
+ measurement_results["total_output_tokens"] = sum(
205
+ o["meta_info"]["completion_tokens"] for o in gen_out
206
+ )
207
+ measurement_results["request_throughput"] = (
208
+ measurement_results["successful_requests"] / latency
209
+ )
210
+ measurement_results["input_throughput"] = (
211
+ measurement_results["total_input_tokens"] / latency
212
+ )
213
+ measurement_results["output_throughput"] = (
214
+ measurement_results["total_output_tokens"] / latency
215
+ )
216
+ measurement_results["total_throughput"] = (
217
+ measurement_results["total_input_tokens"]
218
+ + measurement_results["total_output_tokens"]
219
+ ) / latency
220
+
221
+ return measurement_results
222
+
223
+
224
+ def throughput_test(
225
+ server_args: ServerArgs,
226
+ bench_args: BenchArgs,
227
+ ):
228
+ if bench_args.backend == "engine":
229
+ backend = Engine(**dataclasses.asdict(server_args))
230
+ if not backend:
231
+ raise ValueError("Please provide valid engine arguments")
232
+ elif bench_args.backend == "runtime":
233
+ backend = Runtime(**dataclasses.asdict(server_args))
234
+ else:
235
+ raise ValueError('Please set backend to either "engine" or "runtime"')
236
+
237
+ tokenizer_id = server_args.model_path
238
+ tokenizer = get_tokenizer(tokenizer_id)
239
+
240
+ # Set global environmnets
241
+ set_ulimit()
242
+ random.seed(bench_args.seed)
243
+ np.random.seed(bench_args.seed)
244
+
245
+ # Parse args
246
+ extra_request_body = {}
247
+ if bench_args.extra_request_body:
248
+ extra_request_body = json.loads(args.extra_request_body)
249
+
250
+ # Read dataset
251
+ input_requests = get_dataset(bench_args, tokenizer)
252
+
253
+ warmup_requests = sample_random_requests(
254
+ input_len=256,
255
+ output_len=16,
256
+ num_prompts=16,
257
+ range_ratio=0.8,
258
+ tokenizer=tokenizer,
259
+ dataset_path=bench_args.dataset_path,
260
+ )
261
+
262
+ # Warm up
263
+ if not bench_args.skip_warmup:
264
+ logging.info("\nWarmup...")
265
+ throughput_test_once(
266
+ backend_name=bench_args.backend,
267
+ backend=backend,
268
+ reqs=warmup_requests,
269
+ ignore_eos=not bench_args.disable_ignore_eos,
270
+ extra_request_body=extra_request_body,
271
+ )
272
+
273
+ logging.info("\nBenchmark...")
274
+ result = throughput_test_once(
275
+ backend_name=bench_args.backend,
276
+ backend=backend,
277
+ reqs=input_requests,
278
+ ignore_eos=not bench_args.disable_ignore_eos,
279
+ extra_request_body=extra_request_body,
280
+ )
281
+
282
+ if bench_args.result_filename:
283
+ with open(bench_args.result_filename, "a") as fout:
284
+ fout.write(json.dumps(result) + "\n")
285
+
286
+ print(
287
+ "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
288
+ )
289
+ print("{:<40} {:<10}".format("Backend:", result["backend"]))
290
+ print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
291
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
292
+ print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
293
+ print(
294
+ "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
295
+ )
296
+ print(
297
+ "{:<40} {:<10.2f}".format(
298
+ "Request throughput (req/s):", result["request_throughput"]
299
+ )
300
+ )
301
+ print(
302
+ "{:<40} {:<10.2f}".format(
303
+ "Input token throughput (tok/s):", result["input_throughput"]
304
+ )
305
+ )
306
+ print(
307
+ "{:<40} {:<10.2f}".format(
308
+ "Output token throughput (tok/s):", result["output_throughput"]
309
+ )
310
+ )
311
+ print(
312
+ "{:<40} {:<10.2f}".format(
313
+ "Total token throughput (tok/s):", result["total_throughput"]
314
+ )
315
+ )
316
+ print("=" * 50)
317
+
318
+ return result
319
+
320
+
321
+ if __name__ == "__main__":
322
+ parser = argparse.ArgumentParser()
323
+ ServerArgs.add_cli_args(parser)
324
+ BenchArgs.add_cli_args(parser)
325
+ args = parser.parse_args()
326
+ server_args = ServerArgs.from_cli_args(args)
327
+ bench_args = BenchArgs.from_cli_args(args)
328
+
329
+ logging.basicConfig(
330
+ level=getattr(logging, server_args.log_level.upper()),
331
+ format="%(message)s",
332
+ )
333
+
334
+ throughput_test(server_args, bench_args)
335
+
336
+ while bench_args.do_not_exit:
337
+ pass