sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,474 @@
1
+ """
2
+ Benchmark the latency of running a single static batch without a server.
3
+
4
+ This script does not launch a server and uses the low-level APIs.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ # Usage (latency test)
8
+ ## with dummy weights:
9
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
11
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
+
13
+ # Usage (correctness test):
14
+ python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
15
+
16
+ ## Reference output (of the correctness test above, can be gpu dependent):
17
+ input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
18
+
19
+ prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
20
+ [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
21
+ [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
22
+ device='cuda:0')
23
+
24
+ prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
25
+ [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
26
+ [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
27
+ device='cuda:0')
28
+
29
+ ========== Prompt 0 ==========
30
+ <s> The capital of France is Paris.
31
+ The capital of the United States is Washington, D.C.
32
+
33
+
34
+ ========== Prompt 1 ==========
35
+ <s> The capital of the United Kindom is London.
36
+ The capital of the United Kingdom is London.
37
+ The capital of the
38
+
39
+ ========== Prompt 2 ==========
40
+ <s> Today is a sunny day and I like to go for a walk in the park.
41
+ I'm going to the park
42
+ """
43
+
44
+ import argparse
45
+ import dataclasses
46
+ import itertools
47
+ import json
48
+ import logging
49
+ import multiprocessing
50
+ import time
51
+ from typing import Tuple
52
+
53
+ import numpy as np
54
+ import torch
55
+ import torch.distributed as dist
56
+
57
+ from sglang.srt.configs.model_config import ModelConfig
58
+ from sglang.srt.hf_transformers_utils import get_tokenizer
59
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
60
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
+ from sglang.srt.model_executor.model_runner import ModelRunner
62
+ from sglang.srt.sampling.sampling_params import SamplingParams
63
+ from sglang.srt.server import _set_envs_and_config
64
+ from sglang.srt.server_args import PortArgs, ServerArgs
65
+ from sglang.srt.utils import (
66
+ configure_logger,
67
+ kill_child_process,
68
+ suppress_other_loggers,
69
+ )
70
+
71
+
72
+ @dataclasses.dataclass
73
+ class BenchArgs:
74
+ run_name: str = "default"
75
+ batch_size: Tuple[int] = (1,)
76
+ input_len: Tuple[int] = (1024,)
77
+ output_len: Tuple[int] = (16,)
78
+ result_filename: str = "result.jsonl"
79
+ correctness_test: bool = False
80
+ # This is only used for correctness test
81
+ cut_len: int = 4
82
+
83
+ @staticmethod
84
+ def add_cli_args(parser: argparse.ArgumentParser):
85
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
86
+ parser.add_argument(
87
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
88
+ )
89
+ parser.add_argument(
90
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
91
+ )
92
+ parser.add_argument(
93
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
94
+ )
95
+ parser.add_argument(
96
+ "--result-filename", type=str, default=BenchArgs.result_filename
97
+ )
98
+ parser.add_argument("--correctness-test", action="store_true")
99
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
100
+
101
+ @classmethod
102
+ def from_cli_args(cls, args: argparse.Namespace):
103
+ # use the default value's type to case the args into correct types.
104
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
105
+ return cls(
106
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
107
+ )
108
+
109
+
110
+ def load_model(server_args, port_args, tp_rank):
111
+ suppress_other_loggers()
112
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
113
+
114
+ model_config = ModelConfig(
115
+ server_args.model_path,
116
+ trust_remote_code=server_args.trust_remote_code,
117
+ context_length=server_args.context_length,
118
+ model_override_args=server_args.json_model_override_args,
119
+ )
120
+ model_runner = ModelRunner(
121
+ model_config=model_config,
122
+ mem_fraction_static=server_args.mem_fraction_static,
123
+ gpu_id=tp_rank,
124
+ tp_rank=tp_rank,
125
+ tp_size=server_args.tp_size,
126
+ nccl_port=port_args.nccl_port,
127
+ server_args=server_args,
128
+ )
129
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
130
+ tokenizer = get_tokenizer(
131
+ server_args.tokenizer_path,
132
+ tokenizer_mode=server_args.tokenizer_mode,
133
+ trust_remote_code=server_args.trust_remote_code,
134
+ )
135
+ if server_args.tp_size > 1:
136
+ dist.barrier()
137
+ return model_runner, tokenizer
138
+
139
+
140
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
141
+ prompts = [
142
+ "The capital of France is",
143
+ "The capital of the United Kindom is",
144
+ "Today is a sunny day and I like",
145
+ ]
146
+ input_ids = [tokenizer.encode(p) for p in prompts]
147
+ sampling_params = SamplingParams(
148
+ temperature=0,
149
+ max_new_tokens=BenchArgs.output_len,
150
+ )
151
+
152
+ reqs = []
153
+ for i in range(len(prompts)):
154
+ assert len(input_ids[i]) > bench_args.cut_len
155
+
156
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
157
+ req = Req(
158
+ rid=i,
159
+ origin_input_text=prompts[i],
160
+ origin_input_ids=tmp_input_ids,
161
+ sampling_params=sampling_params,
162
+ )
163
+ req.prefix_indices = []
164
+ req.fill_ids = req.origin_input_ids
165
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
166
+ reqs.append(req)
167
+
168
+ return input_ids, reqs
169
+
170
+
171
+ def prepare_extend_inputs_for_correctness_test(
172
+ bench_args, input_ids, reqs, model_runner
173
+ ):
174
+ for i in range(len(reqs)):
175
+ req = reqs[i]
176
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
177
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
178
+ i, : bench_args.cut_len
179
+ ]
180
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
181
+ return reqs
182
+
183
+
184
+ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
185
+ input_ids = np.ones((batch_size, input_len), dtype=np.int32)
186
+ sampling_params = SamplingParams(
187
+ temperature=0,
188
+ max_new_tokens=BenchArgs.output_len,
189
+ )
190
+
191
+ reqs = []
192
+ for i in range(len(input_ids)):
193
+ req = Req(
194
+ rid=i,
195
+ origin_input_text="",
196
+ origin_input_ids=list(input_ids[i]),
197
+ sampling_params=sampling_params,
198
+ )
199
+ req.prefix_indices = []
200
+ req.fill_ids = req.origin_input_ids
201
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
202
+ reqs.append(req)
203
+
204
+ return reqs
205
+
206
+
207
+ @torch.no_grad
208
+ def extend(reqs, model_runner):
209
+ batch = ScheduleBatch.init_new(
210
+ reqs=reqs,
211
+ req_to_token_pool=model_runner.req_to_token_pool,
212
+ token_to_kv_pool=model_runner.token_to_kv_pool,
213
+ tree_cache=None,
214
+ model_config=model_runner.model_config,
215
+ )
216
+ batch.prepare_for_extend()
217
+ model_worker_batch = batch.get_model_worker_batch()
218
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
219
+ logits_output = model_runner.forward(forward_batch)
220
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
221
+ return next_token_ids, logits_output.next_token_logits, batch
222
+
223
+
224
+ @torch.no_grad
225
+ def decode(input_token_ids, batch, model_runner):
226
+ batch.output_ids = input_token_ids
227
+ batch.prepare_for_decode()
228
+ model_worker_batch = batch.get_model_worker_batch()
229
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
230
+ logits_output = model_runner.forward(forward_batch)
231
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
232
+ return next_token_ids, logits_output.next_token_logits
233
+
234
+
235
+ def correctness_test(
236
+ server_args,
237
+ port_args,
238
+ bench_args,
239
+ tp_rank,
240
+ ):
241
+ # Configure the logger
242
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
243
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
244
+
245
+ # Load the model
246
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
247
+
248
+ # Prepare inputs
249
+ input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
250
+ rank_print(f"\n{input_ids=}\n")
251
+
252
+ if bench_args.cut_len > 0:
253
+ # Prefill
254
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
255
+ rank_print(f"prefill logits (first half): {next_token_logits} \n")
256
+
257
+ # Prepare extend inputs
258
+ reqs = prepare_extend_inputs_for_correctness_test(
259
+ bench_args, input_ids, reqs, model_runner
260
+ )
261
+
262
+ # Extend (prefill w/ KV cache)
263
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
264
+ rank_print(f"prefill logits (final): {next_token_logits} \n")
265
+
266
+ # Decode
267
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
268
+ for _ in range(bench_args.output_len[0] - 1):
269
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
270
+ next_token_ids_list = next_token_ids.tolist()
271
+ for i in range(len(reqs)):
272
+ output_ids[i].append(next_token_ids_list[i])
273
+
274
+ # Print output texts
275
+ for i in range(len(reqs)):
276
+ rank_print(f"========== Prompt {i} ==========")
277
+ rank_print(tokenizer.decode(output_ids[i]), "\n")
278
+
279
+
280
+ def synchronize(device):
281
+ if device == "cuda":
282
+ torch.cuda.synchronize()
283
+ elif device == "xpu":
284
+ torch.xpu.synchronize()
285
+
286
+
287
+ def latency_test_run_once(
288
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
289
+ ):
290
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
291
+ if batch_size > max_batch_size:
292
+ rank_print(
293
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
294
+ )
295
+ return
296
+
297
+ # Clear the pools.
298
+ model_runner.req_to_token_pool.clear()
299
+ model_runner.token_to_kv_pool.clear()
300
+
301
+ measurement_results = {
302
+ "run_name": run_name,
303
+ "batch_size": batch_size,
304
+ "input_len": input_len,
305
+ "output_len": output_len,
306
+ }
307
+
308
+ tot_latency = 0
309
+
310
+ # Prefill
311
+ synchronize(device)
312
+ tic = time.time()
313
+ next_token_ids, _, batch = extend(reqs, model_runner)
314
+ synchronize(device)
315
+ prefill_latency = time.time() - tic
316
+ tot_latency += prefill_latency
317
+ throughput = input_len * batch_size / prefill_latency
318
+ rank_print(
319
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
320
+ )
321
+ measurement_results["prefill_latency"] = prefill_latency
322
+ measurement_results["prefill_throughput"] = throughput
323
+
324
+ # Decode
325
+ decode_latencies = []
326
+ for i in range(output_len - 1):
327
+ synchronize(device)
328
+ tic = time.time()
329
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
330
+ synchronize(device)
331
+ latency = time.time() - tic
332
+ tot_latency += latency
333
+ throughput = batch_size / latency
334
+ decode_latencies.append(latency)
335
+ if i < 5:
336
+ rank_print(
337
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
338
+ )
339
+
340
+ # Record decode timing from 2nd output
341
+ if output_len > 1:
342
+ med_decode_latency = np.median(decode_latencies)
343
+ med_decode_throughput = batch_size / med_decode_latency
344
+ rank_print(
345
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
346
+ )
347
+ measurement_results["median_decode_latency"] = med_decode_latency
348
+ measurement_results["median_decode_throughput"] = med_decode_throughput
349
+
350
+ throughput = (input_len + output_len) * batch_size / tot_latency
351
+ rank_print(
352
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
353
+ )
354
+ measurement_results["total_latency"] = tot_latency
355
+ measurement_results["overall_throughput"] = throughput
356
+ return measurement_results
357
+
358
+
359
+ def latency_test(
360
+ server_args,
361
+ port_args,
362
+ bench_args,
363
+ tp_rank,
364
+ ):
365
+ # Configure the logger
366
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
367
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
368
+
369
+ # Load the model
370
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
371
+
372
+ # Prepare inputs for warm up
373
+ reqs = prepare_synthetic_inputs_for_latency_test(
374
+ bench_args.batch_size[0], bench_args.input_len[0]
375
+ )
376
+
377
+ # Warm up
378
+ rank_print("Warmup ...")
379
+ latency_test_run_once(
380
+ bench_args.run_name,
381
+ model_runner,
382
+ rank_print,
383
+ reqs,
384
+ bench_args.batch_size[0],
385
+ bench_args.input_len[0],
386
+ 8, # shorter decoding to speed up the warmup
387
+ server_args.device,
388
+ )
389
+ rank_print("Benchmark ...")
390
+
391
+ # Run the sweep
392
+ result_list = []
393
+ for bs, il, ol in itertools.product(
394
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
395
+ ):
396
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
397
+ ret = latency_test_run_once(
398
+ bench_args.run_name,
399
+ model_runner,
400
+ rank_print,
401
+ reqs,
402
+ bs,
403
+ il,
404
+ ol,
405
+ server_args.device,
406
+ )
407
+ if ret is not None:
408
+ result_list.append(ret)
409
+
410
+ # Write results in jsonlines format on rank 0.
411
+ if tp_rank == 0 and bench_args.result_filename:
412
+ with open(bench_args.result_filename, "a") as fout:
413
+ for result in result_list:
414
+ fout.write(json.dumps(result) + "\n")
415
+
416
+
417
+ def main(server_args, bench_args):
418
+ _set_envs_and_config(server_args)
419
+
420
+ if server_args.model_path:
421
+ if bench_args.correctness_test:
422
+ work_func = correctness_test
423
+ else:
424
+ work_func = latency_test
425
+ else:
426
+ raise ValueError(
427
+ "Provide --model-path for running the tests or "
428
+ "provide --result-filename for plotting the results"
429
+ )
430
+
431
+ port_args = PortArgs.init_new(server_args)
432
+
433
+ if server_args.tp_size == 1:
434
+ work_func(server_args, port_args, bench_args, 0)
435
+ else:
436
+ workers = []
437
+ for tp_rank in range(server_args.tp_size):
438
+ proc = multiprocessing.Process(
439
+ target=work_func,
440
+ args=(
441
+ server_args,
442
+ port_args,
443
+ bench_args,
444
+ tp_rank,
445
+ ),
446
+ )
447
+ proc.start()
448
+ workers.append(proc)
449
+
450
+ for proc in workers:
451
+ proc.join()
452
+
453
+ proc.terminate()
454
+
455
+
456
+ if __name__ == "__main__":
457
+ parser = argparse.ArgumentParser()
458
+ ServerArgs.add_cli_args(parser)
459
+ BenchArgs.add_cli_args(parser)
460
+ args = parser.parse_args()
461
+ server_args = ServerArgs.from_cli_args(args)
462
+ bench_args = BenchArgs.from_cli_args(args)
463
+
464
+ logging.basicConfig(
465
+ level=getattr(logging, server_args.log_level.upper()),
466
+ format="%(message)s",
467
+ )
468
+
469
+ try:
470
+ main(server_args, bench_args)
471
+ except Exception as e:
472
+ raise e
473
+ finally:
474
+ kill_child_process()
@@ -1,10 +1,10 @@
1
1
  """
2
- Benchmark the latency of serving a single batch with a real server.
2
+ Benchmark the latency of running a single batch with a server.
3
+
3
4
  This script launches a server and uses the HTTP interface.
4
- It accepts arguments similar to those of launch_server.py.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
5
6
 
6
7
  Usage:
7
-
8
8
  python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
9
 
10
10
  python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
sglang/bench_serving.py CHANGED
@@ -15,6 +15,7 @@ import argparse
15
15
  import asyncio
16
16
  import json
17
17
  import os
18
+ import pickle
18
19
  import random
19
20
  import resource
20
21
  import sys
@@ -387,6 +388,24 @@ async def async_request_gserver(
387
388
  raise NotImplementedError()
388
389
 
389
390
 
391
+ async def async_request_profile(api_url: str) -> RequestFuncOutput:
392
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
393
+ output = RequestFuncOutput()
394
+ try:
395
+ async with session.post(url=api_url) as response:
396
+ if response.status == 200:
397
+ output.success = True
398
+ else:
399
+ output.error = response.reason or ""
400
+ output.success = False
401
+ except Exception:
402
+ output.success = False
403
+ exc_info = sys.exc_info()
404
+ output.error = "".join(traceback.format_exception(*exc_info))
405
+
406
+ return output
407
+
408
+
390
409
  def get_model(pretrained_model_name_or_path: str) -> str:
391
410
  if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
392
411
  import huggingface_hub.constants
@@ -682,6 +701,11 @@ def sample_generated_shared_prefix_requests(
682
701
  output_len: int,
683
702
  tokenizer: PreTrainedTokenizerBase,
684
703
  ) -> List[Tuple[str, int, int]]:
704
+ if args.generated_input_path and os.path.exists(args.generated_input_path):
705
+ print(f"\nloading generated input data from {args.generated_input_path}")
706
+ with open(args.generated_input_path, "rb") as f:
707
+ return pickle.load(f)
708
+
685
709
  """Generate benchmark requests with shared system prompts using random tokens."""
686
710
  # Generate system prompts for each group
687
711
  system_prompts = []
@@ -695,6 +719,9 @@ def sample_generated_shared_prefix_requests(
695
719
  question = gen_prompt(tokenizer, question_len)
696
720
  questions.append(question)
697
721
 
722
+ # Shuffle questions
723
+ random.shuffle(questions)
724
+
698
725
  # Combine system prompts with questions
699
726
  input_requests = []
700
727
  total_input_tokens = 0
@@ -723,6 +750,11 @@ def sample_generated_shared_prefix_requests(
723
750
  print(
724
751
  f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
725
752
  )
753
+ if args.generated_input_save_path:
754
+ print(f"Saving generated input data to {args.generated_input_save_path}")
755
+ os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
756
+ with open(args.generated_input_save_path, "wb") as f:
757
+ pickle.dump(input_requests, f)
726
758
 
727
759
  return input_requests
728
760
 
@@ -822,12 +854,14 @@ def calculate_metrics(
822
854
  async def benchmark(
823
855
  backend: str,
824
856
  api_url: str,
857
+ base_url: str,
825
858
  model_id: str,
826
859
  tokenizer: PreTrainedTokenizerBase,
827
860
  input_requests: List[Tuple[str, int, int]],
828
861
  request_rate: float,
829
862
  disable_tqdm: bool,
830
863
  extra_request_body: Dict[str, Any],
864
+ profile: bool,
831
865
  ):
832
866
  if backend in ASYNC_REQUEST_FUNCS:
833
867
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -855,6 +889,14 @@ async def benchmark(
855
889
 
856
890
  time.sleep(1.5)
857
891
 
892
+ if profile:
893
+ print("Starting profiler...")
894
+ profile_output = await async_request_profile(
895
+ api_url=base_url + "/start_profile"
896
+ )
897
+ if profile_output.success:
898
+ print("Profiler started")
899
+
858
900
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
859
901
 
860
902
  benchmark_start_time = time.perf_counter()
@@ -876,6 +918,12 @@ async def benchmark(
876
918
  )
877
919
  outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
878
920
 
921
+ if profile:
922
+ print("Stopping profiler...")
923
+ profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
924
+ if profile_output.success:
925
+ print("Profiler stopped")
926
+
879
927
  if pbar is not None:
880
928
  pbar.close()
881
929
 
@@ -1100,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
1100
1148
  if args.base_url
1101
1149
  else f"http://{args.host}:{args.port}/v1/models/model:predict"
1102
1150
  )
1151
+ base_url = (
1152
+ f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
1153
+ )
1103
1154
 
1104
1155
  # Get model name
1105
1156
  if args.model is None:
@@ -1145,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
1145
1196
  benchmark(
1146
1197
  backend=backend,
1147
1198
  api_url=api_url,
1199
+ base_url=base_url,
1148
1200
  model_id=model_id,
1149
1201
  tokenizer=tokenizer,
1150
1202
  input_requests=input_requests,
1151
1203
  request_rate=args.request_rate,
1152
1204
  disable_tqdm=args.disable_tqdm,
1153
1205
  extra_request_body=extra_request_body,
1206
+ profile=args.profile,
1154
1207
  )
1155
1208
  )
1156
1209
  else:
@@ -1162,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
1162
1215
  benchmark(
1163
1216
  backend=backend,
1164
1217
  api_url=api_url,
1218
+ base_url=base_url,
1165
1219
  model_id=model_id,
1166
1220
  tokenizer=tokenizer,
1167
1221
  input_requests=input_requests,
1168
1222
  request_rate=rate,
1169
1223
  disable_tqdm=args.disable_tqdm,
1170
1224
  extra_request_body=extra_request_body,
1225
+ profile=args.profile,
1171
1226
  )
1172
1227
  )
1173
1228
 
@@ -1331,6 +1386,21 @@ if __name__ == "__main__":
1331
1386
  default=256,
1332
1387
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
1333
1388
  )
1334
-
1389
+ parser.add_argument(
1390
+ "--generated-input-save-path",
1391
+ type=str,
1392
+ help="Path to save generated input data",
1393
+ )
1394
+ parser.add_argument(
1395
+ "--generated-input-path",
1396
+ type=str,
1397
+ help="Path to load previously generated input data",
1398
+ )
1399
+ parser.add_argument(
1400
+ "--profile",
1401
+ action="store_true",
1402
+ help="Use Torch Profiler. The endpoint must be launched with "
1403
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1404
+ )
1335
1405
  args = parser.parse_args()
1336
1406
  run_benchmark(args)
sglang/check_env.py CHANGED
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
15
15
  "flashinfer",
16
16
  "triton",
17
17
  "transformers",
18
- "requests",
19
- "tqdm",
18
+ "torchao",
20
19
  "numpy",
21
20
  "aiohttp",
22
21
  "fastapi",
23
22
  "hf_transfer",
24
23
  "huggingface_hub",
25
24
  "interegular",
26
- "packaging",
27
- "PIL",
28
25
  "psutil",
29
26
  "pydantic",
27
+ "multipart",
28
+ "zmq",
30
29
  "uvicorn",
31
30
  "uvloop",
32
- "zmq",
33
31
  "vllm",
34
32
  "outlines",
35
- "multipart",
36
33
  "openai",
37
34
  "tiktoken",
38
35
  "anthropic",