sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,61 @@
1
- __version__ = "0.1.14"
1
+ __version__ = "0.1.21"
2
2
 
3
- from sglang.api import *
3
+ # SGL API Components
4
+ from sglang.api import (
5
+ Runtime,
6
+ assistant,
7
+ assistant_begin,
8
+ assistant_end,
9
+ flush_cache,
10
+ function,
11
+ gen,
12
+ gen_int,
13
+ gen_string,
14
+ get_server_args,
15
+ image,
16
+ select,
17
+ set_default_backend,
18
+ system,
19
+ user,
20
+ user_begin,
21
+ user_end,
22
+ video,
23
+ )
24
+
25
+ # SGL Backends
26
+ from sglang.backend.anthropic import Anthropic
27
+ from sglang.backend.litellm import LiteLLM
28
+ from sglang.backend.openai import OpenAI
29
+ from sglang.backend.runtime_endpoint import RuntimeEndpoint
30
+ from sglang.backend.vertexai import VertexAI
31
+
32
+ # Global Configurations
4
33
  from sglang.global_config import global_config
34
+
35
+ # public APIs management
36
+ __all__ = [
37
+ "global_config",
38
+ "Anthropic",
39
+ "LiteLLM",
40
+ "OpenAI",
41
+ "RuntimeEndpoint",
42
+ "VertexAI",
43
+ "function",
44
+ "Runtime",
45
+ "set_default_backend",
46
+ "flush_cache",
47
+ "get_server_args",
48
+ "gen",
49
+ "gen_int",
50
+ "gen_string",
51
+ "image",
52
+ "video",
53
+ "select",
54
+ "system",
55
+ "user",
56
+ "assistant",
57
+ "user_begin",
58
+ "user_end",
59
+ "assistant_begin",
60
+ "assistant_end",
61
+ ]
sglang/api.py CHANGED
@@ -1,13 +1,10 @@
1
- """Public API"""
1
+ """Public APIs of the language."""
2
2
 
3
+ import os
3
4
  import re
4
5
  from typing import Callable, List, Optional, Union
5
6
 
6
- from sglang.backend.anthropic import Anthropic
7
7
  from sglang.backend.base_backend import BaseBackend
8
- from sglang.backend.openai import OpenAI
9
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
10
- from sglang.backend.vertexai import VertexAI
11
8
  from sglang.global_config import global_config
12
9
  from sglang.lang.ir import (
13
10
  SglExpr,
@@ -18,23 +15,25 @@ from sglang.lang.ir import (
18
15
  SglRoleBegin,
19
16
  SglRoleEnd,
20
17
  SglSelect,
18
+ SglVideo,
21
19
  )
22
20
 
23
21
 
24
22
  def function(
25
- func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
23
+ func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
26
24
  ):
27
25
  if func:
28
- return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
26
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
29
27
 
30
28
  def decorator(func):
31
- return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
29
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
32
30
 
33
31
  return decorator
34
32
 
35
33
 
36
34
  def Runtime(*args, **kwargs):
37
35
  # Avoid importing unnecessary dependency
36
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
38
37
  from sglang.srt.server import Runtime
39
38
 
40
39
  return Runtime(*args, **kwargs)
@@ -44,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
44
43
  global_config.default_backend = backend
45
44
 
46
45
 
47
- def flush_cache(backend: BaseBackend = None):
46
+ def flush_cache(backend: Optional[BaseBackend] = None):
48
47
  backend = backend or global_config.default_backend
49
48
  if backend is None:
50
49
  return False
51
50
  return backend.flush_cache()
52
51
 
53
52
 
54
- def get_server_args(backend: BaseBackend = None):
53
+ def get_server_args(backend: Optional[BaseBackend] = None):
55
54
  backend = backend or global_config.default_backend
56
55
  if backend is None:
57
56
  return None
@@ -68,10 +67,16 @@ def gen(
68
67
  frequency_penalty: Optional[float] = None,
69
68
  presence_penalty: Optional[float] = None,
70
69
  ignore_eos: Optional[bool] = None,
70
+ return_logprob: Optional[bool] = None,
71
+ logprob_start_len: Optional[int] = None,
72
+ top_logprobs_num: Optional[int] = None,
73
+ return_text_in_logprobs: Optional[bool] = None,
71
74
  dtype: Optional[type] = None,
72
75
  choices: Optional[List[str]] = None,
73
76
  regex: Optional[str] = None,
74
77
  ):
78
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
79
+
75
80
  if choices:
76
81
  return SglSelect(name, choices, 0.0 if temperature is None else temperature)
77
82
 
@@ -92,6 +97,10 @@ def gen(
92
97
  frequency_penalty,
93
98
  presence_penalty,
94
99
  ignore_eos,
100
+ return_logprob,
101
+ logprob_start_len,
102
+ top_logprobs_num,
103
+ return_text_in_logprobs,
95
104
  dtype,
96
105
  regex,
97
106
  )
@@ -107,6 +116,10 @@ def gen_int(
107
116
  frequency_penalty: Optional[float] = None,
108
117
  presence_penalty: Optional[float] = None,
109
118
  ignore_eos: Optional[bool] = None,
119
+ return_logprob: Optional[bool] = None,
120
+ logprob_start_len: Optional[int] = None,
121
+ top_logprobs_num: Optional[int] = None,
122
+ return_text_in_logprobs: Optional[bool] = None,
110
123
  ):
111
124
  return SglGen(
112
125
  name,
@@ -118,6 +131,10 @@ def gen_int(
118
131
  frequency_penalty,
119
132
  presence_penalty,
120
133
  ignore_eos,
134
+ return_logprob,
135
+ logprob_start_len,
136
+ top_logprobs_num,
137
+ return_text_in_logprobs,
121
138
  int,
122
139
  None,
123
140
  )
@@ -133,6 +150,10 @@ def gen_string(
133
150
  frequency_penalty: Optional[float] = None,
134
151
  presence_penalty: Optional[float] = None,
135
152
  ignore_eos: Optional[bool] = None,
153
+ return_logprob: Optional[bool] = None,
154
+ logprob_start_len: Optional[int] = None,
155
+ top_logprobs_num: Optional[int] = None,
156
+ return_text_in_logprobs: Optional[bool] = None,
136
157
  ):
137
158
  return SglGen(
138
159
  name,
@@ -144,6 +165,10 @@ def gen_string(
144
165
  frequency_penalty,
145
166
  presence_penalty,
146
167
  ignore_eos,
168
+ return_logprob,
169
+ logprob_start_len,
170
+ top_logprobs_num,
171
+ return_text_in_logprobs,
147
172
  str,
148
173
  None,
149
174
  )
@@ -153,9 +178,13 @@ def image(expr: SglExpr):
153
178
  return SglImage(expr)
154
179
 
155
180
 
181
+ def video(path: str, num_frames: int):
182
+ return SglVideo(path, num_frames)
183
+
184
+
156
185
  def select(
157
186
  name: Optional[str] = None,
158
- choices: List[str] = None,
187
+ choices: Optional[List[str]] = None,
159
188
  temperature: float = 0.0,
160
189
  ):
161
190
  assert choices is not None
@@ -1,6 +1,7 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
3
  import numpy as np
4
+
4
5
  from sglang.backend.base_backend import BaseBackend
5
6
  from sglang.lang.chat_template import get_chat_template
6
7
  from sglang.lang.interpreter import StreamExecutor
@@ -13,7 +14,7 @@ except ImportError as e:
13
14
 
14
15
 
15
16
  class Anthropic(BaseBackend):
16
- def __init__(self, model_name):
17
+ def __init__(self, model_name, *args, **kwargs):
17
18
  super().__init__()
18
19
 
19
20
  if isinstance(anthropic, Exception):
@@ -21,6 +22,7 @@ class Anthropic(BaseBackend):
21
22
 
22
23
  self.model_name = model_name
23
24
  self.chat_template = get_chat_template("claude")
25
+ self.client = anthropic.Anthropic(*args, **kwargs)
24
26
 
25
27
  def get_chat_template(self):
26
28
  return self.chat_template
@@ -35,8 +37,14 @@ class Anthropic(BaseBackend):
35
37
  else:
36
38
  messages = [{"role": "user", "content": s.text_}]
37
39
 
38
- ret = anthropic.Anthropic().messages.create(
40
+ if messages and messages[0]["role"] == "system":
41
+ system = messages.pop(0)["content"]
42
+ else:
43
+ system = ""
44
+
45
+ ret = self.client.messages.create(
39
46
  model=self.model_name,
47
+ system=system,
40
48
  messages=messages,
41
49
  **sampling_params.to_anthropic_kwargs(),
42
50
  )
@@ -54,8 +62,14 @@ class Anthropic(BaseBackend):
54
62
  else:
55
63
  messages = [{"role": "user", "content": s.text_}]
56
64
 
57
- with anthropic.Anthropic().messages.stream(
65
+ if messages and messages[0]["role"] == "system":
66
+ system = messages.pop(0)["content"]
67
+ else:
68
+ system = ""
69
+
70
+ with self.client.messages.stream(
58
71
  model=self.model_name,
72
+ system=system,
59
73
  messages=messages,
60
74
  **sampling_params.to_anthropic_kwargs(),
61
75
  ) as stream:
@@ -0,0 +1,90 @@
1
+ from typing import Mapping, Optional
2
+
3
+ from sglang.backend.base_backend import BaseBackend
4
+ from sglang.lang.chat_template import get_chat_template_by_model_path
5
+ from sglang.lang.interpreter import StreamExecutor
6
+ from sglang.lang.ir import SglSamplingParams
7
+
8
+ try:
9
+ import litellm
10
+ except ImportError as e:
11
+ litellm = e
12
+ litellm.num_retries = 1
13
+
14
+
15
+ class LiteLLM(BaseBackend):
16
+ def __init__(
17
+ self,
18
+ model_name,
19
+ chat_template=None,
20
+ api_key=None,
21
+ organization: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ timeout: Optional[float] = 600,
24
+ max_retries: Optional[int] = litellm.num_retries,
25
+ default_headers: Optional[Mapping[str, str]] = None,
26
+ ):
27
+ super().__init__()
28
+
29
+ if isinstance(litellm, Exception):
30
+ raise litellm
31
+
32
+ self.model_name = model_name
33
+
34
+ self.chat_template = chat_template or get_chat_template_by_model_path(
35
+ model_name
36
+ )
37
+
38
+ self.client_params = {
39
+ "api_key": api_key,
40
+ "organization": organization,
41
+ "base_url": base_url,
42
+ "timeout": timeout,
43
+ "max_retries": max_retries,
44
+ "default_headers": default_headers,
45
+ }
46
+
47
+ def get_chat_template(self):
48
+ return self.chat_template
49
+
50
+ def generate(
51
+ self,
52
+ s: StreamExecutor,
53
+ sampling_params: SglSamplingParams,
54
+ ):
55
+ if s.messages_:
56
+ messages = s.messages_
57
+ else:
58
+ messages = [{"role": "user", "content": s.text_}]
59
+
60
+ ret = litellm.completion(
61
+ model=self.model_name,
62
+ messages=messages,
63
+ **self.client_params,
64
+ **sampling_params.to_anthropic_kwargs(),
65
+ )
66
+ comp = ret.choices[0].message.content
67
+
68
+ return comp, {}
69
+
70
+ def generate_stream(
71
+ self,
72
+ s: StreamExecutor,
73
+ sampling_params: SglSamplingParams,
74
+ ):
75
+ if s.messages_:
76
+ messages = s.messages_
77
+ else:
78
+ messages = [{"role": "user", "content": s.text_}]
79
+
80
+ ret = litellm.completion(
81
+ model=self.model_name,
82
+ messages=messages,
83
+ stream=True,
84
+ **self.client_params,
85
+ **sampling_params.to_litellm_kwargs(),
86
+ )
87
+ for chunk in ret:
88
+ text = chunk.choices[0].delta.content
89
+ if text is not None:
90
+ yield text, {}
sglang/backend/openai.py CHANGED
@@ -1,8 +1,11 @@
1
+ import dataclasses
1
2
  import logging
2
3
  import time
4
+ import warnings
3
5
  from typing import Callable, List, Optional, Union
4
6
 
5
7
  import numpy as np
8
+
6
9
  from sglang.backend.base_backend import BaseBackend
7
10
  from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
8
11
  from sglang.lang.interpreter import StreamExecutor
@@ -40,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
40
43
  ]
41
44
 
42
45
 
46
+ @dataclasses.dataclass
47
+ class TokenUsage:
48
+ prompt_tokens: int
49
+ completion_tokens: int
50
+
51
+ def reset(self):
52
+ self.prompt_tokens = self.completion_tokens = 0
53
+
54
+
43
55
  class OpenAI(BaseBackend):
44
56
  def __init__(
45
57
  self,
@@ -79,40 +91,92 @@ class OpenAI(BaseBackend):
79
91
  else:
80
92
  self.is_chat_model = True
81
93
 
82
- self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
94
+ self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
95
+
96
+ # Usage
97
+ self.token_usage = TokenUsage(0, 0)
98
+
99
+ # API speculative execution
100
+ # TODO(ying): This does not support multi-threading (run_batch)
101
+ self.spec_kwargs = {}
102
+ self.spec_format = []
103
+ self.spec_max_num_tries = 3
83
104
 
84
105
  def get_chat_template(self):
85
106
  return self.chat_template
86
107
 
108
+ def _prepare_spec_execution(
109
+ self,
110
+ sampling_params: SglSamplingParams,
111
+ num_api_spec_tokens: int,
112
+ spec_var_name: str,
113
+ ):
114
+ if "max_tokens" not in self.spec_kwargs:
115
+ self.spec_kwargs["max_tokens"] = num_api_spec_tokens
116
+ else:
117
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
118
+
119
+ params = sampling_params.to_openai_kwargs()
120
+ for key, value in params.items():
121
+ if key in ["stop"]:
122
+ continue
123
+ if key in ["max_tokens"]:
124
+ warnings.warn(
125
+ "The parameter max_tokens will be overwritten by speculated number of tokens."
126
+ )
127
+ continue
128
+ if key not in self.spec_kwargs:
129
+ self.spec_kwargs[key] = value
130
+ else:
131
+ assert (
132
+ value == self.spec_kwargs[key]
133
+ ), "sampling parameters should be consistent if turn on api speculative execution."
134
+ self.spec_format.append(
135
+ {"text": "", "stop": params["stop"], "name": spec_var_name}
136
+ )
137
+ return "", {}
138
+
87
139
  def generate(
88
140
  self,
89
141
  s: StreamExecutor,
90
142
  sampling_params: SglSamplingParams,
143
+ spec_var_name: str = None,
91
144
  ):
92
145
  if sampling_params.dtype is None:
93
146
  if self.is_chat_model:
94
- if not s.text_.endswith(self.chat_begin_str):
95
- raise RuntimeError(
96
- "This use case is not supported. "
97
- "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
147
+ if s.num_api_spec_tokens is None:
148
+ if not s.text_.endswith(self.chat_prefix):
149
+ raise RuntimeError(
150
+ "This use case is not supported if api speculative execution is off. "
151
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
152
+ "Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
153
+ )
154
+ prompt = s.messages_
155
+ else:
156
+ return self._prepare_spec_execution(
157
+ sampling_params, s.num_api_spec_tokens, spec_var_name
98
158
  )
99
- prompt = s.messages_
100
159
  else:
101
160
  prompt = s.text_
102
161
 
103
162
  kwargs = sampling_params.to_openai_kwargs()
104
163
  comp = openai_completion(
105
164
  client=self.client,
165
+ token_usage=self.token_usage,
106
166
  is_chat=self.is_chat_model,
107
167
  model=self.model_name,
108
168
  prompt=prompt,
109
169
  **kwargs,
110
170
  )
111
171
  elif sampling_params.dtype in [str, "str", "string"]:
172
+ assert (
173
+ not self.is_chat_model
174
+ ), "constrained type not supported on chat model"
112
175
  kwargs = sampling_params.to_openai_kwargs()
113
176
  kwargs.pop("stop")
114
177
  comp = openai_completion(
115
178
  client=self.client,
179
+ token_usage=self.token_usage,
116
180
  is_chat=self.is_chat_model,
117
181
  model=self.model_name,
118
182
  prompt=s.text_ + '"',
@@ -121,10 +185,14 @@ class OpenAI(BaseBackend):
121
185
  )
122
186
  comp = '"' + comp + '"'
123
187
  elif sampling_params.dtype in [int, "int"]:
188
+ assert (
189
+ not self.is_chat_model
190
+ ), "constrained type not supported on chat model"
124
191
  kwargs = sampling_params.to_openai_kwargs()
125
192
  kwargs.pop("stop")
126
193
  comp = openai_completion(
127
194
  client=self.client,
195
+ token_usage=self.token_usage,
128
196
  is_chat=self.is_chat_model,
129
197
  model=self.model_name,
130
198
  prompt=s.text_,
@@ -137,6 +205,63 @@ class OpenAI(BaseBackend):
137
205
 
138
206
  return comp, {}
139
207
 
208
+ def spec_fill(self, value: str):
209
+ assert self.is_chat_model
210
+ self.spec_format.append({"text": value, "stop": None, "name": None})
211
+
212
+ def spec_pattern_match(self, comp):
213
+ for i, term in enumerate(self.spec_format):
214
+ text = term["text"]
215
+ if text != "":
216
+ if comp.startswith(text):
217
+ comp = comp[len(text) :]
218
+ else:
219
+ return False
220
+ else:
221
+ pos = comp.find(term["stop"])
222
+ if pos != -1:
223
+ term["text"] = comp[:pos]
224
+ comp = comp[pos:]
225
+ else:
226
+ if i == len(self.spec_format) - 1:
227
+ term["text"] = comp
228
+ else:
229
+ return False
230
+ return True
231
+
232
+ def role_end_generate(
233
+ self,
234
+ s: StreamExecutor,
235
+ ):
236
+ if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
237
+ return
238
+
239
+ comp = ""
240
+ if not all(x["name"] is None for x in self.spec_format):
241
+ # TODO(ying): throw errors or warnings
242
+ for i in range(self.spec_max_num_tries):
243
+ comp = openai_completion(
244
+ client=self.client,
245
+ token_usage=self.token_usage,
246
+ is_chat=self.is_chat_model,
247
+ model=self.model_name,
248
+ prompt=s.messages_,
249
+ **self.spec_kwargs,
250
+ )
251
+ if self.spec_pattern_match(comp):
252
+ break
253
+
254
+ for term in self.spec_format:
255
+ s.text_ += term["text"]
256
+ name = term["name"]
257
+ if name is not None:
258
+ s.variables[name] = term["text"]
259
+ s.meta_info[name] = {}
260
+ s.variable_event[name].set()
261
+
262
+ self.spec_kwargs = {}
263
+ self.spec_format = []
264
+
140
265
  def generate_stream(
141
266
  self,
142
267
  s: StreamExecutor,
@@ -144,7 +269,7 @@ class OpenAI(BaseBackend):
144
269
  ):
145
270
  if sampling_params.dtype is None:
146
271
  if self.is_chat_model:
147
- if not s.text_.endswith(self.chat_begin_str):
272
+ if not s.text_.endswith(self.chat_prefix):
148
273
  raise RuntimeError(
149
274
  "This use case is not supported. "
150
275
  "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
@@ -156,6 +281,7 @@ class OpenAI(BaseBackend):
156
281
  kwargs = sampling_params.to_openai_kwargs()
157
282
  generator = openai_completion_stream(
158
283
  client=self.client,
284
+ token_usage=self.token_usage,
159
285
  is_chat=self.is_chat_model,
160
286
  model=self.model_name,
161
287
  prompt=prompt,
@@ -201,6 +327,8 @@ class OpenAI(BaseBackend):
201
327
  )
202
328
  ret_str = ret.choices[0].text
203
329
  ret_token = self.tokenizer.encode(ret_str)[0]
330
+ self.token_usage.prompt_tokens += ret.usage.prompt_tokens
331
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
204
332
 
205
333
  # TODO:
206
334
  # 1. return logits as the scores
@@ -227,10 +355,12 @@ class OpenAI(BaseBackend):
227
355
  prompt_tokens.append(ret_token)
228
356
 
229
357
  decision = choices[np.argmax(scores)]
230
- return decision, scores, scores
358
+ return decision, scores, None, None
231
359
 
232
360
 
233
- def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
361
+ def openai_completion(
362
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
363
+ ):
234
364
  for attempt in range(retries):
235
365
  try:
236
366
  if is_chat:
@@ -244,6 +374,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
244
374
  comp = [c.text for c in ret.choices]
245
375
  else:
246
376
  comp = ret.choices[0].text
377
+
378
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
379
+ token_usage.completion_tokens += ret.usage.completion_tokens
247
380
  break
248
381
  except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
249
382
  logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
@@ -257,16 +390,23 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
257
390
  return comp
258
391
 
259
392
 
260
- def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
393
+ def openai_completion_stream(
394
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
395
+ ):
261
396
  for attempt in range(retries):
262
397
  try:
263
398
  if is_chat:
264
399
  if "stop" in kwargs and kwargs["stop"] is None:
265
400
  kwargs.pop("stop")
266
401
  generator = client.chat.completions.create(
267
- messages=prompt, stream=True, **kwargs
402
+ messages=prompt,
403
+ stream=True,
404
+ stream_options={"include_usage": True},
405
+ **kwargs,
268
406
  )
269
407
  for ret in generator:
408
+ if len(ret.choices) == 0:
409
+ continue
270
410
  try:
271
411
  content = ret.choices[0].delta.content
272
412
  except IndexError:
@@ -274,11 +414,19 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
274
414
  yield content or "", {}
275
415
  else:
276
416
  generator = client.completions.create(
277
- prompt=prompt, stream=True, **kwargs
417
+ prompt=prompt,
418
+ stream=True,
419
+ stream_options={"include_usage": True},
420
+ **kwargs,
278
421
  )
279
422
  for ret in generator:
423
+ if len(ret.choices) == 0:
424
+ continue
280
425
  content = ret.choices[0].text
281
426
  yield content or "", {}
427
+
428
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
429
+ token_usage.completion_tokens += ret.usage.completion_tokens
282
430
  break
283
431
  except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
284
432
  logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")