sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -22,6 +22,11 @@ from sglang.api import (
22
22
  user_end,
23
23
  video,
24
24
  )
25
+ from sglang.lang.choices import (
26
+ greedy_token_selection,
27
+ token_length_normalized,
28
+ unconditional_likelihood_normalized,
29
+ )
25
30
 
26
31
  # SGLang DSL APIs
27
32
  __all__ = [
@@ -45,6 +50,9 @@ __all__ = [
45
50
  "user_begin",
46
51
  "user_end",
47
52
  "video",
53
+ "greedy_token_selection",
54
+ "token_length_normalized",
55
+ "unconditional_likelihood_normalized",
48
56
  ]
49
57
 
50
58
  # Global Configurations
sglang/api.py CHANGED
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
6
6
 
7
7
  from sglang.global_config import global_config
8
8
  from sglang.lang.backend.base_backend import BaseBackend
9
+ from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
9
10
  from sglang.lang.ir import (
10
11
  SglExpr,
11
12
  SglExprList,
@@ -73,12 +74,18 @@ def gen(
73
74
  return_text_in_logprobs: Optional[bool] = None,
74
75
  dtype: Optional[type] = None,
75
76
  choices: Optional[List[str]] = None,
77
+ choices_method: Optional[ChoicesSamplingMethod] = None,
76
78
  regex: Optional[str] = None,
77
79
  ):
78
80
  """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
79
81
 
80
82
  if choices:
81
- return SglSelect(name, choices, 0.0 if temperature is None else temperature)
83
+ return SglSelect(
84
+ name,
85
+ choices,
86
+ 0.0 if temperature is None else temperature,
87
+ token_length_normalized if choices_method is None else choices_method,
88
+ )
82
89
 
83
90
  # check regex is valid
84
91
  if regex is not None:
@@ -186,9 +193,10 @@ def select(
186
193
  name: Optional[str] = None,
187
194
  choices: Optional[List[str]] = None,
188
195
  temperature: float = 0.0,
196
+ choices_method: ChoicesSamplingMethod = token_length_normalized,
189
197
  ):
190
198
  assert choices is not None
191
- return SglSelect(name, choices, temperature)
199
+ return SglSelect(name, choices, temperature, choices_method)
192
200
 
193
201
 
194
202
  def _role_common(name: str, expr: Optional[SglExpr] = None):
sglang/bench_latency.py CHANGED
@@ -1,13 +1,21 @@
1
1
  """
2
2
  Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
3
 
4
- # Usage (latency test) with dummy weights:
4
+ # Usage (latency test)
5
+ ## with dummy weights:
5
6
  python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
7
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
8
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
9
+ ## do some changes, and store the results under a different run_name:
10
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
11
+ ## plot the results in series of lines:
12
+ python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
13
+
6
14
 
7
15
  # Usage (correctness test):
8
16
  python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
17
 
10
- ### Reference output (of the correctness test above, can be gpu dependent):
18
+ ## Reference output (of the correctness test above, can be gpu dependent):
11
19
  prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
20
  [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
21
  [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
@@ -28,19 +36,23 @@ I'm going to the park
28
36
 
29
37
  import argparse
30
38
  import dataclasses
39
+ import itertools
31
40
  import logging
32
41
  import multiprocessing
42
+ import os
43
+ import sqlite3
33
44
  import time
34
45
  from typing import Tuple
35
46
 
36
- import jsonlines
37
47
  import numpy as np
48
+ import pandas as pd
38
49
  import torch
39
50
  import torch.distributed as dist
40
51
 
41
52
  from sglang.srt.hf_transformers_utils import get_tokenizer
42
- from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
53
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
43
54
  from sglang.srt.model_config import ModelConfig
55
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
44
56
  from sglang.srt.model_executor.model_runner import ModelRunner
45
57
  from sglang.srt.sampling_params import SamplingParams
46
58
  from sglang.srt.server_args import ServerArgs
@@ -49,26 +61,42 @@ from sglang.srt.utils import suppress_other_loggers
49
61
 
50
62
  @dataclasses.dataclass
51
63
  class BenchArgs:
64
+ run_name: str = "before"
52
65
  batch_size: Tuple[int] = (1,)
53
- input_len: int = 1024
54
- output_len: int = 4
66
+ input_len: Tuple[int] = (1024,)
67
+ output_len: Tuple[int] = (4,)
55
68
  result_filename: str = ""
56
69
  correctness_test: bool = False
57
70
  # This is only used for correctness test
58
71
  cut_len: int = 4
72
+ # Plotting args
73
+ graph_sql: str = (
74
+ "select run_name, batch_size, prefill_throughput from results where run_name='before'"
75
+ )
76
+ graph_filename: str = "out.png"
59
77
 
60
78
  @staticmethod
61
79
  def add_cli_args(parser: argparse.ArgumentParser):
80
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
62
81
  parser.add_argument(
63
82
  "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
64
83
  )
65
- parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
66
- parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
84
+ parser.add_argument(
85
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
86
+ )
87
+ parser.add_argument(
88
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
89
+ )
67
90
  parser.add_argument(
68
91
  "--result-filename", type=str, default=BenchArgs.result_filename
69
92
  )
70
93
  parser.add_argument("--correctness-test", action="store_true")
71
94
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
95
+ # graphing
96
+ parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
97
+ parser.add_argument(
98
+ "--graph-filename", type=str, default=BenchArgs.graph_filename
99
+ )
72
100
 
73
101
  @classmethod
74
102
  def from_cli_args(cls, args: argparse.Namespace):
@@ -124,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
124
152
  req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
125
153
  req.prefix_indices = []
126
154
  req.sampling_params = sampling_params
127
- req.input_ids = req.origin_input_ids
155
+ req.fill_ids = req.origin_input_ids
128
156
  reqs.append(req)
129
157
 
130
158
  return input_ids, reqs
@@ -135,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
135
163
  ):
136
164
  for i in range(len(reqs)):
137
165
  req = reqs[i]
138
- req.input_ids += input_ids[i][bench_args.cut_len :]
166
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
139
167
  req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
140
168
  i, : bench_args.cut_len
141
169
  ]
@@ -154,14 +182,14 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
154
182
  req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
155
183
  req.prefix_indices = []
156
184
  req.sampling_params = sampling_params
157
- req.input_ids = req.origin_input_ids
185
+ req.fill_ids = req.origin_input_ids
158
186
  reqs.append(req)
159
187
 
160
188
  return reqs
161
189
 
162
190
 
163
191
  def extend(reqs, model_runner):
164
- batch = Batch.init_new(
192
+ batch = ScheduleBatch.init_new(
165
193
  reqs=reqs,
166
194
  req_to_token_pool=model_runner.req_to_token_pool,
167
195
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -210,7 +238,7 @@ def correctness_test(
210
238
 
211
239
  # Decode
212
240
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
213
- for _ in range(bench_args.output_len):
241
+ for _ in range(bench_args.output_len[0]):
214
242
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
215
243
  for i in range(len(reqs)):
216
244
  output_ids[i].append(next_token_ids[i])
@@ -222,15 +250,21 @@ def correctness_test(
222
250
 
223
251
  @torch.inference_mode()
224
252
  def latency_test_run_once(
225
- model_runner, rank_print, reqs, batch_size, input_len, output_len
253
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
226
254
  ):
255
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
256
+ if batch_size > max_batch_size:
257
+ rank_print(
258
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
259
+ )
260
+ return
227
261
 
228
262
  # Clear the pools.
229
263
  model_runner.req_to_token_pool.clear()
230
264
  model_runner.token_to_kv_pool.clear()
231
265
 
232
266
  measurement_results = {
233
- "run_name": "before",
267
+ "run_name": run_name,
234
268
  "batch_size": batch_size,
235
269
  "input_len": input_len,
236
270
  "output_len": output_len,
@@ -291,49 +325,121 @@ def latency_test(
291
325
 
292
326
  # Load the model
293
327
  model_runner, tokenizer = load_model(server_args, tp_rank)
294
- rank_print(
295
- f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
296
- )
297
328
 
298
- # To make this PR easier to review, for now, only do the first element in batch_size tuple.
299
- bench_args.batch_size = bench_args.batch_size[0]
300
-
301
- # Prepare inputs
329
+ # Prepare inputs for warm up
302
330
  reqs = prepare_synthetic_inputs_for_latency_test(
303
- bench_args.batch_size, bench_args.input_len
331
+ bench_args.batch_size[0], bench_args.input_len[0]
304
332
  )
305
333
 
306
334
  # Warm up
335
+ rank_print("Warmup ...")
307
336
  latency_test_run_once(
308
- model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
337
+ bench_args.run_name,
338
+ model_runner,
339
+ rank_print,
340
+ reqs,
341
+ bench_args.batch_size[0],
342
+ bench_args.input_len[0],
343
+ 4, # shorter decoding to speed up the warmup
309
344
  )
345
+ rank_print("Benchmark ...")
310
346
 
311
- # Run again
347
+ # Run the sweep
312
348
  result_list = []
313
- result_list.append(
314
- latency_test_run_once(
315
- model_runner,
316
- rank_print,
317
- reqs,
318
- bench_args.batch_size,
319
- bench_args.input_len,
320
- bench_args.output_len,
349
+ for bs, il, ol in itertools.product(
350
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
351
+ ):
352
+ req = prepare_synthetic_inputs_for_latency_test(bs, il)
353
+ ret = latency_test_run_once(
354
+ bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
321
355
  )
322
- )
356
+ if ret is not None:
357
+ result_list.append(ret)
358
+
359
+ # Write results in jsonlines format on rank 0.
360
+ if tp_rank == 0 and bench_args.result_filename:
361
+ import jsonlines
323
362
 
324
- # Write results in jsonlines format.
325
- if bench_args.result_filename:
326
363
  with jsonlines.open(bench_args.result_filename, "a") as f:
327
364
  f.write_all(result_list)
328
365
 
329
366
 
367
+ def plot_latency_test(
368
+ server_args,
369
+ bench_args,
370
+ tp_rank,
371
+ ):
372
+ assert tp_rank == 0
373
+
374
+ # read the jsonl file and put in sqlite
375
+ df = pd.read_json(bench_args.result_filename, lines=True)
376
+ conn = sqlite3.connect(":memory:")
377
+ cur = conn.cursor()
378
+
379
+ # get the columns and their types
380
+ column_names = list(df.iloc[0].keys())
381
+ type_dict = {
382
+ str: "TEXT",
383
+ np.int64: "INTEGER",
384
+ np.float64: "FLOAT",
385
+ }
386
+ column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
387
+
388
+ # create the table
389
+ cur.execute(
390
+ f"""
391
+ CREATE TABLE IF NOT EXISTS results (
392
+ {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
393
+ )
394
+ """
395
+ )
396
+ conn.commit()
397
+
398
+ # write the results to DB
399
+ df.to_sql("results", conn, if_exists="replace", index=False)
400
+ conn.commit()
401
+
402
+ # read it back using sql
403
+ df = pd.read_sql_query(bench_args.graph_sql, conn)
404
+ conn.close()
405
+
406
+ # plot it and save to a file
407
+ import matplotlib.pyplot as plt
408
+
409
+ assert (
410
+ len(df.columns) == 3
411
+ ), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
412
+ for label in df[df.columns[0]].unique():
413
+ q = f"{df.columns[0]}=='{label}'"
414
+ series = df.query(q)
415
+ plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
416
+ plt.xlabel(df.columns[1])
417
+ plt.ylabel(df.columns[2])
418
+ plt.legend()
419
+ plt.savefig(bench_args.graph_filename, dpi=300)
420
+
421
+ # if in kitty, just dump it to the terminal
422
+ if os.environ["TERM"] == "xterm-kitty":
423
+ os.system(
424
+ f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
425
+ )
426
+
427
+
330
428
  def main(server_args, bench_args):
331
- print(bench_args)
332
429
 
333
- if bench_args.correctness_test:
334
- work_func = correctness_test
430
+ if server_args.model_path:
431
+ if bench_args.correctness_test:
432
+ work_func = correctness_test
433
+ else:
434
+ work_func = latency_test
435
+ elif os.path.isfile(bench_args.result_filename):
436
+ assert bench_args.graph_filename, "please provide a filename for the graph"
437
+ work_func = plot_latency_test
335
438
  else:
336
- work_func = latency_test
439
+ raise ValueError(
440
+ "Provide --model-path for running the tests or "
441
+ "provide --result-filename for plotting the results"
442
+ )
337
443
 
338
444
  if server_args.tp_size == 1:
339
445
  work_func(server_args, bench_args, 0)
@@ -361,6 +467,11 @@ if __name__ == "__main__":
361
467
  parser = argparse.ArgumentParser()
362
468
  ServerArgs.add_cli_args(parser)
363
469
  BenchArgs.add_cli_args(parser)
470
+ # For this script, model-path is not required
471
+ assert (
472
+ parser._actions[1].option_strings[0] == "--model-path"
473
+ ), "options changed, this code need to be updated"
474
+ parser._actions[1].required = False
364
475
  args = parser.parse_args()
365
476
 
366
477
  server_args = ServerArgs.from_cli_args(args)
sglang/bench_serving.py CHANGED
@@ -24,7 +24,7 @@ import warnings
24
24
  from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
- from typing import AsyncGenerator, List, Optional, Tuple, Union
27
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
28
28
 
29
29
  import aiohttp
30
30
  import numpy as np
@@ -39,6 +39,8 @@ from transformers import (
39
39
 
40
40
  AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
41
41
 
42
+ global args
43
+
42
44
 
43
45
  @dataclass
44
46
  class RequestFuncInput:
@@ -47,6 +49,7 @@ class RequestFuncInput:
47
49
  prompt_len: int
48
50
  output_len: int
49
51
  model: str
52
+ extra_request_body: Dict[str, Any]
50
53
 
51
54
 
52
55
  @dataclass
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
84
87
  "stream": True,
85
88
  "min_length": request_func_input.output_len,
86
89
  "end_id": 1048576,
90
+ **request_func_input.extra_request_body,
87
91
  }
88
92
  if args.disable_ignore_eos:
89
93
  del payload["min_length"]
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
154
158
  "max_tokens": request_func_input.output_len,
155
159
  "stream": not args.disable_stream,
156
160
  "ignore_eos": not args.disable_ignore_eos,
161
+ **request_func_input.extra_request_body,
157
162
  }
158
163
  headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
159
164
 
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
192
197
  output.ttft = ttft
193
198
 
194
199
  # Decoding phase
195
- output.itl.append(timestamp - most_recent_timestamp)
200
+ else:
201
+ output.itl.append(timestamp - most_recent_timestamp)
196
202
 
197
203
  most_recent_timestamp = timestamp
198
204
  generated_text += data["choices"][0]["text"]
@@ -542,6 +548,7 @@ async def benchmark(
542
548
  request_rate: float,
543
549
  disable_tqdm: bool,
544
550
  enable_multi: bool,
551
+ extra_request_body: Dict[str, Any],
545
552
  ):
546
553
  if backend in ASYNC_REQUEST_FUNCS:
547
554
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -556,6 +563,7 @@ async def benchmark(
556
563
  api_url=api_url,
557
564
  prompt_len=test_prompt_len,
558
565
  output_len=test_output_len,
566
+ extra_request_body=extra_request_body,
559
567
  )
560
568
  test_output = await request_func(request_func_input=test_input)
561
569
  if not test_output.success:
@@ -578,6 +586,7 @@ async def benchmark(
578
586
  api_url=api_url,
579
587
  prompt_len=prompt_len,
580
588
  output_len=output_len,
589
+ extra_request_body=extra_request_body,
581
590
  )
582
591
  tasks.append(
583
592
  asyncio.create_task(
@@ -660,19 +669,20 @@ async def benchmark(
660
669
  "backend": args.backend,
661
670
  "dataset_name": args.dataset_name,
662
671
  "request_rate": request_rate,
663
- "total_input": metrics.total_input,
664
- "total_output": metrics.total_output,
665
- "total_output_retokenized": metrics.total_output_retokenized,
666
- "mean_e2e_latency": metrics.mean_e2e_latency_ms,
667
- "median_e2e_latency": metrics.median_e2e_latency_ms,
668
- "median_ttft": metrics.median_ttft_ms,
669
- "median_itl": metrics.median_itl_ms,
670
- "output_token_throughput": metrics.output_throughput,
672
+ "total_input_tokens": metrics.total_input,
673
+ "total_output_tokens": metrics.total_output,
674
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
675
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
676
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
677
+ "median_ttft_ms": metrics.median_ttft_ms,
678
+ "median_itl_ms": metrics.median_itl_ms,
679
+ "output_throughput": metrics.output_throughput,
671
680
  "sharegpt_output_len": args.sharegpt_output_len,
672
681
  "random_input_len": args.random_input_len,
673
682
  "random_output_len": args.random_output_len,
674
683
  "random_range_ratio": args.random_range_ratio,
675
- "benchmark_duration": benchmark_duration,
684
+ "duration": benchmark_duration,
685
+ "completed": metrics.completed,
676
686
  }
677
687
  else:
678
688
  print(f"Error running benchmark for request rate: {request_rate}")
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
742
752
  return False
743
753
 
744
754
 
745
- def fire(args: argparse.Namespace):
755
+ def run_benchmark(args_: argparse.Namespace):
756
+ global args
757
+ args = args_
758
+
759
+ set_ulimit()
746
760
  random.seed(args.seed)
747
761
  np.random.seed(args.seed)
748
762
 
763
+ extra_request_body = {}
764
+ if args.extra_request_body:
765
+ extra_request_body = json.loads(args.extra_request_body)
766
+
749
767
  if args.port is None:
750
768
  args.port = {
751
769
  "sglang": 30000,
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
838
856
  request_rate=rate,
839
857
  disable_tqdm=args.disable_tqdm,
840
858
  enable_multi=args.multi,
859
+ extra_request_body=extra_request_body,
841
860
  )
842
861
  )
843
862
  else:
844
- asyncio.run(
863
+ return asyncio.run(
845
864
  benchmark(
846
865
  backend=backend,
847
866
  api_url=api_url,
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
851
870
  request_rate=args.request_rate,
852
871
  disable_tqdm=args.disable_tqdm,
853
872
  enable_multi=args.multi,
873
+ extra_request_body=extra_request_body,
854
874
  )
855
875
  )
856
876
 
@@ -949,11 +969,6 @@ if __name__ == "__main__":
949
969
  "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
950
970
  )
951
971
  parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
952
- parser.add_argument(
953
- "--disable-tqdm",
954
- action="store_true",
955
- help="Specify to disable tqdm progress bar.",
956
- )
957
972
  parser.add_argument(
958
973
  "--multi",
959
974
  action="store_true",
@@ -966,6 +981,11 @@ if __name__ == "__main__":
966
981
  help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
967
982
  )
968
983
  parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
984
+ parser.add_argument(
985
+ "--disable-tqdm",
986
+ action="store_true",
987
+ help="Specify to disable tqdm progress bar.",
988
+ )
969
989
  parser.add_argument(
970
990
  "--disable-stream",
971
991
  action="store_true",
@@ -976,8 +996,12 @@ if __name__ == "__main__":
976
996
  action="store_true",
977
997
  help="Disable ignoring EOS.",
978
998
  )
979
-
980
- set_ulimit()
981
-
999
+ parser.add_argument(
1000
+ "--extra-request-body",
1001
+ metavar='{"key1": "value1", "key2": "value2"}',
1002
+ type=str,
1003
+ help="Append given JSON object to the request payload. You can use this to specify"
1004
+ "additional generate params like sampling params.",
1005
+ )
982
1006
  args = parser.parse_args()
983
- fire(args)
1007
+ run_benchmark(args)
sglang/check_env.py CHANGED
@@ -14,6 +14,7 @@ PACKAGE_LIST = [
14
14
  "sglang",
15
15
  "flashinfer",
16
16
  "triton",
17
+ "transformers",
17
18
  "requests",
18
19
  "tqdm",
19
20
  "numpy",
@@ -73,10 +74,26 @@ def _get_gpu_info():
73
74
  Get information about available GPUs.
74
75
  """
75
76
  devices = defaultdict(list)
77
+ capabilities = defaultdict(list)
76
78
  for k in range(torch.cuda.device_count()):
77
79
  devices[torch.cuda.get_device_name(k)].append(str(k))
80
+ capability = torch.cuda.get_device_capability(k)
81
+ capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
78
82
 
79
- return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
83
+ gpu_info = {}
84
+ for name, device_ids in devices.items():
85
+ gpu_info[f"GPU {','.join(device_ids)}"] = name
86
+
87
+ if len(capabilities) == 1:
88
+ # All GPUs have the same compute capability
89
+ cap, gpu_ids = list(capabilities.items())[0]
90
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
91
+ else:
92
+ # GPUs have different compute capabilities
93
+ for cap, gpu_ids in capabilities.items():
94
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
95
+
96
+ return gpu_info
80
97
 
81
98
 
82
99
  def _get_cuda_version_info():
@@ -118,6 +135,7 @@ def _get_cuda_driver_version():
118
135
  """
119
136
  Get CUDA driver version.
120
137
  """
138
+ versions = set()
121
139
  try:
122
140
  output = subprocess.check_output(
123
141
  [
@@ -126,7 +144,11 @@ def _get_cuda_driver_version():
126
144
  "--format=csv,noheader,nounits",
127
145
  ]
128
146
  )
129
- return {"CUDA Driver Version": output.decode().strip()}
147
+ versions = set(output.decode().strip().split("\n"))
148
+ if len(versions) == 1:
149
+ return {"CUDA Driver Version": versions.pop()}
150
+ else:
151
+ return {"CUDA Driver Versions": ", ".join(sorted(versions))}
130
152
  except subprocess.SubprocessError:
131
153
  return {"CUDA Driver Version": "Not Available"}
132
154
 
sglang/global_config.py CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
19
19
  self.init_new_token_ratio = 0.7
20
20
  self.base_min_new_token_ratio = 0.1
21
21
  self.new_token_ratio_decay = 0.001
22
- self.new_token_ratio_recovery = 0.05
23
22
 
24
23
  # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
25
24
  # This can improve the speed for large batch sizes during prefill.
@@ -1,6 +1,7 @@
1
1
  from typing import Callable, List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
4
5
  from sglang.lang.interpreter import StreamExecutor
5
6
  from sglang.lang.ir import SglSamplingParams
6
7
 
@@ -64,7 +65,8 @@ class BaseBackend:
64
65
  s: StreamExecutor,
65
66
  choices: List[str],
66
67
  temperature: float,
67
- ):
68
+ choices_method: Optional[ChoicesSamplingMethod] = None,
69
+ ) -> ChoicesDecision:
68
70
  raise NotImplementedError()
69
71
 
70
72
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -8,6 +8,7 @@ import numpy as np
8
8
 
9
9
  from sglang.lang.backend.base_backend import BaseBackend
10
10
  from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
11
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
11
12
  from sglang.lang.interpreter import StreamExecutor
12
13
  from sglang.lang.ir import SglSamplingParams
13
14
 
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
296
297
  s: StreamExecutor,
297
298
  choices: List[str],
298
299
  temperature: float,
299
- ):
300
+ choices_method: ChoicesSamplingMethod,
301
+ ) -> ChoicesDecision:
302
+ """Note: `choices_method` is not used by the OpenAI backend."""
300
303
  if self.is_chat_model:
301
304
  raise NotImplementedError(
302
305
  "select/choices is not supported for chat models. "
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
354
357
 
355
358
  prompt_tokens.append(ret_token)
356
359
 
357
- decision = choices[np.argmax(scores)]
358
- return decision, scores, None, None
360
+ return ChoicesDecision(
361
+ decision=choices[np.argmax(scores)],
362
+ meta_info={"scores": scores},
363
+ )
359
364
 
360
365
 
361
366
  def openai_completion(