sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +145 -36
  4. sglang/check_env.py +24 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -29
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/layers/logits_processor.py +1 -1
  13. sglang/srt/layers/radix_attention.py +2 -5
  14. sglang/srt/managers/schedule_batch.py +95 -324
  15. sglang/srt/managers/tokenizer_manager.py +6 -3
  16. sglang/srt/managers/tp_worker.py +20 -22
  17. sglang/srt/mem_cache/memory_pool.py +9 -14
  18. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  19. sglang/srt/model_executor/forward_batch_info.py +256 -0
  20. sglang/srt/model_executor/model_runner.py +6 -10
  21. sglang/srt/models/chatglm.py +1 -1
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/dbrx.py +1 -1
  24. sglang/srt/models/deepseek.py +1 -1
  25. sglang/srt/models/deepseek_v2.py +1 -1
  26. sglang/srt/models/gemma.py +1 -1
  27. sglang/srt/models/gemma2.py +1 -1
  28. sglang/srt/models/gpt_bigcode.py +1 -1
  29. sglang/srt/models/grok.py +1 -1
  30. sglang/srt/models/internlm2.py +1 -1
  31. sglang/srt/models/llama2.py +1 -1
  32. sglang/srt/models/llama_classification.py +1 -1
  33. sglang/srt/models/llava.py +1 -2
  34. sglang/srt/models/llavavid.py +1 -2
  35. sglang/srt/models/minicpm.py +1 -1
  36. sglang/srt/models/mixtral.py +1 -1
  37. sglang/srt/models/mixtral_quant.py +1 -1
  38. sglang/srt/models/qwen.py +1 -1
  39. sglang/srt/models/qwen2.py +1 -1
  40. sglang/srt/models/qwen2_moe.py +1 -1
  41. sglang/srt/models/stablelm.py +1 -1
  42. sglang/srt/openai_api/adapter.py +34 -12
  43. sglang/srt/openai_api/protocol.py +6 -0
  44. sglang/srt/server.py +24 -6
  45. sglang/srt/server_args.py +4 -0
  46. sglang/test/test_utils.py +1 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
  49. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
  50. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  51. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  52. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -22,6 +22,11 @@ from sglang.api import (
22
22
  user_end,
23
23
  video,
24
24
  )
25
+ from sglang.lang.choices import (
26
+ greedy_token_selection,
27
+ token_length_normalized,
28
+ unconditional_likelihood_normalized,
29
+ )
25
30
 
26
31
  # SGLang DSL APIs
27
32
  __all__ = [
@@ -45,6 +50,9 @@ __all__ = [
45
50
  "user_begin",
46
51
  "user_end",
47
52
  "video",
53
+ "greedy_token_selection",
54
+ "token_length_normalized",
55
+ "unconditional_likelihood_normalized",
48
56
  ]
49
57
 
50
58
  # Global Configurations
sglang/api.py CHANGED
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
6
6
 
7
7
  from sglang.global_config import global_config
8
8
  from sglang.lang.backend.base_backend import BaseBackend
9
+ from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
9
10
  from sglang.lang.ir import (
10
11
  SglExpr,
11
12
  SglExprList,
@@ -73,12 +74,18 @@ def gen(
73
74
  return_text_in_logprobs: Optional[bool] = None,
74
75
  dtype: Optional[type] = None,
75
76
  choices: Optional[List[str]] = None,
77
+ choices_method: Optional[ChoicesSamplingMethod] = None,
76
78
  regex: Optional[str] = None,
77
79
  ):
78
80
  """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
79
81
 
80
82
  if choices:
81
- return SglSelect(name, choices, 0.0 if temperature is None else temperature)
83
+ return SglSelect(
84
+ name,
85
+ choices,
86
+ 0.0 if temperature is None else temperature,
87
+ token_length_normalized if choices_method is None else choices_method,
88
+ )
82
89
 
83
90
  # check regex is valid
84
91
  if regex is not None:
@@ -186,9 +193,10 @@ def select(
186
193
  name: Optional[str] = None,
187
194
  choices: Optional[List[str]] = None,
188
195
  temperature: float = 0.0,
196
+ choices_method: ChoicesSamplingMethod = token_length_normalized,
189
197
  ):
190
198
  assert choices is not None
191
- return SglSelect(name, choices, temperature)
199
+ return SglSelect(name, choices, temperature, choices_method)
192
200
 
193
201
 
194
202
  def _role_common(name: str, expr: Optional[SglExpr] = None):
sglang/bench_latency.py CHANGED
@@ -1,13 +1,21 @@
1
1
  """
2
2
  Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
3
 
4
- # Usage (latency test) with dummy weights:
4
+ # Usage (latency test)
5
+ ## with dummy weights:
5
6
  python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
7
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
8
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
9
+ ## do some changes, and store the results under a different run_name:
10
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
11
+ ## plot the results in series of lines:
12
+ python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
13
+
6
14
 
7
15
  # Usage (correctness test):
8
16
  python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
17
 
10
- ### Reference output (of the correctness test above, can be gpu dependent):
18
+ ## Reference output (of the correctness test above, can be gpu dependent):
11
19
  prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
20
  [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
21
  [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
@@ -28,19 +36,23 @@ I'm going to the park
28
36
 
29
37
  import argparse
30
38
  import dataclasses
39
+ import itertools
31
40
  import logging
32
41
  import multiprocessing
42
+ import os
43
+ import sqlite3
33
44
  import time
34
45
  from typing import Tuple
35
46
 
36
- import jsonlines
37
47
  import numpy as np
48
+ import pandas as pd
38
49
  import torch
39
50
  import torch.distributed as dist
40
51
 
41
52
  from sglang.srt.hf_transformers_utils import get_tokenizer
42
- from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
53
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
43
54
  from sglang.srt.model_config import ModelConfig
55
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
44
56
  from sglang.srt.model_executor.model_runner import ModelRunner
45
57
  from sglang.srt.sampling_params import SamplingParams
46
58
  from sglang.srt.server_args import ServerArgs
@@ -49,26 +61,42 @@ from sglang.srt.utils import suppress_other_loggers
49
61
 
50
62
  @dataclasses.dataclass
51
63
  class BenchArgs:
64
+ run_name: str = "before"
52
65
  batch_size: Tuple[int] = (1,)
53
- input_len: int = 1024
54
- output_len: int = 4
66
+ input_len: Tuple[int] = (1024,)
67
+ output_len: Tuple[int] = (4,)
55
68
  result_filename: str = ""
56
69
  correctness_test: bool = False
57
70
  # This is only used for correctness test
58
71
  cut_len: int = 4
72
+ # Plotting args
73
+ graph_sql: str = (
74
+ "select run_name, batch_size, prefill_throughput from results where run_name='before'"
75
+ )
76
+ graph_filename: str = "out.png"
59
77
 
60
78
  @staticmethod
61
79
  def add_cli_args(parser: argparse.ArgumentParser):
80
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
62
81
  parser.add_argument(
63
82
  "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
64
83
  )
65
- parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
66
- parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
84
+ parser.add_argument(
85
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
86
+ )
87
+ parser.add_argument(
88
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
89
+ )
67
90
  parser.add_argument(
68
91
  "--result-filename", type=str, default=BenchArgs.result_filename
69
92
  )
70
93
  parser.add_argument("--correctness-test", action="store_true")
71
94
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
95
+ # graphing
96
+ parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
97
+ parser.add_argument(
98
+ "--graph-filename", type=str, default=BenchArgs.graph_filename
99
+ )
72
100
 
73
101
  @classmethod
74
102
  def from_cli_args(cls, args: argparse.Namespace):
@@ -161,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
161
189
 
162
190
 
163
191
  def extend(reqs, model_runner):
164
- batch = Batch.init_new(
192
+ batch = ScheduleBatch.init_new(
165
193
  reqs=reqs,
166
194
  req_to_token_pool=model_runner.req_to_token_pool,
167
195
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -222,15 +250,21 @@ def correctness_test(
222
250
 
223
251
  @torch.inference_mode()
224
252
  def latency_test_run_once(
225
- model_runner, rank_print, reqs, batch_size, input_len, output_len
253
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
226
254
  ):
255
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
256
+ if batch_size > max_batch_size:
257
+ rank_print(
258
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
259
+ )
260
+ return
227
261
 
228
262
  # Clear the pools.
229
263
  model_runner.req_to_token_pool.clear()
230
264
  model_runner.token_to_kv_pool.clear()
231
265
 
232
266
  measurement_results = {
233
- "run_name": "before",
267
+ "run_name": run_name,
234
268
  "batch_size": batch_size,
235
269
  "input_len": input_len,
236
270
  "output_len": output_len,
@@ -291,49 +325,119 @@ def latency_test(
291
325
 
292
326
  # Load the model
293
327
  model_runner, tokenizer = load_model(server_args, tp_rank)
294
- rank_print(
295
- f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
296
- )
297
328
 
298
- # To make this PR easier to review, for now, only do the first element in batch_size tuple.
299
- bench_args.batch_size = bench_args.batch_size[0]
300
-
301
- # Prepare inputs
329
+ # Prepare inputs for warm up
302
330
  reqs = prepare_synthetic_inputs_for_latency_test(
303
- bench_args.batch_size, bench_args.input_len
331
+ bench_args.batch_size[0], bench_args.input_len[0]
304
332
  )
305
333
 
306
334
  # Warm up
307
335
  latency_test_run_once(
308
- model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
336
+ bench_args.run_name,
337
+ model_runner,
338
+ rank_print,
339
+ reqs,
340
+ bench_args.batch_size[0],
341
+ bench_args.input_len[0],
342
+ 4, # shorter decoding to speed up the warmup
309
343
  )
310
344
 
311
- # Run again
345
+ # Run the sweep
312
346
  result_list = []
313
- result_list.append(
314
- latency_test_run_once(
315
- model_runner,
316
- rank_print,
317
- reqs,
318
- bench_args.batch_size,
319
- bench_args.input_len,
320
- bench_args.output_len,
347
+ for bs, il, ol in itertools.product(
348
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
349
+ ):
350
+ req = prepare_synthetic_inputs_for_latency_test(bs, il)
351
+ ret = latency_test_run_once(
352
+ bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
321
353
  )
322
- )
354
+ if ret is not None:
355
+ result_list.append(ret)
356
+
357
+ # Write results in jsonlines format on rank 0.
358
+ if tp_rank == 0 and bench_args.result_filename:
359
+ import jsonlines
323
360
 
324
- # Write results in jsonlines format.
325
- if bench_args.result_filename:
326
361
  with jsonlines.open(bench_args.result_filename, "a") as f:
327
362
  f.write_all(result_list)
328
363
 
329
364
 
365
+ def plot_latency_test(
366
+ server_args,
367
+ bench_args,
368
+ tp_rank,
369
+ ):
370
+ assert tp_rank == 0
371
+
372
+ # read the jsonl file and put in sqlite
373
+ df = pd.read_json(bench_args.result_filename, lines=True)
374
+ conn = sqlite3.connect(":memory:")
375
+ cur = conn.cursor()
376
+
377
+ # get the columns and their types
378
+ column_names = list(df.iloc[0].keys())
379
+ type_dict = {
380
+ str: "TEXT",
381
+ np.int64: "INTEGER",
382
+ np.float64: "FLOAT",
383
+ }
384
+ column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
385
+
386
+ # create the table
387
+ cur.execute(
388
+ f"""
389
+ CREATE TABLE IF NOT EXISTS results (
390
+ {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
391
+ )
392
+ """
393
+ )
394
+ conn.commit()
395
+
396
+ # write the results to DB
397
+ df.to_sql("results", conn, if_exists="replace", index=False)
398
+ conn.commit()
399
+
400
+ # read it back using sql
401
+ df = pd.read_sql_query(bench_args.graph_sql, conn)
402
+ conn.close()
403
+
404
+ # plot it and save to a file
405
+ import matplotlib.pyplot as plt
406
+
407
+ assert (
408
+ len(df.columns) == 3
409
+ ), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
410
+ for label in df[df.columns[0]].unique():
411
+ q = f"{df.columns[0]}=='{label}'"
412
+ series = df.query(q)
413
+ plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
414
+ plt.xlabel(df.columns[1])
415
+ plt.ylabel(df.columns[2])
416
+ plt.legend()
417
+ plt.savefig(bench_args.graph_filename, dpi=300)
418
+
419
+ # if in kitty, just dump it to the terminal
420
+ if os.environ["TERM"] == "xterm-kitty":
421
+ os.system(
422
+ f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
423
+ )
424
+
425
+
330
426
  def main(server_args, bench_args):
331
- print(bench_args)
332
427
 
333
- if bench_args.correctness_test:
334
- work_func = correctness_test
428
+ if server_args.model_path:
429
+ if bench_args.correctness_test:
430
+ work_func = correctness_test
431
+ else:
432
+ work_func = latency_test
433
+ elif os.path.isfile(bench_args.result_filename):
434
+ assert bench_args.graph_filename, "please provide a filename for the graph"
435
+ work_func = plot_latency_test
335
436
  else:
336
- work_func = latency_test
437
+ raise ValueError(
438
+ "Provide --model-path for running the tests or "
439
+ "provide --result-filename for plotting the results"
440
+ )
337
441
 
338
442
  if server_args.tp_size == 1:
339
443
  work_func(server_args, bench_args, 0)
@@ -361,6 +465,11 @@ if __name__ == "__main__":
361
465
  parser = argparse.ArgumentParser()
362
466
  ServerArgs.add_cli_args(parser)
363
467
  BenchArgs.add_cli_args(parser)
468
+ # For this script, model-path is not required
469
+ assert (
470
+ parser._actions[1].option_strings[0] == "--model-path"
471
+ ), "options changed, this code need to be updated"
472
+ parser._actions[1].required = False
364
473
  args = parser.parse_args()
365
474
 
366
475
  server_args = ServerArgs.from_cli_args(args)
sglang/check_env.py CHANGED
@@ -14,6 +14,7 @@ PACKAGE_LIST = [
14
14
  "sglang",
15
15
  "flashinfer",
16
16
  "triton",
17
+ "transformers",
17
18
  "requests",
18
19
  "tqdm",
19
20
  "numpy",
@@ -73,10 +74,26 @@ def _get_gpu_info():
73
74
  Get information about available GPUs.
74
75
  """
75
76
  devices = defaultdict(list)
77
+ capabilities = defaultdict(list)
76
78
  for k in range(torch.cuda.device_count()):
77
79
  devices[torch.cuda.get_device_name(k)].append(str(k))
80
+ capability = torch.cuda.get_device_capability(k)
81
+ capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
78
82
 
79
- return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
83
+ gpu_info = {}
84
+ for name, device_ids in devices.items():
85
+ gpu_info[f"GPU {','.join(device_ids)}"] = name
86
+
87
+ if len(capabilities) == 1:
88
+ # All GPUs have the same compute capability
89
+ cap, gpu_ids = list(capabilities.items())[0]
90
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
91
+ else:
92
+ # GPUs have different compute capabilities
93
+ for cap, gpu_ids in capabilities.items():
94
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
95
+
96
+ return gpu_info
80
97
 
81
98
 
82
99
  def _get_cuda_version_info():
@@ -118,6 +135,7 @@ def _get_cuda_driver_version():
118
135
  """
119
136
  Get CUDA driver version.
120
137
  """
138
+ versions = set()
121
139
  try:
122
140
  output = subprocess.check_output(
123
141
  [
@@ -126,7 +144,11 @@ def _get_cuda_driver_version():
126
144
  "--format=csv,noheader,nounits",
127
145
  ]
128
146
  )
129
- return {"CUDA Driver Version": output.decode().strip()}
147
+ versions = set(output.decode().strip().split("\n"))
148
+ if len(versions) == 1:
149
+ return {"CUDA Driver Version": versions.pop()}
150
+ else:
151
+ return {"CUDA Driver Versions": ", ".join(sorted(versions))}
130
152
  except subprocess.SubprocessError:
131
153
  return {"CUDA Driver Version": "Not Available"}
132
154
 
sglang/global_config.py CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
19
19
  self.init_new_token_ratio = 0.7
20
20
  self.base_min_new_token_ratio = 0.1
21
21
  self.new_token_ratio_decay = 0.001
22
- self.new_token_ratio_recovery = 0.05
23
22
 
24
23
  # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
25
24
  # This can improve the speed for large batch sizes during prefill.
@@ -1,6 +1,7 @@
1
1
  from typing import Callable, List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
4
5
  from sglang.lang.interpreter import StreamExecutor
5
6
  from sglang.lang.ir import SglSamplingParams
6
7
 
@@ -64,7 +65,8 @@ class BaseBackend:
64
65
  s: StreamExecutor,
65
66
  choices: List[str],
66
67
  temperature: float,
67
- ):
68
+ choices_method: Optional[ChoicesSamplingMethod] = None,
69
+ ) -> ChoicesDecision:
68
70
  raise NotImplementedError()
69
71
 
70
72
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -8,6 +8,7 @@ import numpy as np
8
8
 
9
9
  from sglang.lang.backend.base_backend import BaseBackend
10
10
  from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
11
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
11
12
  from sglang.lang.interpreter import StreamExecutor
12
13
  from sglang.lang.ir import SglSamplingParams
13
14
 
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
296
297
  s: StreamExecutor,
297
298
  choices: List[str],
298
299
  temperature: float,
299
- ):
300
+ choices_method: ChoicesSamplingMethod,
301
+ ) -> ChoicesDecision:
302
+ """Note: `choices_method` is not used by the OpenAI backend."""
300
303
  if self.is_chat_model:
301
304
  raise NotImplementedError(
302
305
  "select/choices is not supported for chat models. "
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
354
357
 
355
358
  prompt_tokens.append(ret_token)
356
359
 
357
- decision = choices[np.argmax(scores)]
358
- return decision, scores, None, None
360
+ return ChoicesDecision(
361
+ decision=choices[np.argmax(scores)],
362
+ meta_info={"scores": scores},
363
+ )
359
364
 
360
365
 
361
366
  def openai_completion(
@@ -1,17 +1,21 @@
1
1
  import json
2
2
  from typing import List, Optional
3
3
 
4
- import numpy as np
5
-
6
4
  from sglang.global_config import global_config
7
5
  from sglang.lang.backend.base_backend import BaseBackend
8
6
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
+ from sglang.lang.choices import (
8
+ ChoicesDecision,
9
+ ChoicesSamplingMethod,
10
+ token_length_normalized,
11
+ )
9
12
  from sglang.lang.interpreter import StreamExecutor
10
13
  from sglang.lang.ir import SglSamplingParams
11
14
  from sglang.utils import http_request
12
15
 
13
16
 
14
17
  class RuntimeEndpoint(BaseBackend):
18
+
15
19
  def __init__(
16
20
  self,
17
21
  base_url: str,
@@ -43,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
43
47
  def flush_cache(self):
44
48
  res = http_request(
45
49
  self.base_url + "/flush_cache",
46
- auth_token=self.auth_token,
50
+ api_key=self.api_key,
47
51
  verify=self.verify,
48
52
  )
49
53
  self._assert_success(res)
@@ -51,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
51
55
  def get_server_args(self):
52
56
  res = http_request(
53
57
  self.base_url + "/get_server_args",
54
- auth_token=self.auth_token,
58
+ api_key=self.api_key,
55
59
  verify=self.verify,
56
60
  )
57
61
  self._assert_success(res)
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
208
212
  s: StreamExecutor,
209
213
  choices: List[str],
210
214
  temperature: float,
211
- ):
215
+ choices_method: ChoicesSamplingMethod,
216
+ ) -> ChoicesDecision:
212
217
  assert temperature <= 1e-5
213
218
 
214
219
  # Cache common prefix
215
220
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
216
- self._add_images(s, data)
217
- res = http_request(
218
- self.base_url + "/generate",
219
- json=data,
220
- api_key=self.api_key,
221
- verify=self.verify,
222
- )
223
- self._assert_success(res)
224
- prompt_len = res.json()["meta_info"]["prompt_tokens"]
221
+ obj = self._generate_http_request(s, data)
222
+ prompt_len = obj["meta_info"]["prompt_tokens"]
225
223
 
226
224
  # Compute logprob
227
225
  data = {
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
230
228
  "return_logprob": True,
231
229
  "logprob_start_len": max(prompt_len - 2, 0),
232
230
  }
233
- self._add_images(s, data)
234
- res = http_request(
235
- self.base_url + "/generate",
236
- json=data,
237
- api_key=self.api_key,
238
- verify=self.verify,
239
- )
240
- self._assert_success(res)
241
- obj = res.json()
231
+ obj = self._generate_http_request(s, data)
232
+
242
233
  normalized_prompt_logprobs = [
243
234
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
244
235
  ]
245
- decision = choices[np.argmax(normalized_prompt_logprobs)]
246
236
  input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
247
237
  output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
248
238
 
249
- return (
250
- decision,
251
- normalized_prompt_logprobs,
252
- input_token_logprobs,
253
- output_token_logprobs,
239
+ # Compute unconditional logprobs if required
240
+ if choices_method.requires_unconditional_logprobs:
241
+ input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
242
+ data = {
243
+ "input_ids": input_ids,
244
+ "sampling_params": {"max_new_tokens": 0},
245
+ "return_logprob": True,
246
+ }
247
+ obj = self._generate_http_request(s, data)
248
+ unconditional_token_logprobs = [
249
+ r["meta_info"]["input_token_logprobs"] for r in obj
250
+ ]
251
+ else:
252
+ unconditional_token_logprobs = None
253
+
254
+ return choices_method(
255
+ choices=choices,
256
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
257
+ input_token_logprobs=input_token_logprobs,
258
+ output_token_logprobs=output_token_logprobs,
259
+ unconditional_token_logprobs=unconditional_token_logprobs,
254
260
  )
255
261
 
256
262
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
262
268
  )
263
269
  self._assert_success(res)
264
270
 
271
+ def _generate_http_request(self, s: StreamExecutor, data):
272
+ self._add_images(s, data)
273
+ res = http_request(
274
+ self.base_url + "/generate",
275
+ json=data,
276
+ api_key=self.api_key,
277
+ verify=self.verify,
278
+ )
279
+ self._assert_success(res)
280
+ return res.json()
281
+
265
282
  def _add_images(self, s: StreamExecutor, data):
266
283
  if s.images_:
267
284
  assert len(s.images_) == 1, "Only support one image."