sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +5 -0
  3. sglang/global_config.py +4 -1
  4. sglang/lang/chat_template.py +9 -2
  5. sglang/lang/interpreter.py +52 -19
  6. sglang/lang/ir.py +12 -9
  7. sglang/lang/tracer.py +1 -1
  8. sglang/launch_server.py +1 -2
  9. sglang/launch_server_llavavid.py +31 -0
  10. sglang/srt/flush_cache.py +16 -0
  11. sglang/srt/hf_transformers_utils.py +8 -1
  12. sglang/srt/managers/io_struct.py +15 -3
  13. sglang/srt/managers/router/infer_batch.py +31 -19
  14. sglang/srt/managers/router/manager.py +6 -8
  15. sglang/srt/managers/router/model_rpc.py +59 -23
  16. sglang/srt/managers/router/model_runner.py +6 -6
  17. sglang/srt/managers/router/radix_cache.py +47 -17
  18. sglang/srt/managers/router/scheduler.py +17 -28
  19. sglang/srt/managers/tokenizer_manager.py +54 -22
  20. sglang/srt/model_config.py +4 -0
  21. sglang/srt/models/commandr.py +6 -10
  22. sglang/srt/models/dbrx.py +14 -15
  23. sglang/srt/models/gemma.py +7 -10
  24. sglang/srt/models/llama2.py +7 -10
  25. sglang/srt/models/llava.py +2 -6
  26. sglang/srt/models/llavavid.py +307 -0
  27. sglang/srt/models/mixtral.py +7 -13
  28. sglang/srt/models/qwen.py +20 -13
  29. sglang/srt/models/qwen2.py +7 -10
  30. sglang/srt/models/stablelm.py +13 -12
  31. sglang/srt/models/yivl.py +1 -4
  32. sglang/srt/server.py +32 -18
  33. sglang/srt/server_args.py +9 -6
  34. sglang/srt/utils.py +126 -17
  35. sglang/srt/weight_utils.py +66 -51
  36. sglang/utils.py +77 -26
  37. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
  38. sglang-0.1.16.dist-info/RECORD +72 -0
  39. sglang-0.1.15.dist-info/RECORD +0 -69
  40. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  41. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  42. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,13 @@ import multiprocessing
4
4
  import time
5
5
  import warnings
6
6
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import List
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import rpyc
10
10
  import torch
11
11
  from rpyc.utils.classic import obtain
12
12
  from rpyc.utils.server import ThreadedServer
13
+
13
14
  try:
14
15
  from vllm.logger import _default_handler as vllm_default_logger
15
16
  except ImportError:
@@ -23,7 +24,7 @@ from sglang.srt.managers.io_struct import (
23
24
  FlushCacheReq,
24
25
  TokenizedGenerateReqInput,
25
26
  )
26
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
27
+ from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
27
28
  from sglang.srt.managers.router.model_runner import ModelRunner
28
29
  from sglang.srt.managers.router.radix_cache import RadixCache
29
30
  from sglang.srt.managers.router.scheduler import Scheduler
@@ -48,6 +49,7 @@ class ModelRpcServer:
48
49
  tp_rank: int,
49
50
  server_args: ServerArgs,
50
51
  port_args: PortArgs,
52
+ model_overide_args: Optional[dict] = None,
51
53
  ):
52
54
  server_args, port_args = [obtain(x) for x in [server_args, port_args]]
53
55
 
@@ -62,6 +64,7 @@ class ModelRpcServer:
62
64
  server_args.model_path,
63
65
  server_args.trust_remote_code,
64
66
  context_length=server_args.context_length,
67
+ model_overide_args=model_overide_args,
65
68
  )
66
69
 
67
70
  # For model end global settings
@@ -117,7 +120,11 @@ class ModelRpcServer:
117
120
  logger.info(f"server_args: {server_args.print_mode_args()}")
118
121
 
119
122
  # Init cache
120
- self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
123
+ self.tree_cache = RadixCache(
124
+ req_to_token_pool=self.model_runner.req_to_token_pool,
125
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
126
+ disable=server_args.disable_radix_cache,
127
+ )
121
128
  self.tree_cache_metrics = {"total": 0, "hit": 0}
122
129
  self.scheduler = Scheduler(
123
130
  self.schedule_heuristic,
@@ -135,6 +142,8 @@ class ModelRpcServer:
135
142
  self.out_pyobjs = []
136
143
  self.decode_forward_ct = 0
137
144
  self.stream_interval = server_args.stream_interval
145
+ self.num_generated_tokens = 0
146
+ self.last_stats_tic = time.time()
138
147
 
139
148
  # Init the FSM cache for constrained generation
140
149
  self.regex_fsm_cache = FSMCache(
@@ -201,6 +210,8 @@ class ModelRpcServer:
201
210
  # Run new fill batch
202
211
  self.forward_fill_batch(new_batch)
203
212
 
213
+ self.cache_filled_batch(new_batch)
214
+
204
215
  if not new_batch.is_empty():
205
216
  if self.running_batch is None:
206
217
  self.running_batch = new_batch
@@ -211,6 +222,7 @@ class ModelRpcServer:
211
222
  if self.running_batch is not None:
212
223
  # Run a few decode batches continuously for reducing overhead
213
224
  for _ in range(10):
225
+ self.num_generated_tokens += len(self.running_batch.reqs)
214
226
  self.forward_decode_batch(self.running_batch)
215
227
 
216
228
  if self.running_batch.is_empty():
@@ -226,10 +238,14 @@ class ModelRpcServer:
226
238
  self.token_to_kv_pool.available_size()
227
239
  + self.tree_cache.evictable_size()
228
240
  )
241
+ throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
242
+ self.num_generated_tokens = 0
243
+ self.last_stats_tic = time.time()
229
244
  logger.info(
230
245
  f"#running-req: {len(self.running_batch.reqs)}, "
231
246
  f"#token: {num_used}, "
232
247
  f"token usage: {num_used / self.max_total_num_token:.2f}, "
248
+ f"gen throughput (token/s): {throuhgput:.2f}, "
233
249
  f"#queue-req: {len(self.forward_queue)}"
234
250
  )
235
251
  else:
@@ -342,20 +358,19 @@ class ModelRpcServer:
342
358
  and req.extend_input_len + new_batch_input_tokens
343
359
  < self.max_prefill_num_token
344
360
  ):
345
- delta = self.tree_cache.inc_ref_counter(req.last_node)
361
+ delta = self.tree_cache.inc_lock_ref(req.last_node)
346
362
  available_size += delta
347
363
 
348
364
  if not (
349
365
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
350
366
  < available_size
351
367
  ):
352
- # Undo the insertion
353
- delta = self.tree_cache.dec_ref_counter(req.last_node)
368
+ # Undo locking
369
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
354
370
  available_size += delta
355
371
  break
356
372
  else:
357
373
  # Add this request to the running batch
358
- self.token_to_kv_pool.add_refs(req.prefix_indices)
359
374
  can_run_list.append(req)
360
375
  new_batch_total_tokens += (
361
376
  req.extend_input_len + req.max_new_tokens()
@@ -426,7 +441,9 @@ class ModelRpcServer:
426
441
  # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
427
442
  if last_logprobs is not None:
428
443
  last_token_logprobs = (
429
- last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
444
+ last_logprobs[
445
+ torch.arange(len(batch.reqs), device=next_token_ids.device),
446
+ next_token_ids].tolist()
430
447
  )
431
448
 
432
449
  next_token_ids = next_token_ids.tolist()
@@ -468,6 +485,18 @@ class ModelRpcServer:
468
485
 
469
486
  self.handle_finished_requests(batch)
470
487
 
488
+ def cache_filled_batch(self, batch: Batch):
489
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
490
+ for i, req in enumerate(batch.reqs):
491
+ new_prefix_indices, new_last_node = self.tree_cache.cache_req(
492
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
493
+ last_uncached_pos=len(req.prefix_indices),
494
+ req_pool_idx=req_pool_indices_cpu[i],
495
+ del_in_memory_pool=False,
496
+ old_last_node=req.last_node,
497
+ )
498
+ req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
499
+
471
500
  def forward_decode_batch(self, batch: Batch):
472
501
  # check if decode out of memory
473
502
  if not batch.check_decode_mem():
@@ -586,7 +615,8 @@ class ModelRpcServer:
586
615
  + len(req.output_ids)
587
616
  - req.prompt_tokens,
588
617
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
589
- "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
618
+ "finish_reason": FinishReason.to_str(req.finish_reason),
619
+ "hit_stop_str": req.hit_stop_str,
590
620
  }
591
621
  if req.return_logprob:
592
622
  (
@@ -626,17 +656,13 @@ class ModelRpcServer:
626
656
  req_pool_indices_cpu = batch.req_pool_indices.tolist()
627
657
  for i in finished_indices:
628
658
  req = batch.reqs[i]
629
- req_pool_idx = req_pool_indices_cpu[i]
630
- token_ids = tuple(req.input_ids + req.output_ids)
631
- seq_len = len(token_ids) - 1
632
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
633
- prefix_len = self.tree_cache.insert(
634
- token_ids[:seq_len], indices.clone()
659
+ self.tree_cache.cache_req(
660
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
661
+ last_uncached_pos=len(req.prefix_indices),
662
+ req_pool_idx=req_pool_indices_cpu[i],
635
663
  )
636
664
 
637
- self.token_to_kv_pool.dec_refs(indices[:prefix_len])
638
- self.req_to_token_pool.free(req_pool_idx)
639
- self.tree_cache.dec_ref_counter(req.last_node)
665
+ self.tree_cache.dec_lock_ref(req.last_node)
640
666
 
641
667
  # Update batch tensors
642
668
  if unfinished_indices:
@@ -650,13 +676,15 @@ class ModelRpcService(rpyc.Service):
650
676
 
651
677
 
652
678
  class ModelRpcClient:
653
- def __init__(self, server_args: ServerArgs, port_args: PortArgs):
679
+ def __init__(
680
+ self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
681
+ ):
654
682
  tp_size = server_args.tp_size
655
683
 
656
684
  if tp_size == 1:
657
685
  # Init model
658
686
  self.model_server = ModelRpcService().exposed_ModelRpcServer(
659
- 0, server_args, port_args
687
+ 0, server_args, port_args, model_overide_args
660
688
  )
661
689
 
662
690
  # Wrap functions
@@ -677,7 +705,7 @@ class ModelRpcClient:
677
705
  # Init model
678
706
  def init_model(i):
679
707
  return self.remote_services[i].ModelRpcServer(
680
- i, server_args, port_args
708
+ i, server_args, port_args, model_overide_args
681
709
  )
682
710
 
683
711
  self.model_servers = executor.map(init_model, range(tp_size))
@@ -700,7 +728,11 @@ def _init_service(port):
700
728
  t = ThreadedServer(
701
729
  ModelRpcService(),
702
730
  port=port,
703
- protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
731
+ protocol_config={
732
+ "allow_public_attrs": True,
733
+ "allow_pickle": True,
734
+ "sync_request_timeout": 1800,
735
+ },
704
736
  )
705
737
  t.start()
706
738
 
@@ -716,7 +748,11 @@ def start_model_process(port):
716
748
  con = rpyc.connect(
717
749
  "localhost",
718
750
  port,
719
- config={"allow_pickle": True, "sync_request_timeout": 1800},
751
+ config={
752
+ "allow_public_attrs": True,
753
+ "allow_pickle": True,
754
+ "sync_request_timeout": 1800,
755
+ },
720
756
  )
721
757
  break
722
758
  except ConnectionRefusedError:
@@ -9,16 +9,16 @@ from typing import List
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
+ from vllm.distributed import initialize_model_parallel
12
13
  from vllm.model_executor.layers.quantization.awq import AWQConfig
13
14
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
14
15
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
15
16
  from vllm.model_executor.model_loader.utils import set_default_torch_dtype
16
- from vllm.distributed import initialize_model_parallel
17
17
 
18
18
  from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
19
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
- from sglang.srt.utils import is_multimodal_model
21
- from sglang.utils import get_available_gpu_memory
20
+ from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
21
+
22
22
 
23
23
  QUANTIZATION_CONFIG_MAPPING = {
24
24
  "awq": AWQConfig,
@@ -110,8 +110,8 @@ class InputMetadata:
110
110
  self.kv_last_page_len = torch.ones(
111
111
  (self.batch_size,), dtype=torch.int32, device="cuda"
112
112
  )
113
- req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
114
- seq_lens_cpu = self.seq_lens.tolist()
113
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
114
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
115
115
  self.kv_indices = torch.cat(
116
116
  [
117
117
  self.req_to_token_pool.req_to_token[
@@ -143,7 +143,7 @@ class InputMetadata:
143
143
  self.kv_last_page_len,
144
144
  self.model_runner.model_config.num_attention_heads // tp_size,
145
145
  self.model_runner.model_config.num_key_value_heads // tp_size,
146
- self.model_runner.model_config.head_dim
146
+ self.model_runner.model_config.head_dim,
147
147
  ]
148
148
 
149
149
  self.prefill_wrapper.begin_forward(*args)
@@ -11,7 +11,7 @@ class TreeNode:
11
11
  self.parent = None
12
12
  self.key = None
13
13
  self.value = None
14
- self.ref_counter = 0
14
+ self.lock_ref = 0
15
15
  self.last_access_time = time.time()
16
16
 
17
17
  def __lt__(self, other: "TreeNode"):
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
28
28
 
29
29
 
30
30
  class RadixCache:
31
- def __init__(self, disable: bool = False):
31
+ def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
32
+ self.req_to_token_pool = req_to_token_pool
33
+ self.token_to_kv_pool = token_to_kv_pool
32
34
  self.disable = disable
33
35
  self.reset()
34
36
 
@@ -38,7 +40,7 @@ class RadixCache:
38
40
  self.root_node = TreeNode()
39
41
  self.root_node.key = []
40
42
  self.root_node.value = []
41
- self.root_node.ref_counter = 1
43
+ self.root_node.lock_ref = 1
42
44
  self.evictable_size_ = 0
43
45
 
44
46
  def match_prefix(self, key):
@@ -50,6 +52,8 @@ class RadixCache:
50
52
  self._match_prefix_helper(self.root_node, key, value, last_node)
51
53
  if value:
52
54
  value = torch.concat(value)
55
+ else:
56
+ value = torch.tensor([], dtype=torch.int64)
53
57
  return value, last_node[0]
54
58
 
55
59
  def insert(self, key, value=None):
@@ -60,6 +64,34 @@ class RadixCache:
60
64
  value = [x for x in key]
61
65
  return self._insert_helper(self.root_node, key, value)
62
66
 
67
+ def cache_req(
68
+ self,
69
+ token_ids,
70
+ last_uncached_pos,
71
+ req_pool_idx,
72
+ del_in_memory_pool=True,
73
+ old_last_node=None,
74
+ ):
75
+ # Insert the request into radix cache
76
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
+ new_prefix_len = self.insert(token_ids, indices.clone())
78
+
79
+ # Radix Cache takes one ref in memory pool
80
+ self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
81
+
82
+ if del_in_memory_pool:
83
+ self.req_to_token_pool.free(req_pool_idx)
84
+ else:
85
+ cached_indices, new_last_node = self.match_prefix(token_ids)
86
+ assert len(cached_indices) == len(token_ids)
87
+
88
+ self.req_to_token_pool.req_to_token[
89
+ req_pool_idx, last_uncached_pos : len(cached_indices)
90
+ ] = cached_indices[last_uncached_pos:]
91
+ self.dec_lock_ref(old_last_node)
92
+ self.inc_lock_ref(new_last_node)
93
+ return cached_indices, new_last_node
94
+
63
95
  def pretty_print(self):
64
96
  self._print_helper(self.root_node, 0)
65
97
  print(f"#tokens: {self.total_size()}")
@@ -80,7 +112,7 @@ class RadixCache:
80
112
 
81
113
  if x == self.root_node:
82
114
  break
83
- if x.ref_counter > 0:
115
+ if x.lock_ref > 0:
84
116
  continue
85
117
 
86
118
  num_evicted += evict_callback(x.value)
@@ -89,23 +121,23 @@ class RadixCache:
89
121
  if len(x.parent.children) == 0:
90
122
  heapq.heappush(leaves, x.parent)
91
123
 
92
- def inc_ref_counter(self, node):
124
+ def inc_lock_ref(self, node: TreeNode):
93
125
  delta = 0
94
126
  while node != self.root_node:
95
- if node.ref_counter == 0:
127
+ if node.lock_ref == 0:
96
128
  self.evictable_size_ -= len(node.value)
97
129
  delta -= len(node.value)
98
- node.ref_counter += 1
130
+ node.lock_ref += 1
99
131
  node = node.parent
100
132
  return delta
101
133
 
102
- def dec_ref_counter(self, node):
134
+ def dec_lock_ref(self, node: TreeNode):
103
135
  delta = 0
104
136
  while node != self.root_node:
105
- if node.ref_counter == 1:
137
+ if node.lock_ref == 1:
106
138
  self.evictable_size_ += len(node.value)
107
139
  delta += len(node.value)
108
- node.ref_counter -= 1
140
+ node.lock_ref -= 1
109
141
  node = node.parent
110
142
  return delta
111
143
 
@@ -131,12 +163,12 @@ class RadixCache:
131
163
  last_node[0] = child
132
164
  self._match_prefix_helper(child, key[prefix_len:], value, last_node)
133
165
 
134
- def _split_node(self, key, child, split_len):
166
+ def _split_node(self, key, child: TreeNode, split_len):
135
167
  # new_node -> child
136
168
  new_node = TreeNode()
137
169
  new_node.children = {key[split_len:][0]: child}
138
170
  new_node.parent = child.parent
139
- new_node.ref_counter = child.ref_counter
171
+ new_node.lock_ref = child.lock_ref
140
172
  new_node.key = child.key[:split_len]
141
173
  new_node.value = child.value[:split_len]
142
174
  child.parent = new_node
@@ -176,11 +208,9 @@ class RadixCache:
176
208
  self.evictable_size_ += len(value)
177
209
  return 0
178
210
 
179
- def _print_helper(self, node, indent):
211
+ def _print_helper(self, node: TreeNode, indent):
180
212
  for _, child in node.children.items():
181
- print(
182
- " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
183
- )
213
+ print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
184
214
  self._print_helper(child, indent=indent + 2)
185
215
 
186
216
  def _delete_leaf(self, node):
@@ -211,7 +241,7 @@ class RadixCache:
211
241
 
212
242
 
213
243
  if __name__ == "__main__":
214
- tree = RadixCache()
244
+ tree = RadixCache(None, None, False)
215
245
 
216
246
  tree.insert("Hello")
217
247
  tree.insert("Hello")
@@ -27,44 +27,33 @@ class Scheduler:
27
27
  return forward_queue
28
28
  elif self.schedule_heuristic == "fcfs":
29
29
  return forward_queue
30
- elif self.schedule_heuristic == "weight":
30
+ elif self.schedule_heuristic == "dfs-weight":
31
31
  last_node_to_reqs = defaultdict(list)
32
32
  for req in forward_queue:
33
33
  last_node_to_reqs[req.last_node].append(req)
34
- for node in last_node_to_reqs:
35
- last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
36
34
 
37
35
  node_to_weight = defaultdict(int)
38
- self._calc_weight_recursive(
39
- self.tree_cache.root_node, last_node_to_reqs, node_to_weight
40
- )
36
+ for node in last_node_to_reqs:
37
+ node_to_weight[node] = len(last_node_to_reqs[node])
38
+ self.calc_weight(self.tree_cache.root_node, node_to_weight)
41
39
 
42
- tmp_queue = []
43
- self._get_weight_priority_recursive(
44
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
40
+ q = []
41
+ self.get_dfs_priority(
42
+ self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
45
43
  )
46
- assert len(tmp_queue) == len(forward_queue)
47
- return tmp_queue
44
+ assert len(q) == len(forward_queue)
45
+ return q
48
46
  else:
49
47
  raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
50
48
 
51
- def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
52
- node_to_weight[cur_node] = 1
53
- if cur_node in last_node_to_reqs:
54
- node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
49
+ def calc_weight(self, cur_node, node_to_weight):
55
50
  for child in cur_node.children.values():
56
- self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
51
+ self.calc_weight(child, node_to_weight)
57
52
  node_to_weight[cur_node] += node_to_weight[child]
58
53
 
59
- def _get_weight_priority_recursive(
60
- self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
61
- ):
62
- visit_list = [child for child in cur_node.children.values()]
63
- visit_list.sort(key=lambda x: -node_to_wight[x])
64
- # for node in visit_list:
65
- # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
66
- for child in visit_list:
67
- self._get_weight_priority_recursive(
68
- child, node_to_wight, last_node_to_reqs, tmp_queue
69
- )
70
- tmp_queue.extend(last_node_to_reqs[cur_node])
54
+ def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
55
+ childs = [child for child in cur_node.children.values()]
56
+ childs.sort(key=lambda x: -node_to_priority[x])
57
+ for child in childs:
58
+ self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
59
+ q.extend(last_node_to_reqs[cur_node])
@@ -60,21 +60,29 @@ def get_pixel_values(
60
60
  ):
61
61
  try:
62
62
  processor = processor or global_processor
63
- image = load_image(image_data)
64
- image_hash = hash(image_data)
65
- if image_aspect_ratio == "pad":
66
- image = expand2square(
67
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
68
- )
69
- pixel_values = processor.image_processor(image)["pixel_values"][0]
70
- elif image_aspect_ratio == "anyres":
71
- pixel_values = process_anyres_image(
72
- image, processor.image_processor, image_grid_pinpoints
73
- )
63
+ image, image_size = load_image(image_data)
64
+ if image_size != None:
65
+ image_hash = hash(image_data)
66
+ pixel_values = processor.image_processor(image)["pixel_values"]
67
+ for _ in range(len(pixel_values)):
68
+ pixel_values[_] = pixel_values[_].astype(np.float16)
69
+ pixel_values = np.stack(pixel_values, axis=0)
70
+ return pixel_values, image_hash, image_size
74
71
  else:
75
- pixel_values = processor.image_processor(image)["pixel_values"][0]
76
- pixel_values = pixel_values.astype(np.float16)
77
- return pixel_values, image_hash, image.size
72
+ image_hash = hash(image_data)
73
+ if image_aspect_ratio == "pad":
74
+ image = expand2square(
75
+ image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
76
+ )
77
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
78
+ elif image_aspect_ratio == "anyres":
79
+ pixel_values = process_anyres_image(
80
+ image, processor.image_processor, image_grid_pinpoints
81
+ )
82
+ else:
83
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
84
+ pixel_values = pixel_values.astype(np.float16)
85
+ return pixel_values, image_hash, image.size
78
86
  except Exception:
79
87
  print("Exception in TokenizerManager:\n" + get_exception_traceback())
80
88
 
@@ -84,6 +92,7 @@ class TokenizerManager:
84
92
  self,
85
93
  server_args: ServerArgs,
86
94
  port_args: PortArgs,
95
+ model_overide_args: dict = None,
87
96
  ):
88
97
  self.server_args = server_args
89
98
 
@@ -96,9 +105,10 @@ class TokenizerManager:
96
105
 
97
106
  self.model_path = server_args.model_path
98
107
  self.hf_config = get_config(
99
- self.model_path, trust_remote_code=server_args.trust_remote_code
108
+ self.model_path,
109
+ trust_remote_code=server_args.trust_remote_code,
110
+ model_overide_args=model_overide_args,
100
111
  )
101
-
102
112
  self.context_len = get_context_length(self.hf_config)
103
113
 
104
114
  if is_multimodal_model(self.model_path):
@@ -147,11 +157,21 @@ class TokenizerManager:
147
157
  if self.to_create_loop:
148
158
  await self.create_handle_loop()
149
159
 
150
- is_single = isinstance(obj.text, str)
151
-
160
+ is_single = obj.is_single
152
161
  if is_single:
153
162
  rid = obj.rid
154
- input_ids = self.tokenizer.encode(obj.text)
163
+
164
+ if obj.input_ids is None:
165
+ input_ids = self.tokenizer.encode(obj.text)
166
+ else:
167
+ input_ids = obj.input_ids
168
+
169
+ if len(input_ids) >= self.context_len:
170
+ raise ValueError(
171
+ f"The input ({len(input_ids)} tokens) is longer than the "
172
+ f"model's context length ({self.context_len} tokens)"
173
+ )
174
+
155
175
  sampling_params = SamplingParams(**obj.sampling_params)
156
176
  if sampling_params.max_new_tokens != 0:
157
177
  sampling_params.normalize(self.tokenizer)
@@ -204,10 +224,22 @@ class TokenizerManager:
204
224
  event.clear()
205
225
  else:
206
226
  assert obj.stream is False
207
- bs = len(obj.text)
227
+
228
+ if obj.input_ids is None:
229
+ bs = len(obj.text)
230
+ else:
231
+ bs = len(obj.input_ids)
232
+
208
233
  for i in range(bs):
209
234
  rid = obj.rid[i]
210
- input_ids = self.tokenizer.encode(obj.text[i])
235
+
236
+ if obj.input_ids is None:
237
+ input_text = obj.text[i]
238
+ input_ids = self.tokenizer.encode(obj.text[i])
239
+ else:
240
+ input_text = None
241
+ input_ids = obj.input_ids[i]
242
+
211
243
  sampling_params = SamplingParams(**obj.sampling_params[i])
212
244
  if sampling_params.max_new_tokens != 0:
213
245
  sampling_params.normalize(self.tokenizer)
@@ -220,7 +252,7 @@ class TokenizerManager:
220
252
  )
221
253
  tokenized_obj = TokenizedGenerateReqInput(
222
254
  rid=rid,
223
- input_text=obj.text[i],
255
+ input_text=input_text,
224
256
  input_ids=input_ids,
225
257
  pixel_values=pixel_values,
226
258
  image_hash=image_hash,
@@ -10,12 +10,16 @@ class ModelConfig:
10
10
  trust_remote_code: bool = True,
11
11
  revision: Optional[str] = None,
12
12
  context_length: Optional[int] = None,
13
+ model_overide_args: Optional[dict] = None,
13
14
  ) -> None:
14
15
  self.path = path
15
16
  self.trust_remote_code = trust_remote_code
16
17
  self.revision = revision
17
18
  self.hf_config = get_config(self.path, trust_remote_code, revision)
18
19
 
20
+ if model_overide_args is not None:
21
+ self.hf_config.update(model_overide_args)
22
+
19
23
  if context_length is not None:
20
24
  self.context_len = context_length
21
25
  else:
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
27
27
  from torch import nn
28
28
  from torch.nn.parameter import Parameter
29
29
  from transformers import PretrainedConfig
30
+ from vllm.distributed import (
31
+ get_tensor_model_parallel_rank,
32
+ get_tensor_model_parallel_world_size,
33
+ )
30
34
  from vllm.model_executor.layers.activation import SiluAndMul
31
35
  from vllm.model_executor.layers.linear import (
32
36
  MergedColumnParallelLinear,
33
37
  QKVParallelLinear,
34
38
  RowParallelLinear,
35
39
  )
36
- from vllm.model_executor.layers.quantization.base_config import (
37
- QuantizationConfig)
40
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
41
  from vllm.model_executor.layers.rotary_embedding import get_rope
39
42
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
- from vllm.distributed import (
41
- get_tensor_model_parallel_rank,
42
- get_tensor_model_parallel_world_size,
43
- )
44
43
  from vllm.model_executor.utils import set_weight_attrs
45
- from sglang.srt.weight_utils import (
46
- default_weight_loader,
47
- hf_model_weights_iterator,
48
- )
49
44
 
50
45
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
46
  from sglang.srt.layers.radix_attention import RadixAttention
52
47
  from sglang.srt.managers.router.model_runner import InputMetadata
48
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
53
49
 
54
50
 
55
51
  @torch.compile