sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -22,6 +22,11 @@ from sglang.api import (
22
22
  user_end,
23
23
  video,
24
24
  )
25
+ from sglang.lang.choices import (
26
+ greedy_token_selection,
27
+ token_length_normalized,
28
+ unconditional_likelihood_normalized,
29
+ )
25
30
 
26
31
  # SGLang DSL APIs
27
32
  __all__ = [
@@ -45,6 +50,9 @@ __all__ = [
45
50
  "user_begin",
46
51
  "user_end",
47
52
  "video",
53
+ "greedy_token_selection",
54
+ "token_length_normalized",
55
+ "unconditional_likelihood_normalized",
48
56
  ]
49
57
 
50
58
  # Global Configurations
sglang/api.py CHANGED
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
6
6
 
7
7
  from sglang.global_config import global_config
8
8
  from sglang.lang.backend.base_backend import BaseBackend
9
+ from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
9
10
  from sglang.lang.ir import (
10
11
  SglExpr,
11
12
  SglExprList,
@@ -73,12 +74,18 @@ def gen(
73
74
  return_text_in_logprobs: Optional[bool] = None,
74
75
  dtype: Optional[type] = None,
75
76
  choices: Optional[List[str]] = None,
77
+ choices_method: Optional[ChoicesSamplingMethod] = None,
76
78
  regex: Optional[str] = None,
77
79
  ):
78
80
  """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
79
81
 
80
82
  if choices:
81
- return SglSelect(name, choices, 0.0 if temperature is None else temperature)
83
+ return SglSelect(
84
+ name,
85
+ choices,
86
+ 0.0 if temperature is None else temperature,
87
+ token_length_normalized if choices_method is None else choices_method,
88
+ )
82
89
 
83
90
  # check regex is valid
84
91
  if regex is not None:
@@ -186,9 +193,10 @@ def select(
186
193
  name: Optional[str] = None,
187
194
  choices: Optional[List[str]] = None,
188
195
  temperature: float = 0.0,
196
+ choices_method: ChoicesSamplingMethod = token_length_normalized,
189
197
  ):
190
198
  assert choices is not None
191
- return SglSelect(name, choices, temperature)
199
+ return SglSelect(name, choices, temperature, choices_method)
192
200
 
193
201
 
194
202
  def _role_common(name: str, expr: Optional[SglExpr] = None):
sglang/bench_latency.py CHANGED
@@ -1,13 +1,21 @@
1
1
  """
2
2
  Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
3
 
4
- # Usage (latency test):
4
+ # Usage (latency test)
5
+ ## with dummy weights:
5
6
  python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
7
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
8
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
9
+ ## do some changes, and store the results under a different run_name:
10
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
11
+ ## plot the results in series of lines:
12
+ python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
13
+
6
14
 
7
15
  # Usage (correctness test):
8
16
  python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
17
 
10
- ### Reference output:
18
+ ## Reference output (of the correctness test above, can be gpu dependent):
11
19
  prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
20
  [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
21
  [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
@@ -28,17 +36,23 @@ I'm going to the park
28
36
 
29
37
  import argparse
30
38
  import dataclasses
39
+ import itertools
31
40
  import logging
32
41
  import multiprocessing
42
+ import os
43
+ import sqlite3
33
44
  import time
45
+ from typing import Tuple
34
46
 
35
47
  import numpy as np
48
+ import pandas as pd
36
49
  import torch
37
50
  import torch.distributed as dist
38
51
 
39
52
  from sglang.srt.hf_transformers_utils import get_tokenizer
40
- from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
53
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
41
54
  from sglang.srt.model_config import ModelConfig
55
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
42
56
  from sglang.srt.model_executor.model_runner import ModelRunner
43
57
  from sglang.srt.sampling_params import SamplingParams
44
58
  from sglang.srt.server_args import ServerArgs
@@ -47,25 +61,50 @@ from sglang.srt.utils import suppress_other_loggers
47
61
 
48
62
  @dataclasses.dataclass
49
63
  class BenchArgs:
50
- batch_size: int = 1
51
- input_len: int = 1024
52
- output_len: int = 4
64
+ run_name: str = "before"
65
+ batch_size: Tuple[int] = (1,)
66
+ input_len: Tuple[int] = (1024,)
67
+ output_len: Tuple[int] = (4,)
68
+ result_filename: str = ""
53
69
  correctness_test: bool = False
54
70
  # This is only used for correctness test
55
71
  cut_len: int = 4
72
+ # Plotting args
73
+ graph_sql: str = (
74
+ "select run_name, batch_size, prefill_throughput from results where run_name='before'"
75
+ )
76
+ graph_filename: str = "out.png"
56
77
 
57
78
  @staticmethod
58
79
  def add_cli_args(parser: argparse.ArgumentParser):
59
- parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
60
- parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
- parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
80
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
81
+ parser.add_argument(
82
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
83
+ )
84
+ parser.add_argument(
85
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
86
+ )
87
+ parser.add_argument(
88
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
89
+ )
90
+ parser.add_argument(
91
+ "--result-filename", type=str, default=BenchArgs.result_filename
92
+ )
62
93
  parser.add_argument("--correctness-test", action="store_true")
63
94
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
95
+ # graphing
96
+ parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
97
+ parser.add_argument(
98
+ "--graph-filename", type=str, default=BenchArgs.graph_filename
99
+ )
64
100
 
65
101
  @classmethod
66
102
  def from_cli_args(cls, args: argparse.Namespace):
67
- attrs = [attr.name for attr in dataclasses.fields(cls)]
68
- return cls(**{attr: getattr(args, attr) for attr in attrs})
103
+ # use the default value's type to case the args into correct types.
104
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
105
+ return cls(
106
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
107
+ )
69
108
 
70
109
 
71
110
  def load_model(server_args, tp_rank):
@@ -93,7 +132,7 @@ def load_model(server_args, tp_rank):
93
132
  return model_runner, tokenizer
94
133
 
95
134
 
96
- def prepare_inputs(bench_args, tokenizer):
135
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
97
136
  prompts = [
98
137
  "The capital of France is",
99
138
  "The capital of the United Kindom is",
@@ -119,7 +158,9 @@ def prepare_inputs(bench_args, tokenizer):
119
158
  return input_ids, reqs
120
159
 
121
160
 
122
- def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
161
+ def prepare_extend_inputs_for_correctness_test(
162
+ bench_args, input_ids, reqs, model_runner
163
+ ):
123
164
  for i in range(len(reqs)):
124
165
  req = reqs[i]
125
166
  req.input_ids += input_ids[i][bench_args.cut_len :]
@@ -129,8 +170,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
129
170
  return reqs
130
171
 
131
172
 
132
- def prepare_synthetic_inputs(bench_args, tokenizer):
133
- input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
173
+ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
174
+ input_ids = np.ones((batch_size, input_len), dtype=np.int32)
134
175
  sampling_params = SamplingParams(
135
176
  temperature=0,
136
177
  max_new_tokens=BenchArgs.output_len,
@@ -148,7 +189,7 @@ def prepare_synthetic_inputs(bench_args, tokenizer):
148
189
 
149
190
 
150
191
  def extend(reqs, model_runner):
151
- batch = Batch.init_new(
192
+ batch = ScheduleBatch.init_new(
152
193
  reqs=reqs,
153
194
  req_to_token_pool=model_runner.req_to_token_pool,
154
195
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -179,7 +220,7 @@ def correctness_test(
179
220
  model_runner, tokenizer = load_model(server_args, tp_rank)
180
221
 
181
222
  # Prepare inputs
182
- input_ids, reqs = prepare_inputs(bench_args, tokenizer)
223
+ input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
183
224
 
184
225
  if bench_args.cut_len > 0:
185
226
  # Prefill
@@ -187,7 +228,9 @@ def correctness_test(
187
228
  rank_print("prefill logits (first half)", next_token_logits)
188
229
 
189
230
  # Prepare extend inputs
190
- reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
231
+ reqs = prepare_extend_inputs_for_correctness_test(
232
+ bench_args, input_ids, reqs, model_runner
233
+ )
191
234
 
192
235
  # Extend
193
236
  next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
@@ -205,6 +248,74 @@ def correctness_test(
205
248
  rank_print(tokenizer.decode(output_ids[i]))
206
249
 
207
250
 
251
+ @torch.inference_mode()
252
+ def latency_test_run_once(
253
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
254
+ ):
255
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
256
+ if batch_size > max_batch_size:
257
+ rank_print(
258
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
259
+ )
260
+ return
261
+
262
+ # Clear the pools.
263
+ model_runner.req_to_token_pool.clear()
264
+ model_runner.token_to_kv_pool.clear()
265
+
266
+ measurement_results = {
267
+ "run_name": run_name,
268
+ "batch_size": batch_size,
269
+ "input_len": input_len,
270
+ "output_len": output_len,
271
+ }
272
+
273
+ tot_latency = 0
274
+
275
+ # Prefill
276
+ torch.cuda.synchronize()
277
+ tic = time.time()
278
+ next_token_ids, _, batch = extend(reqs, model_runner)
279
+ torch.cuda.synchronize()
280
+ prefill_latency = time.time() - tic
281
+ tot_latency += prefill_latency
282
+ throughput = input_len * batch_size / prefill_latency
283
+ rank_print(
284
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
285
+ )
286
+ measurement_results["prefill_latency"] = prefill_latency
287
+ measurement_results["prefill_throughput"] = throughput
288
+
289
+ # Decode
290
+ for i in range(output_len):
291
+ torch.cuda.synchronize()
292
+ tic = time.time()
293
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
294
+ torch.cuda.synchronize()
295
+ latency = time.time() - tic
296
+ tot_latency += latency
297
+ throughput = batch_size / latency
298
+ if i < 5:
299
+ rank_print(
300
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
301
+ )
302
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
303
+ avg_decode_throughput = batch_size / avg_decode_latency
304
+ rank_print(
305
+ f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
306
+ )
307
+ measurement_results["avg_decode_latency"] = avg_decode_latency
308
+ measurement_results["avg_decode_throughput"] = avg_decode_throughput
309
+
310
+ throughput = (input_len + output_len) * batch_size / tot_latency
311
+ rank_print(
312
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
313
+ )
314
+ measurement_results["total_latency"] = tot_latency
315
+ measurement_results["total_throughput"] = throughput
316
+ return measurement_results
317
+
318
+
208
319
  def latency_test(
209
320
  server_args,
210
321
  bench_args,
@@ -214,75 +325,119 @@ def latency_test(
214
325
 
215
326
  # Load the model
216
327
  model_runner, tokenizer = load_model(server_args, tp_rank)
217
- rank_print(
218
- f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
219
- )
220
328
 
221
- # Prepare inputs
222
- reqs = prepare_synthetic_inputs(bench_args, tokenizer)
329
+ # Prepare inputs for warm up
330
+ reqs = prepare_synthetic_inputs_for_latency_test(
331
+ bench_args.batch_size[0], bench_args.input_len[0]
332
+ )
223
333
 
224
- def clear():
225
- model_runner.req_to_token_pool.clear()
226
- model_runner.token_to_kv_pool.clear()
334
+ # Warm up
335
+ latency_test_run_once(
336
+ bench_args.run_name,
337
+ model_runner,
338
+ rank_print,
339
+ reqs,
340
+ bench_args.batch_size[0],
341
+ bench_args.input_len[0],
342
+ 4, # shorter decoding to speed up the warmup
343
+ )
227
344
 
228
- @torch.inference_mode()
229
- def run_once(output_len):
230
- # Prefill
231
- torch.cuda.synchronize()
232
- tot_latency = 0
233
- tic = time.time()
234
- next_token_ids, _, batch = extend(reqs, model_runner)
235
- torch.cuda.synchronize()
236
- prefill_latency = time.time() - tic
237
- tot_latency += prefill_latency
238
- throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
239
- rank_print(
240
- f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
345
+ # Run the sweep
346
+ result_list = []
347
+ for bs, il, ol in itertools.product(
348
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
349
+ ):
350
+ req = prepare_synthetic_inputs_for_latency_test(bs, il)
351
+ ret = latency_test_run_once(
352
+ bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
241
353
  )
354
+ if ret is not None:
355
+ result_list.append(ret)
242
356
 
243
- # Decode
244
- for i in range(output_len):
245
- torch.cuda.synchronize()
246
- tic = time.time()
247
- next_token_ids, _ = decode(next_token_ids, batch, model_runner)
248
- torch.cuda.synchronize()
249
- latency = time.time() - tic
250
- tot_latency += latency
251
- throughput = bench_args.batch_size / latency
252
- if i < 5:
253
- rank_print(
254
- f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
255
- )
256
- avg_decode_latency = (tot_latency - prefill_latency) / output_len
257
- avg_decode_throughput = bench_args.batch_size / avg_decode_latency
258
- rank_print(
259
- f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
260
- )
357
+ # Write results in jsonlines format on rank 0.
358
+ if tp_rank == 0 and bench_args.result_filename:
359
+ import jsonlines
261
360
 
262
- throughput = (
263
- (bench_args.input_len + bench_args.output_len)
264
- * bench_args.batch_size
265
- / tot_latency
266
- )
267
- rank_print(
268
- f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
269
- )
361
+ with jsonlines.open(bench_args.result_filename, "a") as f:
362
+ f.write_all(result_list)
270
363
 
271
- # Warm up
272
- run_once(4)
273
- clear()
274
364
 
275
- # Run again
276
- run_once(bench_args.output_len)
365
+ def plot_latency_test(
366
+ server_args,
367
+ bench_args,
368
+ tp_rank,
369
+ ):
370
+ assert tp_rank == 0
371
+
372
+ # read the jsonl file and put in sqlite
373
+ df = pd.read_json(bench_args.result_filename, lines=True)
374
+ conn = sqlite3.connect(":memory:")
375
+ cur = conn.cursor()
376
+
377
+ # get the columns and their types
378
+ column_names = list(df.iloc[0].keys())
379
+ type_dict = {
380
+ str: "TEXT",
381
+ np.int64: "INTEGER",
382
+ np.float64: "FLOAT",
383
+ }
384
+ column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
385
+
386
+ # create the table
387
+ cur.execute(
388
+ f"""
389
+ CREATE TABLE IF NOT EXISTS results (
390
+ {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
391
+ )
392
+ """
393
+ )
394
+ conn.commit()
395
+
396
+ # write the results to DB
397
+ df.to_sql("results", conn, if_exists="replace", index=False)
398
+ conn.commit()
399
+
400
+ # read it back using sql
401
+ df = pd.read_sql_query(bench_args.graph_sql, conn)
402
+ conn.close()
403
+
404
+ # plot it and save to a file
405
+ import matplotlib.pyplot as plt
406
+
407
+ assert (
408
+ len(df.columns) == 3
409
+ ), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
410
+ for label in df[df.columns[0]].unique():
411
+ q = f"{df.columns[0]}=='{label}'"
412
+ series = df.query(q)
413
+ plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
414
+ plt.xlabel(df.columns[1])
415
+ plt.ylabel(df.columns[2])
416
+ plt.legend()
417
+ plt.savefig(bench_args.graph_filename, dpi=300)
418
+
419
+ # if in kitty, just dump it to the terminal
420
+ if os.environ["TERM"] == "xterm-kitty":
421
+ os.system(
422
+ f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
423
+ )
277
424
 
278
425
 
279
426
  def main(server_args, bench_args):
280
- print(bench_args)
281
427
 
282
- if bench_args.correctness_test:
283
- work_func = correctness_test
428
+ if server_args.model_path:
429
+ if bench_args.correctness_test:
430
+ work_func = correctness_test
431
+ else:
432
+ work_func = latency_test
433
+ elif os.path.isfile(bench_args.result_filename):
434
+ assert bench_args.graph_filename, "please provide a filename for the graph"
435
+ work_func = plot_latency_test
284
436
  else:
285
- work_func = latency_test
437
+ raise ValueError(
438
+ "Provide --model-path for running the tests or "
439
+ "provide --result-filename for plotting the results"
440
+ )
286
441
 
287
442
  if server_args.tp_size == 1:
288
443
  work_func(server_args, bench_args, 0)
@@ -310,6 +465,11 @@ if __name__ == "__main__":
310
465
  parser = argparse.ArgumentParser()
311
466
  ServerArgs.add_cli_args(parser)
312
467
  BenchArgs.add_cli_args(parser)
468
+ # For this script, model-path is not required
469
+ assert (
470
+ parser._actions[1].option_strings[0] == "--model-path"
471
+ ), "options changed, this code need to be updated"
472
+ parser._actions[1].required = False
313
473
  args = parser.parse_args()
314
474
 
315
475
  server_args = ServerArgs.from_cli_args(args)
sglang/check_env.py CHANGED
@@ -13,6 +13,8 @@ import torch
13
13
  PACKAGE_LIST = [
14
14
  "sglang",
15
15
  "flashinfer",
16
+ "triton",
17
+ "transformers",
16
18
  "requests",
17
19
  "tqdm",
18
20
  "numpy",
@@ -72,10 +74,26 @@ def _get_gpu_info():
72
74
  Get information about available GPUs.
73
75
  """
74
76
  devices = defaultdict(list)
77
+ capabilities = defaultdict(list)
75
78
  for k in range(torch.cuda.device_count()):
76
79
  devices[torch.cuda.get_device_name(k)].append(str(k))
80
+ capability = torch.cuda.get_device_capability(k)
81
+ capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
77
82
 
78
- return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
83
+ gpu_info = {}
84
+ for name, device_ids in devices.items():
85
+ gpu_info[f"GPU {','.join(device_ids)}"] = name
86
+
87
+ if len(capabilities) == 1:
88
+ # All GPUs have the same compute capability
89
+ cap, gpu_ids = list(capabilities.items())[0]
90
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
91
+ else:
92
+ # GPUs have different compute capabilities
93
+ for cap, gpu_ids in capabilities.items():
94
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
95
+
96
+ return gpu_info
79
97
 
80
98
 
81
99
  def _get_cuda_version_info():
@@ -117,6 +135,7 @@ def _get_cuda_driver_version():
117
135
  """
118
136
  Get CUDA driver version.
119
137
  """
138
+ versions = set()
120
139
  try:
121
140
  output = subprocess.check_output(
122
141
  [
@@ -125,7 +144,11 @@ def _get_cuda_driver_version():
125
144
  "--format=csv,noheader,nounits",
126
145
  ]
127
146
  )
128
- return {"CUDA Driver Version": output.decode().strip()}
147
+ versions = set(output.decode().strip().split("\n"))
148
+ if len(versions) == 1:
149
+ return {"CUDA Driver Version": versions.pop()}
150
+ else:
151
+ return {"CUDA Driver Versions": ", ".join(sorted(versions))}
129
152
  except subprocess.SubprocessError:
130
153
  return {"CUDA Driver Version": "Not Available"}
131
154
 
sglang/global_config.py CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
19
19
  self.init_new_token_ratio = 0.7
20
20
  self.base_min_new_token_ratio = 0.1
21
21
  self.new_token_ratio_decay = 0.001
22
- self.new_token_ratio_recovery = 0.05
23
22
 
24
23
  # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
25
24
  # This can improve the speed for large batch sizes during prefill.
@@ -1,6 +1,7 @@
1
1
  from typing import Callable, List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
4
5
  from sglang.lang.interpreter import StreamExecutor
5
6
  from sglang.lang.ir import SglSamplingParams
6
7
 
@@ -64,7 +65,8 @@ class BaseBackend:
64
65
  s: StreamExecutor,
65
66
  choices: List[str],
66
67
  temperature: float,
67
- ):
68
+ choices_method: Optional[ChoicesSamplingMethod] = None,
69
+ ) -> ChoicesDecision:
68
70
  raise NotImplementedError()
69
71
 
70
72
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -8,6 +8,7 @@ import numpy as np
8
8
 
9
9
  from sglang.lang.backend.base_backend import BaseBackend
10
10
  from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
11
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
11
12
  from sglang.lang.interpreter import StreamExecutor
12
13
  from sglang.lang.ir import SglSamplingParams
13
14
 
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
296
297
  s: StreamExecutor,
297
298
  choices: List[str],
298
299
  temperature: float,
299
- ):
300
+ choices_method: ChoicesSamplingMethod,
301
+ ) -> ChoicesDecision:
302
+ """Note: `choices_method` is not used by the OpenAI backend."""
300
303
  if self.is_chat_model:
301
304
  raise NotImplementedError(
302
305
  "select/choices is not supported for chat models. "
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
354
357
 
355
358
  prompt_tokens.append(ret_token)
356
359
 
357
- decision = choices[np.argmax(scores)]
358
- return decision, scores, None, None
360
+ return ChoicesDecision(
361
+ decision=choices[np.argmax(scores)],
362
+ meta_info={"scores": scores},
363
+ )
359
364
 
360
365
 
361
366
  def openai_completion(