sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +5 -0
  3. sglang/global_config.py +4 -1
  4. sglang/lang/chat_template.py +9 -2
  5. sglang/lang/interpreter.py +52 -19
  6. sglang/lang/ir.py +12 -9
  7. sglang/lang/tracer.py +1 -1
  8. sglang/launch_server.py +1 -2
  9. sglang/launch_server_llavavid.py +31 -0
  10. sglang/srt/flush_cache.py +16 -0
  11. sglang/srt/hf_transformers_utils.py +8 -1
  12. sglang/srt/managers/io_struct.py +15 -3
  13. sglang/srt/managers/router/infer_batch.py +31 -19
  14. sglang/srt/managers/router/manager.py +6 -8
  15. sglang/srt/managers/router/model_rpc.py +59 -23
  16. sglang/srt/managers/router/model_runner.py +6 -6
  17. sglang/srt/managers/router/radix_cache.py +47 -17
  18. sglang/srt/managers/router/scheduler.py +17 -28
  19. sglang/srt/managers/tokenizer_manager.py +54 -22
  20. sglang/srt/model_config.py +4 -0
  21. sglang/srt/models/commandr.py +6 -10
  22. sglang/srt/models/dbrx.py +14 -15
  23. sglang/srt/models/gemma.py +7 -10
  24. sglang/srt/models/llama2.py +7 -10
  25. sglang/srt/models/llava.py +2 -6
  26. sglang/srt/models/llavavid.py +307 -0
  27. sglang/srt/models/mixtral.py +7 -13
  28. sglang/srt/models/qwen.py +20 -13
  29. sglang/srt/models/qwen2.py +7 -10
  30. sglang/srt/models/stablelm.py +13 -12
  31. sglang/srt/models/yivl.py +1 -4
  32. sglang/srt/server.py +32 -18
  33. sglang/srt/server_args.py +9 -6
  34. sglang/srt/utils.py +126 -17
  35. sglang/srt/weight_utils.py +66 -51
  36. sglang/utils.py +77 -26
  37. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
  38. sglang-0.1.16.dist-info/RECORD +72 -0
  39. sglang-0.1.15.dist-info/RECORD +0 -69
  40. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  41. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  42. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.15"
1
+ __version__ = "0.1.16"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -19,6 +19,7 @@ from sglang.api import (
19
19
  user,
20
20
  user_begin,
21
21
  user_end,
22
+ video,
22
23
  )
23
24
 
24
25
  # SGL Backends
@@ -46,6 +47,7 @@ __all__ = [
46
47
  "gen_int",
47
48
  "gen_string",
48
49
  "image",
50
+ "video",
49
51
  "select",
50
52
  "system",
51
53
  "user",
sglang/api.py CHANGED
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
15
15
  SglRoleBegin,
16
16
  SglRoleEnd,
17
17
  SglSelect,
18
+ SglVideo,
18
19
  )
19
20
 
20
21
 
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
151
152
  return SglImage(expr)
152
153
 
153
154
 
155
+ def video(path: str, num_frames: int):
156
+ return SglVideo(path, num_frames)
157
+
158
+
154
159
  def select(
155
160
  name: Optional[str] = None,
156
161
  choices: List[str] = None,
sglang/global_config.py CHANGED
@@ -16,7 +16,7 @@ class GlobalConfig:
16
16
 
17
17
  # Optimization configs
18
18
  self.eager_fill_image = False
19
- self.enable_prefix_sharing = True
19
+ self.enable_precache_with_tracing = True
20
20
  self.enable_parallel_encoding = True
21
21
  self.enable_parallel_decoding = True
22
22
 
@@ -25,5 +25,8 @@ class GlobalConfig:
25
25
  # adjust_cache: Adjust the position embedding of KV cache.
26
26
  self.concate_and_append_mode = "no_adjust"
27
27
 
28
+ # Request dependency time due to network delay
29
+ self.request_dependency_time = 0.03
30
+
28
31
 
29
32
  global_config = GlobalConfig()
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
259
259
  return get_chat_template("vicuna_v1.1")
260
260
  if "llava-v1.5" in model_path.lower():
261
261
  return get_chat_template("vicuna_v1.1")
262
+ if "llava-next-video-7b" in model_path.lower():
263
+ return get_chat_template("vicuna_v1.1")
262
264
 
263
265
 
264
266
  @register_chat_template_matching_function
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
283
285
 
284
286
  @register_chat_template_matching_function
285
287
  def match_chat_ml(model_path: str):
288
+ # import pdb;pdb.set_trace()
286
289
  model_path = model_path.lower()
287
290
  if "tinyllama" in model_path:
288
291
  return get_chat_template("chatml")
289
292
  if "qwen" in model_path and "chat" in model_path:
290
293
  return get_chat_template("chatml")
291
- if "llava-v1.6-34b" in model_path:
294
+ if (
295
+ "llava-v1.6-34b" in model_path
296
+ or "llava-v1.6-yi-34b" in model_path
297
+ or "llava-next-video-34b" in model_path
298
+ ):
292
299
  return get_chat_template("chatml-llava")
293
300
 
294
301
 
295
302
  @register_chat_template_matching_function
296
303
  def match_chat_yi(model_path: str):
297
304
  model_path = model_path.lower()
298
- if "yi" in model_path:
305
+ if "yi" in model_path and "llava" not in model_path:
299
306
  return get_chat_template("yi")
300
307
 
301
308
 
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
28
28
  SglVariable,
29
29
  SglVarScopeBegin,
30
30
  SglVarScopeEnd,
31
+ SglVideo,
31
32
  )
32
- from sglang.utils import encode_image_base64
33
+ from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
33
34
 
34
35
 
35
36
  def run_internal(state, program, func_args, func_kwargs, sync):
@@ -86,9 +87,9 @@ def run_program_batch(
86
87
  if hasattr(backend, "endpoint"):
87
88
  backend = backend.endpoint
88
89
 
89
- # Extract prefix by tracing and cache it
90
- if len(batch_arguments) > 1:
91
- pin_program(program, backend)
90
+ # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
91
+ if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
92
+ cache_program(program, backend)
92
93
 
93
94
  # Run all programs
94
95
  if num_threads == "auto":
@@ -154,21 +155,12 @@ def run_program_batch(
154
155
  return rets
155
156
 
156
157
 
157
- def pin_program(program, backend):
158
- if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
159
- # TODO: handle multiple backends
160
- from sglang.lang.tracer import extract_prefix_by_tracing
158
+ def cache_program(program, backend):
159
+ from sglang.lang.tracer import extract_prefix_by_tracing
161
160
 
162
- prefix = extract_prefix_by_tracing(program, backend)
163
- if prefix and len(prefix) > 64:
164
- prefix_rid = backend.cache_prefix(prefix)
165
- program.pin_prefix_rid = prefix_rid
166
- return prefix_rid
167
- return None
168
-
169
-
170
- def unpin_program(program, backend):
171
- pass
161
+ prefix = extract_prefix_by_tracing(program, backend)
162
+ if prefix and len(prefix) > 64:
163
+ backend.cache_prefix(prefix)
172
164
 
173
165
 
174
166
  class StreamExecutor:
@@ -195,6 +187,7 @@ class StreamExecutor:
195
187
  self.variable_event = {} # Dict[name: str -> event: threading.Event]
196
188
  self.meta_info = {} # Dict[name: str -> info: str]
197
189
  self.is_finished = False
190
+ self.error = None
198
191
 
199
192
  # For completion
200
193
  self.text_ = "" # The full text
@@ -310,17 +303,39 @@ class StreamExecutor:
310
303
  self.backend.end_program(self)
311
304
 
312
305
  def _thread_worker_func(self):
306
+ error = None
307
+
313
308
  while True:
314
309
  expr = self.queue.get()
315
310
  if expr is None:
316
311
  self.queue.task_done()
317
312
  break
318
313
 
319
- self._execute(expr)
314
+ try:
315
+ self._execute(expr)
316
+ except Exception as e:
317
+ # print(f"Error in stream_executor: {get_exception_traceback()}")
318
+ error = e
319
+ break
320
320
  self.queue.task_done()
321
321
  if self.stream_text_event:
322
322
  self.stream_text_event.set()
323
323
 
324
+ # Clean the queue and events
325
+ if error is not None:
326
+ try:
327
+ while True:
328
+ self.queue.task_done()
329
+ self.queue.get_nowait()
330
+ except queue.Empty:
331
+ pass
332
+ for name in self.variable_event:
333
+ self.variable_event[name].set()
334
+ if self.stream_var_event:
335
+ for name in self.stream_var_event:
336
+ self.stream_var_event[name].set()
337
+ self.error = error
338
+
324
339
  if self.stream_text_event:
325
340
  self.stream_text_event.set()
326
341
 
@@ -347,6 +362,8 @@ class StreamExecutor:
347
362
  self._execute_role_end(other)
348
363
  elif isinstance(other, SglImage):
349
364
  self._execute_image(other)
365
+ elif isinstance(other, SglVideo):
366
+ self._execute_video(other)
350
367
  elif isinstance(other, SglVariable):
351
368
  self._execute_variable(other)
352
369
  elif isinstance(other, SglVarScopeBegin):
@@ -383,6 +400,16 @@ class StreamExecutor:
383
400
  self.cur_images.append((path, base64_data))
384
401
  self.text_ += self.chat_template.image_token
385
402
 
403
+ def _execute_video(self, expr: SglVideo):
404
+ path = expr.path
405
+ num_frames = expr.num_frames
406
+
407
+ base64_data = encode_video_base64(path, num_frames)
408
+
409
+ self.images_.append((path, base64_data))
410
+ self.cur_images.append((path, base64_data))
411
+ self.text_ += self.chat_template.image_token
412
+
386
413
  # if global_config.eager_fill_image:
387
414
  # self.backend.fill_image(self)
388
415
 
@@ -681,6 +708,9 @@ class ProgramState:
681
708
  def sync(self):
682
709
  return self.stream_executor.sync()
683
710
 
711
+ def error(self):
712
+ return self.stream_executor.error
713
+
684
714
  def text_iter(self, var_name: Optional[str] = None):
685
715
  if self.stream_executor.stream:
686
716
  prev = 0
@@ -769,6 +799,9 @@ class ProgramState:
769
799
  def __setitem__(self, name, value):
770
800
  self.set_var(name, value)
771
801
 
802
+ def __contains__(self, name):
803
+ return name in self.stream_executor.variables
804
+
772
805
  def __del__(self):
773
806
  self.stream_executor.end()
774
807
 
sglang/lang/ir.py CHANGED
@@ -193,17 +193,11 @@ class SglFunction:
193
193
  backend = backend or global_config.default_backend
194
194
  return trace_program(self, kwargs, backend)
195
195
 
196
- def pin(self, backend=None):
197
- from sglang.lang.interpreter import pin_program
196
+ def cache(self, backend=None):
197
+ from sglang.lang.interpreter import cache_program
198
198
 
199
199
  backend = backend or global_config.default_backend
200
- return pin_program(self, backend)
201
-
202
- def unpin(self, backend=None):
203
- from sglang.lang.interpreter import unpin_program
204
-
205
- backend = backend or global_config.default_backend
206
- return unpin_program(self, backend)
200
+ return cache_program(self, backend)
207
201
 
208
202
  def compile(self, *, backend=None):
209
203
  from sglang.lang.compiler import compile_func
@@ -336,6 +330,15 @@ class SglImage(SglExpr):
336
330
  return f"SglImage({self.path})"
337
331
 
338
332
 
333
+ class SglVideo(SglExpr):
334
+ def __init__(self, path, num_frames):
335
+ self.path = path
336
+ self.num_frames = num_frames
337
+
338
+ def __repr__(self) -> str:
339
+ return f"SglVideo({self.path}, {self.num_frames})"
340
+
341
+
339
342
  class SglGen(SglExpr):
340
343
  def __init__(
341
344
  self,
sglang/lang/tracer.py CHANGED
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
110
110
  ##################################
111
111
 
112
112
  def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
113
- assert (size >= 1)
113
+ assert size >= 1
114
114
 
115
115
  if self.only_trace_prefix:
116
116
  raise StopTracing()
sglang/launch_server.py CHANGED
@@ -2,11 +2,10 @@ import argparse
2
2
 
3
3
  from sglang.srt.server import ServerArgs, launch_server
4
4
 
5
-
6
5
  if __name__ == "__main__":
7
6
  parser = argparse.ArgumentParser()
8
7
  ServerArgs.add_cli_args(parser)
9
8
  args = parser.parse_args()
10
9
  server_args = ServerArgs.from_cli_args(args)
11
10
 
12
- launch_server(server_args, None)
11
+ launch_server(server_args, None)
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import multiprocessing as mp
3
+
4
+ from sglang.srt.server import ServerArgs, launch_server
5
+
6
+ if __name__ == "__main__":
7
+
8
+ model_overide_args = {}
9
+
10
+ model_overide_args["mm_spatial_pool_stride"] = 2
11
+ model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
12
+ model_overide_args["num_frames"] = 16
13
+ model_overide_args["model_type"] = "llavavid"
14
+ if model_overide_args["num_frames"] == 32:
15
+ model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
16
+ model_overide_args["max_sequence_length"] = 4096 * 2
17
+ model_overide_args["tokenizer_model_max_length"] = 4096 * 2
18
+ model_overide_args["model_max_length"] = 4096 * 2
19
+
20
+ parser = argparse.ArgumentParser()
21
+ ServerArgs.add_cli_args(parser)
22
+ args = parser.parse_args()
23
+
24
+ if "34b" in args.model_path.lower():
25
+ model_overide_args["image_token_index"] = 64002
26
+
27
+ server_args = ServerArgs.from_cli_args(args)
28
+
29
+ pipe_reader, pipe_writer = mp.Pipe(duplex=False)
30
+
31
+ launch_server(server_args, pipe_writer, model_overide_args)
@@ -0,0 +1,16 @@
1
+ """
2
+ Usage:
3
+ python3 -m sglang.srt.flush_cache --url http://localhost:30000
4
+ """
5
+
6
+ import argparse
7
+
8
+ import requests
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
13
+ args = parser.parse_args()
14
+
15
+ response = requests.get(args.url + "/flush_cache")
16
+ assert response.status_code == 200
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
30
30
  return config
31
31
 
32
32
 
33
- def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
33
+ def get_config(
34
+ model: str,
35
+ trust_remote_code: bool,
36
+ revision: Optional[str] = None,
37
+ model_overide_args: Optional[dict] = None,
38
+ ):
34
39
  config = AutoConfig.from_pretrained(
35
40
  model, trust_remote_code=trust_remote_code, revision=revision
36
41
  )
42
+ if model_overide_args:
43
+ config.update(model_overide_args)
37
44
  return config
38
45
 
39
46
 
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
8
8
  @dataclass
9
9
  class GenerateReqInput:
10
10
  # The input prompt
11
- text: Union[List[str], str]
11
+ text: Optional[Union[List[str], str]] = None
12
+ # The token ids for text; one can either specify text or input_ids
13
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
12
14
  # The image input
13
15
  image_data: Optional[Union[List[str], str]] = None
14
16
  # The sampling_params
@@ -28,7 +30,17 @@ class GenerateReqInput:
28
30
  # TODO: make all parameters a Union[List[T], T] to allow for batched requests
29
31
 
30
32
  def post_init(self):
31
- is_single = isinstance(self.text, str)
33
+
34
+ if self.text is None:
35
+ assert self.input_ids is not None, "Either text or input_ids should be provided"
36
+ else:
37
+ assert self.input_ids is None, "Either text or input_ids should be provided"
38
+
39
+ if self.text is not None:
40
+ is_single = isinstance(self.text, str)
41
+ else:
42
+ is_single = isinstance(self.input_ids[0], int)
43
+ self.is_single = is_single
32
44
 
33
45
  if is_single:
34
46
  if self.sampling_params is None:
@@ -42,7 +54,7 @@ class GenerateReqInput:
42
54
  if self.top_logprobs_num is None:
43
55
  self.top_logprobs_num = 0
44
56
  else:
45
- num = len(self.text)
57
+ num = len(self.text) if self.text is not None else len(self.input_ids)
46
58
 
47
59
  if self.image_data is None:
48
60
  self.image_data = [None] * num
@@ -20,6 +20,17 @@ class FinishReason(IntEnum):
20
20
  LENGTH = auto()
21
21
  STOP_STR = auto()
22
22
 
23
+ @staticmethod
24
+ def to_str(reason):
25
+ if reason == FinishReason.EOS_TOKEN:
26
+ return None
27
+ elif reason == FinishReason.LENGTH:
28
+ return "length"
29
+ elif reason == FinishReason.STOP_STR:
30
+ return "stop"
31
+ else:
32
+ return None
33
+
23
34
 
24
35
  class Req:
25
36
  def __init__(self, rid, input_text, input_ids):
@@ -85,6 +96,9 @@ class Req:
85
96
  )
86
97
  if first_token.startswith("▁"):
87
98
  old_output_str = " " + old_output_str
99
+ if self.input_text is None:
100
+ # TODO(lmzheng): This can be wrong. Check with Liangsheng.
101
+ self.input_text = self.tokenizer.decode(self.input_ids)
88
102
  new_input_string = (
89
103
  self.input_text
90
104
  + self.output_and_jump_forward_str
@@ -332,20 +346,20 @@ class Batch:
332
346
  req = self.reqs[idx]
333
347
  retracted_reqs.append(req)
334
348
 
335
- self.tree_cache.dec_ref_counter(req.last_node)
349
+ # TODO: apply more fine-grained retraction
350
+ last_uncached_pos = len(req.prefix_indices)
351
+ token_indices = self.req_to_token_pool.req_to_token[
352
+ req_pool_indices_cpu[idx]
353
+ ][last_uncached_pos : seq_lens_cpu[idx]]
354
+ self.token_to_kv_pool.dec_refs(token_indices)
355
+
356
+ self.tree_cache.dec_lock_ref(req.last_node)
336
357
  req.prefix_indices = None
337
358
  req.last_node = None
338
359
  req.extend_input_len = 0
339
360
  req.output_ids = []
340
361
  req.regex_fsm_state = 0
341
362
 
342
- # TODO: apply more fine-grained retraction
343
-
344
- token_indices = self.req_to_token_pool.req_to_token[
345
- req_pool_indices_cpu[idx]
346
- ][: seq_lens_cpu[idx]]
347
- self.token_to_kv_pool.dec_refs(token_indices)
348
-
349
363
  self.filter_batch(sorted_indices)
350
364
 
351
365
  return retracted_reqs
@@ -364,20 +378,18 @@ class Batch:
364
378
  if len(jump_forward_str) <= 1:
365
379
  continue
366
380
 
367
- # insert the old request into tree_cache
368
- token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
369
381
  if req_pool_indices_cpu is None:
370
382
  req_pool_indices_cpu = self.req_pool_indices.tolist()
371
- req_pool_idx = req_pool_indices_cpu[i]
372
- indices = self.req_to_token_pool.req_to_token[
373
- req_pool_idx, : len(token_ids_in_memory)
374
- ]
375
- prefix_len = self.tree_cache.insert(
376
- token_ids_in_memory, indices.clone()
383
+
384
+ # insert the old request into tree_cache
385
+ self.tree_cache.cache_req(
386
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
387
+ last_uncached_pos=len(req.prefix_indices),
388
+ req_pool_idx=req_pool_indices_cpu[i],
377
389
  )
378
- self.token_to_kv_pool.dec_refs(indices[:prefix_len])
379
- self.req_to_token_pool.free(req_pool_idx)
380
- self.tree_cache.dec_ref_counter(req.last_node)
390
+
391
+ # unlock the last node
392
+ self.tree_cache.dec_lock_ref(req.last_node)
381
393
 
382
394
  # jump-forward
383
395
  req.jump_forward_and_retokenize(jump_forward_str, next_state)
@@ -5,7 +5,7 @@ import uvloop
5
5
  import zmq
6
6
  import zmq.asyncio
7
7
 
8
- from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
8
+ from sglang.global_config import global_config
9
9
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
10
10
  from sglang.srt.server_args import PortArgs, ServerArgs
11
11
  from sglang.srt.utils import get_exception_traceback
@@ -30,7 +30,7 @@ class RouterManager:
30
30
  self.recv_reqs = []
31
31
 
32
32
  # Init some configs
33
- self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
+ self.request_dependency_time = global_config.request_dependency_time
34
34
 
35
35
  async def loop_for_forward(self):
36
36
  while True:
@@ -46,9 +46,9 @@ class RouterManager:
46
46
  if len(out_pyobjs) != 0:
47
47
  has_finished = any([obj.finished for obj in out_pyobjs])
48
48
  if has_finished:
49
- if self.extend_dependency_time > 0:
49
+ if self.request_dependency_time > 0:
50
50
  slept = True
51
- await asyncio.sleep(self.extend_dependency_time)
51
+ await asyncio.sleep(self.request_dependency_time)
52
52
 
53
53
  if not slept:
54
54
  await asyncio.sleep(0.0006)
@@ -60,9 +60,7 @@ class RouterManager:
60
60
 
61
61
 
62
62
  def start_router_process(
63
- server_args: ServerArgs,
64
- port_args: PortArgs,
65
- pipe_writer,
63
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
66
64
  ):
67
65
  logging.basicConfig(
68
66
  level=getattr(logging, server_args.log_level.upper()),
@@ -70,7 +68,7 @@ def start_router_process(
70
68
  )
71
69
 
72
70
  try:
73
- model_client = ModelRpcClient(server_args, port_args)
71
+ model_client = ModelRpcClient(server_args, port_args, model_overide_args)
74
72
  router = RouterManager(model_client, port_args)
75
73
  except Exception:
76
74
  pipe_writer.send(get_exception_traceback())