sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.16"
1
+ __version__ = "0.1.18"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -24,6 +24,7 @@ from sglang.api import (
24
24
 
25
25
  # SGL Backends
26
26
  from sglang.backend.anthropic import Anthropic
27
+ from sglang.backend.litellm import LiteLLM
27
28
  from sglang.backend.openai import OpenAI
28
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
30
  from sglang.backend.vertexai import VertexAI
@@ -35,6 +36,7 @@ from sglang.global_config import global_config
35
36
  __all__ = [
36
37
  "global_config",
37
38
  "Anthropic",
39
+ "LiteLLM",
38
40
  "OpenAI",
39
41
  "RuntimeEndpoint",
40
42
  "VertexAI",
sglang/api.py CHANGED
@@ -1,4 +1,4 @@
1
- """Some Public API Definitions"""
1
+ """Public APIs of the language."""
2
2
 
3
3
  import os
4
4
  import re
@@ -20,13 +20,13 @@ from sglang.lang.ir import (
20
20
 
21
21
 
22
22
  def function(
23
- func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
23
+ func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
24
24
  ):
25
25
  if func:
26
- return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
26
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
27
27
 
28
28
  def decorator(func):
29
- return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
29
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
30
30
 
31
31
  return decorator
32
32
 
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
43
43
  global_config.default_backend = backend
44
44
 
45
45
 
46
- def flush_cache(backend: BaseBackend = None):
46
+ def flush_cache(backend: Optional[BaseBackend] = None):
47
47
  backend = backend or global_config.default_backend
48
48
  if backend is None:
49
49
  return False
50
50
  return backend.flush_cache()
51
51
 
52
52
 
53
- def get_server_args(backend: BaseBackend = None):
53
+ def get_server_args(backend: Optional[BaseBackend] = None):
54
54
  backend = backend or global_config.default_backend
55
55
  if backend is None:
56
56
  return None
@@ -158,7 +158,7 @@ def video(path: str, num_frames: int):
158
158
 
159
159
  def select(
160
160
  name: Optional[str] = None,
161
- choices: List[str] = None,
161
+ choices: Optional[List[str]] = None,
162
162
  temperature: float = 0.0,
163
163
  ):
164
164
  assert choices is not None
@@ -74,4 +74,4 @@ class Anthropic(BaseBackend):
74
74
  **sampling_params.to_anthropic_kwargs(),
75
75
  ) as stream:
76
76
  for text in stream.text_stream:
77
- yield text, {}
77
+ yield text, {}
@@ -0,0 +1,90 @@
1
+ from typing import Mapping, Optional
2
+
3
+ from sglang.backend.base_backend import BaseBackend
4
+ from sglang.lang.chat_template import get_chat_template_by_model_path
5
+ from sglang.lang.interpreter import StreamExecutor
6
+ from sglang.lang.ir import SglSamplingParams
7
+
8
+ try:
9
+ import litellm
10
+ except ImportError as e:
11
+ litellm = e
12
+ litellm.num_retries = 1
13
+
14
+
15
+ class LiteLLM(BaseBackend):
16
+ def __init__(
17
+ self,
18
+ model_name,
19
+ chat_template=None,
20
+ api_key=None,
21
+ organization: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ timeout: Optional[float] = 600,
24
+ max_retries: Optional[int] = litellm.num_retries,
25
+ default_headers: Optional[Mapping[str, str]] = None,
26
+ ):
27
+ super().__init__()
28
+
29
+ if isinstance(litellm, Exception):
30
+ raise litellm
31
+
32
+ self.model_name = model_name
33
+
34
+ self.chat_template = chat_template or get_chat_template_by_model_path(
35
+ model_name
36
+ )
37
+
38
+ self.client_params = {
39
+ "api_key": api_key,
40
+ "organization": organization,
41
+ "base_url": base_url,
42
+ "timeout": timeout,
43
+ "max_retries": max_retries,
44
+ "default_headers": default_headers,
45
+ }
46
+
47
+ def get_chat_template(self):
48
+ return self.chat_template
49
+
50
+ def generate(
51
+ self,
52
+ s: StreamExecutor,
53
+ sampling_params: SglSamplingParams,
54
+ ):
55
+ if s.messages_:
56
+ messages = s.messages_
57
+ else:
58
+ messages = [{"role": "user", "content": s.text_}]
59
+
60
+ ret = litellm.completion(
61
+ model=self.model_name,
62
+ messages=messages,
63
+ **self.client_params,
64
+ **sampling_params.to_anthropic_kwargs(),
65
+ )
66
+ comp = ret.choices[0].message.content
67
+
68
+ return comp, {}
69
+
70
+ def generate_stream(
71
+ self,
72
+ s: StreamExecutor,
73
+ sampling_params: SglSamplingParams,
74
+ ):
75
+ if s.messages_:
76
+ messages = s.messages_
77
+ else:
78
+ messages = [{"role": "user", "content": s.text_}]
79
+
80
+ ret = litellm.completion(
81
+ model=self.model_name,
82
+ messages=messages,
83
+ stream=True,
84
+ **self.client_params,
85
+ **sampling_params.to_litellm_kwargs(),
86
+ )
87
+ for chunk in ret:
88
+ text = chunk.choices[0].delta.content
89
+ if text is not None:
90
+ yield text, {}
sglang/backend/openai.py CHANGED
@@ -1,5 +1,7 @@
1
+ import dataclasses
1
2
  import logging
2
3
  import time
4
+ import warnings
3
5
  from typing import Callable, List, Optional, Union
4
6
 
5
7
  import numpy as np
@@ -41,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
41
43
  ]
42
44
 
43
45
 
46
+ @dataclasses.dataclass
47
+ class TokenUsage:
48
+ prompt_tokens: int
49
+ completion_tokens: int
50
+
51
+ def reset(self):
52
+ self.prompt_tokens = self.completion_tokens = 0
53
+
54
+
44
55
  class OpenAI(BaseBackend):
45
56
  def __init__(
46
57
  self,
@@ -80,40 +91,92 @@ class OpenAI(BaseBackend):
80
91
  else:
81
92
  self.is_chat_model = True
82
93
 
83
- self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
94
+ self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
95
+
96
+ # Usage
97
+ self.token_usage = TokenUsage(0, 0)
98
+
99
+ # API speculative execution
100
+ # TODO(ying): This does not support multi-threading (run_batch)
101
+ self.spec_kwargs = {}
102
+ self.spec_format = []
103
+ self.spec_max_num_tries = 3
84
104
 
85
105
  def get_chat_template(self):
86
106
  return self.chat_template
87
107
 
108
+ def _prepare_spec_execution(
109
+ self,
110
+ sampling_params: SglSamplingParams,
111
+ num_api_spec_tokens: int,
112
+ spec_var_name: str,
113
+ ):
114
+ if "max_tokens" not in self.spec_kwargs:
115
+ self.spec_kwargs["max_tokens"] = num_api_spec_tokens
116
+ else:
117
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
118
+
119
+ params = sampling_params.to_openai_kwargs()
120
+ for key, value in params.items():
121
+ if key in ["stop"]:
122
+ continue
123
+ if key in ["max_tokens"]:
124
+ warnings.warn(
125
+ "The parameter max_tokens will be overwritten by speculated number of tokens."
126
+ )
127
+ continue
128
+ if key not in self.spec_kwargs:
129
+ self.spec_kwargs[key] = value
130
+ else:
131
+ assert (
132
+ value == self.spec_kwargs[key]
133
+ ), "sampling parameters should be consistent if turn on api speculative execution."
134
+ self.spec_format.append(
135
+ {"text": "", "stop": params["stop"], "name": spec_var_name}
136
+ )
137
+ return "", {}
138
+
88
139
  def generate(
89
140
  self,
90
141
  s: StreamExecutor,
91
142
  sampling_params: SglSamplingParams,
143
+ spec_var_name: str = None,
92
144
  ):
93
145
  if sampling_params.dtype is None:
94
146
  if self.is_chat_model:
95
- if not s.text_.endswith(self.chat_begin_str):
96
- raise RuntimeError(
97
- "This use case is not supported. "
98
- "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
147
+ if s.num_api_spec_tokens is None:
148
+ if not s.text_.endswith(self.chat_prefix):
149
+ raise RuntimeError(
150
+ "This use case is not supported if api speculative execution is off. "
151
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
152
+ "Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
153
+ )
154
+ prompt = s.messages_
155
+ else:
156
+ return self._prepare_spec_execution(
157
+ sampling_params, s.num_api_spec_tokens, spec_var_name
99
158
  )
100
- prompt = s.messages_
101
159
  else:
102
160
  prompt = s.text_
103
161
 
104
162
  kwargs = sampling_params.to_openai_kwargs()
105
163
  comp = openai_completion(
106
164
  client=self.client,
165
+ token_usage=self.token_usage,
107
166
  is_chat=self.is_chat_model,
108
167
  model=self.model_name,
109
168
  prompt=prompt,
110
169
  **kwargs,
111
170
  )
112
171
  elif sampling_params.dtype in [str, "str", "string"]:
172
+ assert (
173
+ not self.is_chat_model
174
+ ), "constrained type not supported on chat model"
113
175
  kwargs = sampling_params.to_openai_kwargs()
114
176
  kwargs.pop("stop")
115
177
  comp = openai_completion(
116
178
  client=self.client,
179
+ token_usage=self.token_usage,
117
180
  is_chat=self.is_chat_model,
118
181
  model=self.model_name,
119
182
  prompt=s.text_ + '"',
@@ -122,10 +185,14 @@ class OpenAI(BaseBackend):
122
185
  )
123
186
  comp = '"' + comp + '"'
124
187
  elif sampling_params.dtype in [int, "int"]:
188
+ assert (
189
+ not self.is_chat_model
190
+ ), "constrained type not supported on chat model"
125
191
  kwargs = sampling_params.to_openai_kwargs()
126
192
  kwargs.pop("stop")
127
193
  comp = openai_completion(
128
194
  client=self.client,
195
+ token_usage=self.token_usage,
129
196
  is_chat=self.is_chat_model,
130
197
  model=self.model_name,
131
198
  prompt=s.text_,
@@ -138,6 +205,63 @@ class OpenAI(BaseBackend):
138
205
 
139
206
  return comp, {}
140
207
 
208
+ def spec_fill(self, value: str):
209
+ assert self.is_chat_model
210
+ self.spec_format.append({"text": value, "stop": None, "name": None})
211
+
212
+ def spec_pattern_match(self, comp):
213
+ for i, term in enumerate(self.spec_format):
214
+ text = term["text"]
215
+ if text != "":
216
+ if comp.startswith(text):
217
+ comp = comp[len(text) :]
218
+ else:
219
+ return False
220
+ else:
221
+ pos = comp.find(term["stop"])
222
+ if pos != -1:
223
+ term["text"] = comp[:pos]
224
+ comp = comp[pos:]
225
+ else:
226
+ if i == len(self.spec_format) - 1:
227
+ term["text"] = comp
228
+ else:
229
+ return False
230
+ return True
231
+
232
+ def role_end_generate(
233
+ self,
234
+ s: StreamExecutor,
235
+ ):
236
+ if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
237
+ return
238
+
239
+ comp = ""
240
+ if not all(x["name"] is None for x in self.spec_format):
241
+ # TODO(ying): throw errors or warnings
242
+ for i in range(self.spec_max_num_tries):
243
+ comp = openai_completion(
244
+ client=self.client,
245
+ token_usage=self.token_usage,
246
+ is_chat=self.is_chat_model,
247
+ model=self.model_name,
248
+ prompt=s.messages_,
249
+ **self.spec_kwargs,
250
+ )
251
+ if self.spec_pattern_match(comp):
252
+ break
253
+
254
+ for term in self.spec_format:
255
+ s.text_ += term["text"]
256
+ name = term["name"]
257
+ if name is not None:
258
+ s.variables[name] = term["text"]
259
+ s.meta_info[name] = {}
260
+ s.variable_event[name].set()
261
+
262
+ self.spec_kwargs = {}
263
+ self.spec_format = []
264
+
141
265
  def generate_stream(
142
266
  self,
143
267
  s: StreamExecutor,
@@ -145,7 +269,7 @@ class OpenAI(BaseBackend):
145
269
  ):
146
270
  if sampling_params.dtype is None:
147
271
  if self.is_chat_model:
148
- if not s.text_.endswith(self.chat_begin_str):
272
+ if not s.text_.endswith(self.chat_prefix):
149
273
  raise RuntimeError(
150
274
  "This use case is not supported. "
151
275
  "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
@@ -157,6 +281,7 @@ class OpenAI(BaseBackend):
157
281
  kwargs = sampling_params.to_openai_kwargs()
158
282
  generator = openai_completion_stream(
159
283
  client=self.client,
284
+ token_usage=self.token_usage,
160
285
  is_chat=self.is_chat_model,
161
286
  model=self.model_name,
162
287
  prompt=prompt,
@@ -202,6 +327,8 @@ class OpenAI(BaseBackend):
202
327
  )
203
328
  ret_str = ret.choices[0].text
204
329
  ret_token = self.tokenizer.encode(ret_str)[0]
330
+ self.token_usage.prompt_tokens += ret.usage.prompt_tokens
331
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
205
332
 
206
333
  # TODO:
207
334
  # 1. return logits as the scores
@@ -231,7 +358,9 @@ class OpenAI(BaseBackend):
231
358
  return decision, scores, None, None
232
359
 
233
360
 
234
- def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
361
+ def openai_completion(
362
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
363
+ ):
235
364
  for attempt in range(retries):
236
365
  try:
237
366
  if is_chat:
@@ -245,6 +374,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
245
374
  comp = [c.text for c in ret.choices]
246
375
  else:
247
376
  comp = ret.choices[0].text
377
+
378
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
379
+ token_usage.completion_tokens += ret.usage.completion_tokens
248
380
  break
249
381
  except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
250
382
  logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
@@ -258,16 +390,23 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
258
390
  return comp
259
391
 
260
392
 
261
- def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
393
+ def openai_completion_stream(
394
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
395
+ ):
262
396
  for attempt in range(retries):
263
397
  try:
264
398
  if is_chat:
265
399
  if "stop" in kwargs and kwargs["stop"] is None:
266
400
  kwargs.pop("stop")
267
401
  generator = client.chat.completions.create(
268
- messages=prompt, stream=True, **kwargs
402
+ messages=prompt,
403
+ stream=True,
404
+ stream_options={"include_usage": True},
405
+ **kwargs,
269
406
  )
270
407
  for ret in generator:
408
+ if len(ret.choices) == 0:
409
+ continue
271
410
  try:
272
411
  content = ret.choices[0].delta.content
273
412
  except IndexError:
@@ -275,11 +414,19 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
275
414
  yield content or "", {}
276
415
  else:
277
416
  generator = client.completions.create(
278
- prompt=prompt, stream=True, **kwargs
417
+ prompt=prompt,
418
+ stream=True,
419
+ stream_options={"include_usage": True},
420
+ **kwargs,
279
421
  )
280
422
  for ret in generator:
423
+ if len(ret.choices) == 0:
424
+ continue
281
425
  content = ret.choices[0].text
282
426
  yield content or "", {}
427
+
428
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
429
+ token_usage.completion_tokens += ret.usage.completion_tokens
283
430
  break
284
431
  except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
285
432
  logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
34
34
  api_key=self.api_key,
35
35
  verify=self.verify,
36
36
  )
37
- assert res.status_code == 200
37
+ self._assert_success(res)
38
38
  self.model_info = res.json()
39
39
 
40
40
  self.chat_template = get_chat_template_by_model_path(
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
50
50
  auth_token=self.auth_token,
51
51
  verify=self.verify,
52
52
  )
53
- return res.status_code == 200
53
+ self._assert_success(res)
54
54
 
55
55
  def get_server_args(self):
56
56
  res = http_request(
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
58
58
  auth_token=self.auth_token,
59
59
  verify=self.verify,
60
60
  )
61
+ self._assert_success(res)
61
62
  return res.json()
62
63
 
63
64
  def get_chat_template(self):
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
71
72
  api_key=self.api_key,
72
73
  verify=self.verify,
73
74
  )
74
- assert res.status_code == 200
75
+ self._assert_success(res)
75
76
 
76
77
  def commit_lazy_operations(self, s: StreamExecutor):
77
78
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
83
84
  api_key=self.api_key,
84
85
  verify=self.verify,
85
86
  )
86
- assert res.status_code == 200
87
+ self._assert_success(res)
87
88
 
88
89
  def fill_image(self, s: StreamExecutor):
89
90
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
95
96
  api_key=self.api_key,
96
97
  verify=self.verify,
97
98
  )
98
- assert res.status_code == 200
99
+ self._assert_success(res)
99
100
 
100
101
  def generate(
101
102
  self,
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
133
134
  api_key=self.api_key,
134
135
  verify=self.verify,
135
136
  )
137
+ self._assert_success(res)
138
+
136
139
  obj = res.json()
137
140
  comp = obj["text"]
138
141
  return comp, obj["meta_info"]
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
167
170
  data["stream"] = True
168
171
  self._add_images(s, data)
169
172
 
170
- response = http_request(
173
+ res = http_request(
171
174
  self.base_url + "/generate",
172
175
  json=data,
173
176
  stream=True,
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
175
178
  api_key=self.api_key,
176
179
  verify=self.verify,
177
180
  )
181
+ self._assert_success(res)
178
182
  pos = 0
179
183
 
180
184
  incomplete_text = ""
181
- for chunk in response.iter_lines(decode_unicode=False):
185
+ for chunk in res.iter_lines(decode_unicode=False):
182
186
  chunk = chunk.decode("utf-8")
183
187
  if chunk and chunk.startswith("data:"):
184
188
  if chunk == "data: [DONE]":
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
211
215
  api_key=self.api_key,
212
216
  verify=self.verify,
213
217
  )
214
- assert res.status_code == 200
218
+ self._assert_success(res)
215
219
  prompt_len = res.json()["meta_info"]["prompt_tokens"]
216
220
 
217
221
  # Compute logprob
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
229
233
  api_key=self.api_key,
230
234
  verify=self.verify,
231
235
  )
232
- assert res.status_code == 200
236
+ self._assert_success(res)
233
237
  obj = res.json()
234
238
  normalized_prompt_logprobs = [
235
239
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
253
257
  api_key=self.api_key,
254
258
  verify=self.verify,
255
259
  )
256
- assert res.status_code == 200
260
+ self._assert_success(res)
257
261
 
258
262
  def _add_images(self, s: StreamExecutor, data):
259
263
  if s.images_:
260
264
  assert len(s.images_) == 1, "Only support one image."
261
265
  data["image_data"] = s.images_[0][1]
266
+
267
+ def _assert_success(self, res):
268
+ if res.status_code != 200:
269
+ raise RuntimeError(res.json())