sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,59 @@
1
- __version__ = "0.1.14"
1
+ __version__ = "0.1.16"
2
2
 
3
- from sglang.api import *
3
+ # SGL API Components
4
+ from sglang.api import (
5
+ Runtime,
6
+ assistant,
7
+ assistant_begin,
8
+ assistant_end,
9
+ flush_cache,
10
+ function,
11
+ gen,
12
+ gen_int,
13
+ gen_string,
14
+ get_server_args,
15
+ image,
16
+ select,
17
+ set_default_backend,
18
+ system,
19
+ user,
20
+ user_begin,
21
+ user_end,
22
+ video,
23
+ )
24
+
25
+ # SGL Backends
26
+ from sglang.backend.anthropic import Anthropic
27
+ from sglang.backend.openai import OpenAI
28
+ from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
+ from sglang.backend.vertexai import VertexAI
30
+
31
+ # Global Configurations
4
32
  from sglang.global_config import global_config
33
+
34
+ # public APIs management
35
+ __all__ = [
36
+ "global_config",
37
+ "Anthropic",
38
+ "OpenAI",
39
+ "RuntimeEndpoint",
40
+ "VertexAI",
41
+ "function",
42
+ "Runtime",
43
+ "set_default_backend",
44
+ "flush_cache",
45
+ "get_server_args",
46
+ "gen",
47
+ "gen_int",
48
+ "gen_string",
49
+ "image",
50
+ "video",
51
+ "select",
52
+ "system",
53
+ "user",
54
+ "assistant",
55
+ "user_begin",
56
+ "user_end",
57
+ "assistant_begin",
58
+ "assistant_end",
59
+ ]
sglang/api.py CHANGED
@@ -1,13 +1,10 @@
1
- """Public API"""
1
+ """Some Public API Definitions"""
2
2
 
3
+ import os
3
4
  import re
4
5
  from typing import Callable, List, Optional, Union
5
6
 
6
- from sglang.backend.anthropic import Anthropic
7
7
  from sglang.backend.base_backend import BaseBackend
8
- from sglang.backend.openai import OpenAI
9
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
10
- from sglang.backend.vertexai import VertexAI
11
8
  from sglang.global_config import global_config
12
9
  from sglang.lang.ir import (
13
10
  SglExpr,
@@ -18,6 +15,7 @@ from sglang.lang.ir import (
18
15
  SglRoleBegin,
19
16
  SglRoleEnd,
20
17
  SglSelect,
18
+ SglVideo,
21
19
  )
22
20
 
23
21
 
@@ -35,6 +33,7 @@ def function(
35
33
 
36
34
  def Runtime(*args, **kwargs):
37
35
  # Avoid importing unnecessary dependency
36
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
38
37
  from sglang.srt.server import Runtime
39
38
 
40
39
  return Runtime(*args, **kwargs)
@@ -153,6 +152,10 @@ def image(expr: SglExpr):
153
152
  return SglImage(expr)
154
153
 
155
154
 
155
+ def video(path: str, num_frames: int):
156
+ return SglVideo(path, num_frames)
157
+
158
+
156
159
  def select(
157
160
  name: Optional[str] = None,
158
161
  choices: List[str] = None,
@@ -1,6 +1,7 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
3
  import numpy as np
4
+
4
5
  from sglang.backend.base_backend import BaseBackend
5
6
  from sglang.lang.chat_template import get_chat_template
6
7
  from sglang.lang.interpreter import StreamExecutor
@@ -13,7 +14,7 @@ except ImportError as e:
13
14
 
14
15
 
15
16
  class Anthropic(BaseBackend):
16
- def __init__(self, model_name):
17
+ def __init__(self, model_name, *args, **kwargs):
17
18
  super().__init__()
18
19
 
19
20
  if isinstance(anthropic, Exception):
@@ -21,6 +22,7 @@ class Anthropic(BaseBackend):
21
22
 
22
23
  self.model_name = model_name
23
24
  self.chat_template = get_chat_template("claude")
25
+ self.client = anthropic.Anthropic(*args, **kwargs)
24
26
 
25
27
  def get_chat_template(self):
26
28
  return self.chat_template
@@ -35,8 +37,14 @@ class Anthropic(BaseBackend):
35
37
  else:
36
38
  messages = [{"role": "user", "content": s.text_}]
37
39
 
38
- ret = anthropic.Anthropic().messages.create(
40
+ if messages and messages[0]["role"] == "system":
41
+ system = messages.pop(0)["content"]
42
+ else:
43
+ system = ""
44
+
45
+ ret = self.client.messages.create(
39
46
  model=self.model_name,
47
+ system=system,
40
48
  messages=messages,
41
49
  **sampling_params.to_anthropic_kwargs(),
42
50
  )
@@ -54,10 +62,16 @@ class Anthropic(BaseBackend):
54
62
  else:
55
63
  messages = [{"role": "user", "content": s.text_}]
56
64
 
57
- with anthropic.Anthropic().messages.stream(
65
+ if messages and messages[0]["role"] == "system":
66
+ system = messages.pop(0)["content"]
67
+ else:
68
+ system = ""
69
+
70
+ with self.client.messages.stream(
58
71
  model=self.model_name,
72
+ system=system,
59
73
  messages=messages,
60
74
  **sampling_params.to_anthropic_kwargs(),
61
75
  ) as stream:
62
76
  for text in stream.text_stream:
63
- yield text, {}
77
+ yield text, {}
sglang/backend/openai.py CHANGED
@@ -3,6 +3,7 @@ import time
3
3
  from typing import Callable, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
+
6
7
  from sglang.backend.base_backend import BaseBackend
7
8
  from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
8
9
  from sglang.lang.interpreter import StreamExecutor
@@ -227,7 +228,7 @@ class OpenAI(BaseBackend):
227
228
  prompt_tokens.append(ret_token)
228
229
 
229
230
  decision = choices[np.argmax(scores)]
230
- return decision, scores, scores
231
+ return decision, scores, None, None
231
232
 
232
233
 
233
234
  def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
  import requests
6
+
6
7
  from sglang.backend.base_backend import BaseBackend
7
8
  from sglang.global_config import global_config
8
9
  from sglang.lang.chat_template import get_chat_template_by_model_path
@@ -73,9 +74,11 @@ class RuntimeEndpoint(BaseBackend):
73
74
  assert res.status_code == 200
74
75
 
75
76
  def commit_lazy_operations(self, s: StreamExecutor):
77
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
78
+ self._add_images(s, data)
76
79
  res = http_request(
77
80
  self.base_url + "/generate",
78
- json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
81
+ json=data,
79
82
  auth_token=self.auth_token,
80
83
  api_key=self.api_key,
81
84
  verify=self.verify,
@@ -104,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
104
107
  "text": s.text_,
105
108
  "sampling_params": {
106
109
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
110
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
107
111
  **sampling_params.to_srt_kwargs(),
108
112
  },
109
113
  }
@@ -112,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
112
116
  "text": s.text_,
113
117
  "sampling_params": {
114
118
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
119
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
115
120
  "dtype": "int",
116
121
  **sampling_params.to_srt_kwargs(),
117
122
  },
@@ -142,6 +147,7 @@ class RuntimeEndpoint(BaseBackend):
142
147
  "text": s.text_,
143
148
  "sampling_params": {
144
149
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
150
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
145
151
  **sampling_params.to_srt_kwargs(),
146
152
  },
147
153
  }
@@ -150,6 +156,7 @@ class RuntimeEndpoint(BaseBackend):
150
156
  "text": s.text_,
151
157
  "sampling_params": {
152
158
  "skip_special_tokens": global_config.skip_special_tokens_in_output,
159
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
153
160
  "dtype": "int",
154
161
  **sampling_params.to_srt_kwargs(),
155
162
  },
@@ -224,13 +231,19 @@ class RuntimeEndpoint(BaseBackend):
224
231
  )
225
232
  assert res.status_code == 200
226
233
  obj = res.json()
227
- normalized_prompt_logprob = [
234
+ normalized_prompt_logprobs = [
228
235
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
229
236
  ]
230
- prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
237
+ decision = choices[np.argmax(normalized_prompt_logprobs)]
238
+ prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
239
+ decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
231
240
 
232
- decision = choices[np.argmax(normalized_prompt_logprob)]
233
- return decision, normalized_prompt_logprob, prompt_logprob
241
+ return (
242
+ decision,
243
+ normalized_prompt_logprobs,
244
+ prefill_token_logprobs,
245
+ decode_token_logprobs,
246
+ )
234
247
 
235
248
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
236
249
  res = http_request(
@@ -3,6 +3,7 @@ import warnings
3
3
  from typing import List, Optional, Union
4
4
 
5
5
  import numpy as np
6
+
6
7
  from sglang.backend.base_backend import BaseBackend
7
8
  from sglang.lang.chat_template import get_chat_template
8
9
  from sglang.lang.interpreter import StreamExecutor
sglang/global_config.py CHANGED
@@ -12,10 +12,11 @@ class GlobalConfig:
12
12
 
13
13
  # Output configs
14
14
  self.skip_special_tokens_in_output = True
15
+ self.spaces_between_special_tokens_in_out = True
15
16
 
16
17
  # Optimization configs
17
18
  self.eager_fill_image = False
18
- self.enable_prefix_sharing = True
19
+ self.enable_precache_with_tracing = True
19
20
  self.enable_parallel_encoding = True
20
21
  self.enable_parallel_decoding = True
21
22
 
@@ -24,5 +25,8 @@ class GlobalConfig:
24
25
  # adjust_cache: Adjust the position embedding of KV cache.
25
26
  self.concate_and_append_mode = "no_adjust"
26
27
 
28
+ # Request dependency time due to network delay
29
+ self.request_dependency_time = 0.03
30
+
27
31
 
28
32
  global_config = GlobalConfig()
@@ -162,6 +162,28 @@ register_chat_template(
162
162
  )
163
163
  )
164
164
 
165
+ register_chat_template(
166
+ ChatTemplate(
167
+ name="llama-3-instruct",
168
+ default_system_prompt=None,
169
+ role_prefix_and_suffix={
170
+ "system": (
171
+ "<|start_header_id|>system<|end_header_id|>\n\n",
172
+ "<|eot_id|>",
173
+ ),
174
+ "user": (
175
+ "<|start_header_id|>user<|end_header_id|>\n\n",
176
+ "<|eot_id|>",
177
+ ),
178
+ "assistant": (
179
+ "<|start_header_id|>assistant<|end_header_id|>\n\n",
180
+ "<|eot_id|>",
181
+ ),
182
+ },
183
+ stop_str=("<|eot_id|>",),
184
+ )
185
+ )
186
+
165
187
  # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
166
188
  register_chat_template(
167
189
  ChatTemplate(
@@ -192,6 +214,44 @@ register_chat_template(
192
214
  )
193
215
  )
194
216
 
217
+ register_chat_template(
218
+ ChatTemplate(
219
+ name="dbrx-instruct",
220
+ default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
221
+ role_prefix_and_suffix={
222
+ "system": ("<|im_start|>system\n", "<|im_end|>"),
223
+ "user": ("\n<|im_start|>user\n", "<|im_end|>"),
224
+ "assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
225
+ },
226
+ stop_str=("<|im_end|>",),
227
+ )
228
+ )
229
+
230
+ register_chat_template(
231
+ ChatTemplate(
232
+ name="c4ai-command-r",
233
+ default_system_prompt=None,
234
+ role_prefix_and_suffix={
235
+ "system": (
236
+ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
237
+ "<|END_OF_TURN_TOKEN|>",
238
+ ),
239
+ "user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
240
+ "assistant": (
241
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
242
+ "<|END_OF_TURN_TOKEN|>",
243
+ ),
244
+ },
245
+ style=ChatTemplateStyle.PLAIN,
246
+ )
247
+ )
248
+
249
+
250
+ @register_chat_template_matching_function
251
+ def match_dbrx(model_path: str):
252
+ if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
253
+ return get_chat_template("dbrx-instruct")
254
+
195
255
 
196
256
  @register_chat_template_matching_function
197
257
  def match_vicuna(model_path: str):
@@ -199,6 +259,8 @@ def match_vicuna(model_path: str):
199
259
  return get_chat_template("vicuna_v1.1")
200
260
  if "llava-v1.5" in model_path.lower():
201
261
  return get_chat_template("vicuna_v1.1")
262
+ if "llava-next-video-7b" in model_path.lower():
263
+ return get_chat_template("vicuna_v1.1")
202
264
 
203
265
 
204
266
  @register_chat_template_matching_function
@@ -214,21 +276,33 @@ def match_llama2_chat(model_path: str):
214
276
  return get_chat_template("llama-2-chat")
215
277
 
216
278
 
279
+ @register_chat_template_matching_function
280
+ def match_llama3_instruct(model_path: str):
281
+ model_path = model_path.lower()
282
+ if "llama-3" in model_path and "instruct" in model_path:
283
+ return get_chat_template("llama-3-instruct")
284
+
285
+
217
286
  @register_chat_template_matching_function
218
287
  def match_chat_ml(model_path: str):
288
+ # import pdb;pdb.set_trace()
219
289
  model_path = model_path.lower()
220
290
  if "tinyllama" in model_path:
221
291
  return get_chat_template("chatml")
222
292
  if "qwen" in model_path and "chat" in model_path:
223
293
  return get_chat_template("chatml")
224
- if "llava-v1.6-34b" in model_path:
294
+ if (
295
+ "llava-v1.6-34b" in model_path
296
+ or "llava-v1.6-yi-34b" in model_path
297
+ or "llava-next-video-34b" in model_path
298
+ ):
225
299
  return get_chat_template("chatml-llava")
226
300
 
227
301
 
228
302
  @register_chat_template_matching_function
229
303
  def match_chat_yi(model_path: str):
230
304
  model_path = model_path.lower()
231
- if "yi" in model_path:
305
+ if "yi" in model_path and "llava" not in model_path:
232
306
  return get_chat_template("yi")
233
307
 
234
308
 
@@ -239,6 +313,13 @@ def match_gemma_it(model_path: str):
239
313
  return get_chat_template("gemma-it")
240
314
 
241
315
 
316
+ @register_chat_template_matching_function
317
+ def match_c4ai_command_r(model_path: str):
318
+ model_path = model_path.lower()
319
+ if "c4ai-command-r" in model_path:
320
+ return get_chat_template("c4ai-command-r")
321
+
322
+
242
323
  if __name__ == "__main__":
243
324
  messages = [
244
325
  {"role": "system", "content": None}, # None means default
@@ -1,6 +1,7 @@
1
1
  """The interpreter that executes SGL programs"""
2
2
 
3
3
  import asyncio
4
+ import contextvars
4
5
  import multiprocessing
5
6
  import queue
6
7
  import threading
@@ -10,6 +11,7 @@ from contextlib import contextmanager
10
11
  from typing import Any, Callable, Dict, List, Optional, Union
11
12
 
12
13
  import tqdm
14
+
13
15
  from sglang.global_config import global_config
14
16
  from sglang.lang.ir import (
15
17
  SglCommitLazy,
@@ -26,8 +28,9 @@ from sglang.lang.ir import (
26
28
  SglVariable,
27
29
  SglVarScopeBegin,
28
30
  SglVarScopeEnd,
31
+ SglVideo,
29
32
  )
30
- from sglang.utils import encode_image_base64
33
+ from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
31
34
 
32
35
 
33
36
  def run_internal(state, program, func_args, func_kwargs, sync):
@@ -84,9 +87,9 @@ def run_program_batch(
84
87
  if hasattr(backend, "endpoint"):
85
88
  backend = backend.endpoint
86
89
 
87
- # Extract prefix by tracing and cache it
88
- if len(batch_arguments) > 1:
89
- pin_program(program, backend)
90
+ # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
91
+ if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
92
+ cache_program(program, backend)
90
93
 
91
94
  # Run all programs
92
95
  if num_threads == "auto":
@@ -152,21 +155,12 @@ def run_program_batch(
152
155
  return rets
153
156
 
154
157
 
155
- def pin_program(program, backend):
156
- if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
157
- # TODO: handle multiple backends
158
- from sglang.lang.tracer import extract_prefix_by_tracing
159
-
160
- prefix = extract_prefix_by_tracing(program, backend)
161
- if prefix and len(prefix) > 64:
162
- prefix_rid = backend.cache_prefix(prefix)
163
- program.pin_prefix_rid = prefix_rid
164
- return prefix_rid
165
- return None
158
+ def cache_program(program, backend):
159
+ from sglang.lang.tracer import extract_prefix_by_tracing
166
160
 
167
-
168
- def unpin_program(program, backend):
169
- pass
161
+ prefix = extract_prefix_by_tracing(program, backend)
162
+ if prefix and len(prefix) > 64:
163
+ backend.cache_prefix(prefix)
170
164
 
171
165
 
172
166
  class StreamExecutor:
@@ -193,6 +187,7 @@ class StreamExecutor:
193
187
  self.variable_event = {} # Dict[name: str -> event: threading.Event]
194
188
  self.meta_info = {} # Dict[name: str -> info: str]
195
189
  self.is_finished = False
190
+ self.error = None
196
191
 
197
192
  # For completion
198
193
  self.text_ = "" # The full text
@@ -217,7 +212,13 @@ class StreamExecutor:
217
212
  self.use_thread = use_thread
218
213
  if self.use_thread:
219
214
  self.queue = queue.Queue()
220
- self.worker = threading.Thread(target=self._thread_worker_func)
215
+
216
+ def _run_worker_in_context():
217
+ self._thread_worker_func()
218
+
219
+ self.worker = threading.Thread(
220
+ target=contextvars.copy_context().run, args=(_run_worker_in_context,)
221
+ )
221
222
  self.worker.start()
222
223
 
223
224
  # For streaming
@@ -248,17 +249,24 @@ class StreamExecutor:
248
249
  def set_var(self, name, value):
249
250
  self.variables[name] = value
250
251
 
251
- def get_meta_info(self, name):
252
+ def get_meta_info(self, name, timeout=None):
252
253
  if name in self.variable_event:
253
- self.variable_event[name].wait()
254
+ got = self.variable_event[name].wait(timeout)
255
+ if not got:
256
+ raise TimeoutError(f"Timeout while waiting for event '{name}'")
254
257
  ret = self.meta_info.get(name, None)
255
258
  return ret
256
259
 
257
- def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
258
- self.submit(SglCommitLazy())
259
- self.sync()
260
+ def fork(
261
+ self,
262
+ size: int = 1,
263
+ position_ids_offset: Optional[List[int]] = None,
264
+ ):
265
+ if size > 1:
266
+ self.submit(SglCommitLazy())
260
267
 
261
- number = int(number)
268
+ self.sync()
269
+ size = int(size)
262
270
 
263
271
  exes = [
264
272
  StreamExecutor(
@@ -268,14 +276,15 @@ class StreamExecutor:
268
276
  self.chat_template,
269
277
  self.stream,
270
278
  )
271
- for _ in range(number)
279
+ for _ in range(size)
272
280
  ]
273
- for i in range(number):
281
+ for i in range(size):
274
282
  exes[i].variables = dict(self.variables)
275
283
  exes[i].text_ = str(self.text_)
276
284
  exes[i].messages_ = list(self.messages_)
277
285
  exes[i].cur_role = self.cur_role
278
286
  exes[i].fork_start_text_pos = len(self.text_)
287
+ exes[i].images_ = list(self.images_)
279
288
 
280
289
  return exes
281
290
 
@@ -294,17 +303,39 @@ class StreamExecutor:
294
303
  self.backend.end_program(self)
295
304
 
296
305
  def _thread_worker_func(self):
306
+ error = None
307
+
297
308
  while True:
298
309
  expr = self.queue.get()
299
310
  if expr is None:
300
311
  self.queue.task_done()
301
312
  break
302
313
 
303
- self._execute(expr)
314
+ try:
315
+ self._execute(expr)
316
+ except Exception as e:
317
+ # print(f"Error in stream_executor: {get_exception_traceback()}")
318
+ error = e
319
+ break
304
320
  self.queue.task_done()
305
321
  if self.stream_text_event:
306
322
  self.stream_text_event.set()
307
323
 
324
+ # Clean the queue and events
325
+ if error is not None:
326
+ try:
327
+ while True:
328
+ self.queue.task_done()
329
+ self.queue.get_nowait()
330
+ except queue.Empty:
331
+ pass
332
+ for name in self.variable_event:
333
+ self.variable_event[name].set()
334
+ if self.stream_var_event:
335
+ for name in self.stream_var_event:
336
+ self.stream_var_event[name].set()
337
+ self.error = error
338
+
308
339
  if self.stream_text_event:
309
340
  self.stream_text_event.set()
310
341
 
@@ -331,6 +362,8 @@ class StreamExecutor:
331
362
  self._execute_role_end(other)
332
363
  elif isinstance(other, SglImage):
333
364
  self._execute_image(other)
365
+ elif isinstance(other, SglVideo):
366
+ self._execute_video(other)
334
367
  elif isinstance(other, SglVariable):
335
368
  self._execute_variable(other)
336
369
  elif isinstance(other, SglVarScopeBegin):
@@ -367,6 +400,16 @@ class StreamExecutor:
367
400
  self.cur_images.append((path, base64_data))
368
401
  self.text_ += self.chat_template.image_token
369
402
 
403
+ def _execute_video(self, expr: SglVideo):
404
+ path = expr.path
405
+ num_frames = expr.num_frames
406
+
407
+ base64_data = encode_video_base64(path, num_frames)
408
+
409
+ self.images_.append((path, base64_data))
410
+ self.cur_images.append((path, base64_data))
411
+ self.text_ += self.chat_template.image_token
412
+
370
413
  # if global_config.eager_fill_image:
371
414
  # self.backend.fill_image(self)
372
415
 
@@ -454,15 +497,19 @@ class StreamExecutor:
454
497
  self.stream_var_event[name].set()
455
498
 
456
499
  def _execute_select(self, expr: SglSelect):
457
- decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
458
- self, expr.choices, expr.temperature
459
- )
500
+ (
501
+ decision,
502
+ normalized_prompt_logprobs,
503
+ prefill_token_logprobs,
504
+ decode_token_logprobs,
505
+ ) = self.backend.select(self, expr.choices, expr.temperature)
460
506
  if expr.name is not None:
461
507
  name = expr.name
462
508
  self.variables[name] = decision
463
509
  self.meta_info[name] = {
464
- "normalized_prompt_logprob": normalized_prompt_logprob,
465
- "prompt_logprob": prompt_logprob,
510
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
511
+ "prefill_token_logprobs": prefill_token_logprobs,
512
+ "decode_token_logprobs": decode_token_logprobs,
466
513
  }
467
514
  self.variable_event[name].set()
468
515
  self.text_ += decision
@@ -634,8 +681,12 @@ class ProgramState:
634
681
  yield
635
682
  self.stream_executor.submit(SglVarScopeEnd(name))
636
683
 
637
- def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
638
- stream_executors = self.stream_executor.fork(number, position_ids_offset)
684
+ def fork(
685
+ self,
686
+ size: int = 1,
687
+ position_ids_offset: Optional[List[int]] = None,
688
+ ):
689
+ stream_executors = self.stream_executor.fork(size, position_ids_offset)
639
690
  states = [ProgramState(x) for x in stream_executors]
640
691
  state_group = ProgramStateGroup(states, self)
641
692
  return state_group
@@ -657,6 +708,9 @@ class ProgramState:
657
708
  def sync(self):
658
709
  return self.stream_executor.sync()
659
710
 
711
+ def error(self):
712
+ return self.stream_executor.error
713
+
660
714
  def text_iter(self, var_name: Optional[str] = None):
661
715
  if self.stream_executor.stream:
662
716
  prev = 0
@@ -745,6 +799,9 @@ class ProgramState:
745
799
  def __setitem__(self, name, value):
746
800
  self.set_var(name, value)
747
801
 
802
+ def __contains__(self, name):
803
+ return name in self.stream_executor.variables
804
+
748
805
  def __del__(self):
749
806
  self.stream_executor.end()
750
807