sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +5 -0
- sglang/global_config.py +4 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +52 -19
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +8 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/router/infer_batch.py +31 -19
- sglang/srt/managers/router/manager.py +6 -8
- sglang/srt/managers/router/model_rpc.py +59 -23
- sglang/srt/managers/router/model_runner.py +6 -6
- sglang/srt/managers/router/radix_cache.py +47 -17
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +54 -22
- sglang/srt/model_config.py +4 -0
- sglang/srt/models/commandr.py +6 -10
- sglang/srt/models/dbrx.py +14 -15
- sglang/srt/models/gemma.py +7 -10
- sglang/srt/models/llama2.py +7 -10
- sglang/srt/models/llava.py +2 -6
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +7 -13
- sglang/srt/models/qwen.py +20 -13
- sglang/srt/models/qwen2.py +7 -10
- sglang/srt/models/stablelm.py +13 -12
- sglang/srt/models/yivl.py +1 -4
- sglang/srt/server.py +32 -18
- sglang/srt/server_args.py +9 -6
- sglang/srt/utils.py +126 -17
- sglang/srt/weight_utils.py +66 -51
- sglang/utils.py +77 -26
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,13 @@ import multiprocessing
|
|
4
4
|
import time
|
5
5
|
import warnings
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import List
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import rpyc
|
10
10
|
import torch
|
11
11
|
from rpyc.utils.classic import obtain
|
12
12
|
from rpyc.utils.server import ThreadedServer
|
13
|
+
|
13
14
|
try:
|
14
15
|
from vllm.logger import _default_handler as vllm_default_logger
|
15
16
|
except ImportError:
|
@@ -23,7 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|
23
24
|
FlushCacheReq,
|
24
25
|
TokenizedGenerateReqInput,
|
25
26
|
)
|
26
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
27
|
+
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
|
27
28
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
28
29
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
29
30
|
from sglang.srt.managers.router.scheduler import Scheduler
|
@@ -48,6 +49,7 @@ class ModelRpcServer:
|
|
48
49
|
tp_rank: int,
|
49
50
|
server_args: ServerArgs,
|
50
51
|
port_args: PortArgs,
|
52
|
+
model_overide_args: Optional[dict] = None,
|
51
53
|
):
|
52
54
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
53
55
|
|
@@ -62,6 +64,7 @@ class ModelRpcServer:
|
|
62
64
|
server_args.model_path,
|
63
65
|
server_args.trust_remote_code,
|
64
66
|
context_length=server_args.context_length,
|
67
|
+
model_overide_args=model_overide_args,
|
65
68
|
)
|
66
69
|
|
67
70
|
# For model end global settings
|
@@ -117,7 +120,11 @@ class ModelRpcServer:
|
|
117
120
|
logger.info(f"server_args: {server_args.print_mode_args()}")
|
118
121
|
|
119
122
|
# Init cache
|
120
|
-
self.tree_cache = RadixCache(
|
123
|
+
self.tree_cache = RadixCache(
|
124
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
125
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
126
|
+
disable=server_args.disable_radix_cache,
|
127
|
+
)
|
121
128
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
122
129
|
self.scheduler = Scheduler(
|
123
130
|
self.schedule_heuristic,
|
@@ -135,6 +142,8 @@ class ModelRpcServer:
|
|
135
142
|
self.out_pyobjs = []
|
136
143
|
self.decode_forward_ct = 0
|
137
144
|
self.stream_interval = server_args.stream_interval
|
145
|
+
self.num_generated_tokens = 0
|
146
|
+
self.last_stats_tic = time.time()
|
138
147
|
|
139
148
|
# Init the FSM cache for constrained generation
|
140
149
|
self.regex_fsm_cache = FSMCache(
|
@@ -201,6 +210,8 @@ class ModelRpcServer:
|
|
201
210
|
# Run new fill batch
|
202
211
|
self.forward_fill_batch(new_batch)
|
203
212
|
|
213
|
+
self.cache_filled_batch(new_batch)
|
214
|
+
|
204
215
|
if not new_batch.is_empty():
|
205
216
|
if self.running_batch is None:
|
206
217
|
self.running_batch = new_batch
|
@@ -211,6 +222,7 @@ class ModelRpcServer:
|
|
211
222
|
if self.running_batch is not None:
|
212
223
|
# Run a few decode batches continuously for reducing overhead
|
213
224
|
for _ in range(10):
|
225
|
+
self.num_generated_tokens += len(self.running_batch.reqs)
|
214
226
|
self.forward_decode_batch(self.running_batch)
|
215
227
|
|
216
228
|
if self.running_batch.is_empty():
|
@@ -226,10 +238,14 @@ class ModelRpcServer:
|
|
226
238
|
self.token_to_kv_pool.available_size()
|
227
239
|
+ self.tree_cache.evictable_size()
|
228
240
|
)
|
241
|
+
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
242
|
+
self.num_generated_tokens = 0
|
243
|
+
self.last_stats_tic = time.time()
|
229
244
|
logger.info(
|
230
245
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
231
246
|
f"#token: {num_used}, "
|
232
247
|
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
248
|
+
f"gen throughput (token/s): {throuhgput:.2f}, "
|
233
249
|
f"#queue-req: {len(self.forward_queue)}"
|
234
250
|
)
|
235
251
|
else:
|
@@ -342,20 +358,19 @@ class ModelRpcServer:
|
|
342
358
|
and req.extend_input_len + new_batch_input_tokens
|
343
359
|
< self.max_prefill_num_token
|
344
360
|
):
|
345
|
-
delta = self.tree_cache.
|
361
|
+
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
346
362
|
available_size += delta
|
347
363
|
|
348
364
|
if not (
|
349
365
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
350
366
|
< available_size
|
351
367
|
):
|
352
|
-
# Undo
|
353
|
-
delta = self.tree_cache.
|
368
|
+
# Undo locking
|
369
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
354
370
|
available_size += delta
|
355
371
|
break
|
356
372
|
else:
|
357
373
|
# Add this request to the running batch
|
358
|
-
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
359
374
|
can_run_list.append(req)
|
360
375
|
new_batch_total_tokens += (
|
361
376
|
req.extend_input_len + req.max_new_tokens()
|
@@ -426,7 +441,9 @@ class ModelRpcServer:
|
|
426
441
|
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
427
442
|
if last_logprobs is not None:
|
428
443
|
last_token_logprobs = (
|
429
|
-
last_logprobs[
|
444
|
+
last_logprobs[
|
445
|
+
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
446
|
+
next_token_ids].tolist()
|
430
447
|
)
|
431
448
|
|
432
449
|
next_token_ids = next_token_ids.tolist()
|
@@ -468,6 +485,18 @@ class ModelRpcServer:
|
|
468
485
|
|
469
486
|
self.handle_finished_requests(batch)
|
470
487
|
|
488
|
+
def cache_filled_batch(self, batch: Batch):
|
489
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
490
|
+
for i, req in enumerate(batch.reqs):
|
491
|
+
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
492
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
493
|
+
last_uncached_pos=len(req.prefix_indices),
|
494
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
495
|
+
del_in_memory_pool=False,
|
496
|
+
old_last_node=req.last_node,
|
497
|
+
)
|
498
|
+
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
499
|
+
|
471
500
|
def forward_decode_batch(self, batch: Batch):
|
472
501
|
# check if decode out of memory
|
473
502
|
if not batch.check_decode_mem():
|
@@ -586,7 +615,8 @@ class ModelRpcServer:
|
|
586
615
|
+ len(req.output_ids)
|
587
616
|
- req.prompt_tokens,
|
588
617
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
589
|
-
"finish_reason":
|
618
|
+
"finish_reason": FinishReason.to_str(req.finish_reason),
|
619
|
+
"hit_stop_str": req.hit_stop_str,
|
590
620
|
}
|
591
621
|
if req.return_logprob:
|
592
622
|
(
|
@@ -626,17 +656,13 @@ class ModelRpcServer:
|
|
626
656
|
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
627
657
|
for i in finished_indices:
|
628
658
|
req = batch.reqs[i]
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
prefix_len = self.tree_cache.insert(
|
634
|
-
token_ids[:seq_len], indices.clone()
|
659
|
+
self.tree_cache.cache_req(
|
660
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
661
|
+
last_uncached_pos=len(req.prefix_indices),
|
662
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
635
663
|
)
|
636
664
|
|
637
|
-
self.
|
638
|
-
self.req_to_token_pool.free(req_pool_idx)
|
639
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
665
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
640
666
|
|
641
667
|
# Update batch tensors
|
642
668
|
if unfinished_indices:
|
@@ -650,13 +676,15 @@ class ModelRpcService(rpyc.Service):
|
|
650
676
|
|
651
677
|
|
652
678
|
class ModelRpcClient:
|
653
|
-
def __init__(
|
679
|
+
def __init__(
|
680
|
+
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
681
|
+
):
|
654
682
|
tp_size = server_args.tp_size
|
655
683
|
|
656
684
|
if tp_size == 1:
|
657
685
|
# Init model
|
658
686
|
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
659
|
-
0, server_args, port_args
|
687
|
+
0, server_args, port_args, model_overide_args
|
660
688
|
)
|
661
689
|
|
662
690
|
# Wrap functions
|
@@ -677,7 +705,7 @@ class ModelRpcClient:
|
|
677
705
|
# Init model
|
678
706
|
def init_model(i):
|
679
707
|
return self.remote_services[i].ModelRpcServer(
|
680
|
-
i, server_args, port_args
|
708
|
+
i, server_args, port_args, model_overide_args
|
681
709
|
)
|
682
710
|
|
683
711
|
self.model_servers = executor.map(init_model, range(tp_size))
|
@@ -700,7 +728,11 @@ def _init_service(port):
|
|
700
728
|
t = ThreadedServer(
|
701
729
|
ModelRpcService(),
|
702
730
|
port=port,
|
703
|
-
protocol_config={
|
731
|
+
protocol_config={
|
732
|
+
"allow_public_attrs": True,
|
733
|
+
"allow_pickle": True,
|
734
|
+
"sync_request_timeout": 1800,
|
735
|
+
},
|
704
736
|
)
|
705
737
|
t.start()
|
706
738
|
|
@@ -716,7 +748,11 @@ def start_model_process(port):
|
|
716
748
|
con = rpyc.connect(
|
717
749
|
"localhost",
|
718
750
|
port,
|
719
|
-
config={
|
751
|
+
config={
|
752
|
+
"allow_public_attrs": True,
|
753
|
+
"allow_pickle": True,
|
754
|
+
"sync_request_timeout": 1800,
|
755
|
+
},
|
720
756
|
)
|
721
757
|
break
|
722
758
|
except ConnectionRefusedError:
|
@@ -9,16 +9,16 @@ from typing import List
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import torch
|
12
|
+
from vllm.distributed import initialize_model_parallel
|
12
13
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
13
14
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
14
15
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
15
16
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
16
|
-
from vllm.distributed import initialize_model_parallel
|
17
17
|
|
18
18
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
19
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
-
from sglang.srt.utils import is_multimodal_model
|
21
|
-
|
20
|
+
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
|
21
|
+
|
22
22
|
|
23
23
|
QUANTIZATION_CONFIG_MAPPING = {
|
24
24
|
"awq": AWQConfig,
|
@@ -110,8 +110,8 @@ class InputMetadata:
|
|
110
110
|
self.kv_last_page_len = torch.ones(
|
111
111
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
112
|
)
|
113
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().
|
114
|
-
seq_lens_cpu = self.seq_lens.
|
113
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
114
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
115
115
|
self.kv_indices = torch.cat(
|
116
116
|
[
|
117
117
|
self.req_to_token_pool.req_to_token[
|
@@ -143,7 +143,7 @@ class InputMetadata:
|
|
143
143
|
self.kv_last_page_len,
|
144
144
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
145
145
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
146
|
-
self.model_runner.model_config.head_dim
|
146
|
+
self.model_runner.model_config.head_dim,
|
147
147
|
]
|
148
148
|
|
149
149
|
self.prefill_wrapper.begin_forward(*args)
|
@@ -11,7 +11,7 @@ class TreeNode:
|
|
11
11
|
self.parent = None
|
12
12
|
self.key = None
|
13
13
|
self.value = None
|
14
|
-
self.
|
14
|
+
self.lock_ref = 0
|
15
15
|
self.last_access_time = time.time()
|
16
16
|
|
17
17
|
def __lt__(self, other: "TreeNode"):
|
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
|
|
28
28
|
|
29
29
|
|
30
30
|
class RadixCache:
|
31
|
-
def __init__(self, disable: bool = False):
|
31
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
32
|
+
self.req_to_token_pool = req_to_token_pool
|
33
|
+
self.token_to_kv_pool = token_to_kv_pool
|
32
34
|
self.disable = disable
|
33
35
|
self.reset()
|
34
36
|
|
@@ -38,7 +40,7 @@ class RadixCache:
|
|
38
40
|
self.root_node = TreeNode()
|
39
41
|
self.root_node.key = []
|
40
42
|
self.root_node.value = []
|
41
|
-
self.root_node.
|
43
|
+
self.root_node.lock_ref = 1
|
42
44
|
self.evictable_size_ = 0
|
43
45
|
|
44
46
|
def match_prefix(self, key):
|
@@ -50,6 +52,8 @@ class RadixCache:
|
|
50
52
|
self._match_prefix_helper(self.root_node, key, value, last_node)
|
51
53
|
if value:
|
52
54
|
value = torch.concat(value)
|
55
|
+
else:
|
56
|
+
value = torch.tensor([], dtype=torch.int64)
|
53
57
|
return value, last_node[0]
|
54
58
|
|
55
59
|
def insert(self, key, value=None):
|
@@ -60,6 +64,34 @@ class RadixCache:
|
|
60
64
|
value = [x for x in key]
|
61
65
|
return self._insert_helper(self.root_node, key, value)
|
62
66
|
|
67
|
+
def cache_req(
|
68
|
+
self,
|
69
|
+
token_ids,
|
70
|
+
last_uncached_pos,
|
71
|
+
req_pool_idx,
|
72
|
+
del_in_memory_pool=True,
|
73
|
+
old_last_node=None,
|
74
|
+
):
|
75
|
+
# Insert the request into radix cache
|
76
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
|
+
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
|
+
|
79
|
+
# Radix Cache takes one ref in memory pool
|
80
|
+
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
81
|
+
|
82
|
+
if del_in_memory_pool:
|
83
|
+
self.req_to_token_pool.free(req_pool_idx)
|
84
|
+
else:
|
85
|
+
cached_indices, new_last_node = self.match_prefix(token_ids)
|
86
|
+
assert len(cached_indices) == len(token_ids)
|
87
|
+
|
88
|
+
self.req_to_token_pool.req_to_token[
|
89
|
+
req_pool_idx, last_uncached_pos : len(cached_indices)
|
90
|
+
] = cached_indices[last_uncached_pos:]
|
91
|
+
self.dec_lock_ref(old_last_node)
|
92
|
+
self.inc_lock_ref(new_last_node)
|
93
|
+
return cached_indices, new_last_node
|
94
|
+
|
63
95
|
def pretty_print(self):
|
64
96
|
self._print_helper(self.root_node, 0)
|
65
97
|
print(f"#tokens: {self.total_size()}")
|
@@ -80,7 +112,7 @@ class RadixCache:
|
|
80
112
|
|
81
113
|
if x == self.root_node:
|
82
114
|
break
|
83
|
-
if x.
|
115
|
+
if x.lock_ref > 0:
|
84
116
|
continue
|
85
117
|
|
86
118
|
num_evicted += evict_callback(x.value)
|
@@ -89,23 +121,23 @@ class RadixCache:
|
|
89
121
|
if len(x.parent.children) == 0:
|
90
122
|
heapq.heappush(leaves, x.parent)
|
91
123
|
|
92
|
-
def
|
124
|
+
def inc_lock_ref(self, node: TreeNode):
|
93
125
|
delta = 0
|
94
126
|
while node != self.root_node:
|
95
|
-
if node.
|
127
|
+
if node.lock_ref == 0:
|
96
128
|
self.evictable_size_ -= len(node.value)
|
97
129
|
delta -= len(node.value)
|
98
|
-
node.
|
130
|
+
node.lock_ref += 1
|
99
131
|
node = node.parent
|
100
132
|
return delta
|
101
133
|
|
102
|
-
def
|
134
|
+
def dec_lock_ref(self, node: TreeNode):
|
103
135
|
delta = 0
|
104
136
|
while node != self.root_node:
|
105
|
-
if node.
|
137
|
+
if node.lock_ref == 1:
|
106
138
|
self.evictable_size_ += len(node.value)
|
107
139
|
delta += len(node.value)
|
108
|
-
node.
|
140
|
+
node.lock_ref -= 1
|
109
141
|
node = node.parent
|
110
142
|
return delta
|
111
143
|
|
@@ -131,12 +163,12 @@ class RadixCache:
|
|
131
163
|
last_node[0] = child
|
132
164
|
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
133
165
|
|
134
|
-
def _split_node(self, key, child, split_len):
|
166
|
+
def _split_node(self, key, child: TreeNode, split_len):
|
135
167
|
# new_node -> child
|
136
168
|
new_node = TreeNode()
|
137
169
|
new_node.children = {key[split_len:][0]: child}
|
138
170
|
new_node.parent = child.parent
|
139
|
-
new_node.
|
171
|
+
new_node.lock_ref = child.lock_ref
|
140
172
|
new_node.key = child.key[:split_len]
|
141
173
|
new_node.value = child.value[:split_len]
|
142
174
|
child.parent = new_node
|
@@ -176,11 +208,9 @@ class RadixCache:
|
|
176
208
|
self.evictable_size_ += len(value)
|
177
209
|
return 0
|
178
210
|
|
179
|
-
def _print_helper(self, node, indent):
|
211
|
+
def _print_helper(self, node: TreeNode, indent):
|
180
212
|
for _, child in node.children.items():
|
181
|
-
print(
|
182
|
-
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
183
|
-
)
|
213
|
+
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
184
214
|
self._print_helper(child, indent=indent + 2)
|
185
215
|
|
186
216
|
def _delete_leaf(self, node):
|
@@ -211,7 +241,7 @@ class RadixCache:
|
|
211
241
|
|
212
242
|
|
213
243
|
if __name__ == "__main__":
|
214
|
-
tree = RadixCache()
|
244
|
+
tree = RadixCache(None, None, False)
|
215
245
|
|
216
246
|
tree.insert("Hello")
|
217
247
|
tree.insert("Hello")
|
@@ -27,44 +27,33 @@ class Scheduler:
|
|
27
27
|
return forward_queue
|
28
28
|
elif self.schedule_heuristic == "fcfs":
|
29
29
|
return forward_queue
|
30
|
-
elif self.schedule_heuristic == "weight":
|
30
|
+
elif self.schedule_heuristic == "dfs-weight":
|
31
31
|
last_node_to_reqs = defaultdict(list)
|
32
32
|
for req in forward_queue:
|
33
33
|
last_node_to_reqs[req.last_node].append(req)
|
34
|
-
for node in last_node_to_reqs:
|
35
|
-
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
36
34
|
|
37
35
|
node_to_weight = defaultdict(int)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
36
|
+
for node in last_node_to_reqs:
|
37
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
38
|
+
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
41
39
|
|
42
|
-
|
43
|
-
self.
|
44
|
-
self.tree_cache.root_node, node_to_weight, last_node_to_reqs,
|
40
|
+
q = []
|
41
|
+
self.get_dfs_priority(
|
42
|
+
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
45
43
|
)
|
46
|
-
assert len(
|
47
|
-
return
|
44
|
+
assert len(q) == len(forward_queue)
|
45
|
+
return q
|
48
46
|
else:
|
49
47
|
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
50
48
|
|
51
|
-
def
|
52
|
-
node_to_weight[cur_node] = 1
|
53
|
-
if cur_node in last_node_to_reqs:
|
54
|
-
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
49
|
+
def calc_weight(self, cur_node, node_to_weight):
|
55
50
|
for child in cur_node.children.values():
|
56
|
-
self.
|
51
|
+
self.calc_weight(child, node_to_weight)
|
57
52
|
node_to_weight[cur_node] += node_to_weight[child]
|
58
53
|
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
66
|
-
for child in visit_list:
|
67
|
-
self._get_weight_priority_recursive(
|
68
|
-
child, node_to_wight, last_node_to_reqs, tmp_queue
|
69
|
-
)
|
70
|
-
tmp_queue.extend(last_node_to_reqs[cur_node])
|
54
|
+
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
|
55
|
+
childs = [child for child in cur_node.children.values()]
|
56
|
+
childs.sort(key=lambda x: -node_to_priority[x])
|
57
|
+
for child in childs:
|
58
|
+
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
59
|
+
q.extend(last_node_to_reqs[cur_node])
|
@@ -60,21 +60,29 @@ def get_pixel_values(
|
|
60
60
|
):
|
61
61
|
try:
|
62
62
|
processor = processor or global_processor
|
63
|
-
image = load_image(image_data)
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
pixel_values =
|
70
|
-
|
71
|
-
pixel_values = process_anyres_image(
|
72
|
-
image, processor.image_processor, image_grid_pinpoints
|
73
|
-
)
|
63
|
+
image, image_size = load_image(image_data)
|
64
|
+
if image_size != None:
|
65
|
+
image_hash = hash(image_data)
|
66
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
67
|
+
for _ in range(len(pixel_values)):
|
68
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
69
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
70
|
+
return pixel_values, image_hash, image_size
|
74
71
|
else:
|
75
|
-
|
76
|
-
|
77
|
-
|
72
|
+
image_hash = hash(image_data)
|
73
|
+
if image_aspect_ratio == "pad":
|
74
|
+
image = expand2square(
|
75
|
+
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
76
|
+
)
|
77
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
78
|
+
elif image_aspect_ratio == "anyres":
|
79
|
+
pixel_values = process_anyres_image(
|
80
|
+
image, processor.image_processor, image_grid_pinpoints
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
84
|
+
pixel_values = pixel_values.astype(np.float16)
|
85
|
+
return pixel_values, image_hash, image.size
|
78
86
|
except Exception:
|
79
87
|
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
80
88
|
|
@@ -84,6 +92,7 @@ class TokenizerManager:
|
|
84
92
|
self,
|
85
93
|
server_args: ServerArgs,
|
86
94
|
port_args: PortArgs,
|
95
|
+
model_overide_args: dict = None,
|
87
96
|
):
|
88
97
|
self.server_args = server_args
|
89
98
|
|
@@ -96,9 +105,10 @@ class TokenizerManager:
|
|
96
105
|
|
97
106
|
self.model_path = server_args.model_path
|
98
107
|
self.hf_config = get_config(
|
99
|
-
self.model_path,
|
108
|
+
self.model_path,
|
109
|
+
trust_remote_code=server_args.trust_remote_code,
|
110
|
+
model_overide_args=model_overide_args,
|
100
111
|
)
|
101
|
-
|
102
112
|
self.context_len = get_context_length(self.hf_config)
|
103
113
|
|
104
114
|
if is_multimodal_model(self.model_path):
|
@@ -147,11 +157,21 @@ class TokenizerManager:
|
|
147
157
|
if self.to_create_loop:
|
148
158
|
await self.create_handle_loop()
|
149
159
|
|
150
|
-
is_single =
|
151
|
-
|
160
|
+
is_single = obj.is_single
|
152
161
|
if is_single:
|
153
162
|
rid = obj.rid
|
154
|
-
|
163
|
+
|
164
|
+
if obj.input_ids is None:
|
165
|
+
input_ids = self.tokenizer.encode(obj.text)
|
166
|
+
else:
|
167
|
+
input_ids = obj.input_ids
|
168
|
+
|
169
|
+
if len(input_ids) >= self.context_len:
|
170
|
+
raise ValueError(
|
171
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
172
|
+
f"model's context length ({self.context_len} tokens)"
|
173
|
+
)
|
174
|
+
|
155
175
|
sampling_params = SamplingParams(**obj.sampling_params)
|
156
176
|
if sampling_params.max_new_tokens != 0:
|
157
177
|
sampling_params.normalize(self.tokenizer)
|
@@ -204,10 +224,22 @@ class TokenizerManager:
|
|
204
224
|
event.clear()
|
205
225
|
else:
|
206
226
|
assert obj.stream is False
|
207
|
-
|
227
|
+
|
228
|
+
if obj.input_ids is None:
|
229
|
+
bs = len(obj.text)
|
230
|
+
else:
|
231
|
+
bs = len(obj.input_ids)
|
232
|
+
|
208
233
|
for i in range(bs):
|
209
234
|
rid = obj.rid[i]
|
210
|
-
|
235
|
+
|
236
|
+
if obj.input_ids is None:
|
237
|
+
input_text = obj.text[i]
|
238
|
+
input_ids = self.tokenizer.encode(obj.text[i])
|
239
|
+
else:
|
240
|
+
input_text = None
|
241
|
+
input_ids = obj.input_ids[i]
|
242
|
+
|
211
243
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
212
244
|
if sampling_params.max_new_tokens != 0:
|
213
245
|
sampling_params.normalize(self.tokenizer)
|
@@ -220,7 +252,7 @@ class TokenizerManager:
|
|
220
252
|
)
|
221
253
|
tokenized_obj = TokenizedGenerateReqInput(
|
222
254
|
rid=rid,
|
223
|
-
input_text=
|
255
|
+
input_text=input_text,
|
224
256
|
input_ids=input_ids,
|
225
257
|
pixel_values=pixel_values,
|
226
258
|
image_hash=image_hash,
|
sglang/srt/model_config.py
CHANGED
@@ -10,12 +10,16 @@ class ModelConfig:
|
|
10
10
|
trust_remote_code: bool = True,
|
11
11
|
revision: Optional[str] = None,
|
12
12
|
context_length: Optional[int] = None,
|
13
|
+
model_overide_args: Optional[dict] = None,
|
13
14
|
) -> None:
|
14
15
|
self.path = path
|
15
16
|
self.trust_remote_code = trust_remote_code
|
16
17
|
self.revision = revision
|
17
18
|
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
18
19
|
|
20
|
+
if model_overide_args is not None:
|
21
|
+
self.hf_config.update(model_overide_args)
|
22
|
+
|
19
23
|
if context_length is not None:
|
20
24
|
self.context_len = context_length
|
21
25
|
else:
|
sglang/srt/models/commandr.py
CHANGED
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
|
|
27
27
|
from torch import nn
|
28
28
|
from torch.nn.parameter import Parameter
|
29
29
|
from transformers import PretrainedConfig
|
30
|
+
from vllm.distributed import (
|
31
|
+
get_tensor_model_parallel_rank,
|
32
|
+
get_tensor_model_parallel_world_size,
|
33
|
+
)
|
30
34
|
from vllm.model_executor.layers.activation import SiluAndMul
|
31
35
|
from vllm.model_executor.layers.linear import (
|
32
36
|
MergedColumnParallelLinear,
|
33
37
|
QKVParallelLinear,
|
34
38
|
RowParallelLinear,
|
35
39
|
)
|
36
|
-
from vllm.model_executor.layers.quantization.base_config import
|
37
|
-
QuantizationConfig)
|
40
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
38
41
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
39
42
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
|
-
from vllm.distributed import (
|
41
|
-
get_tensor_model_parallel_rank,
|
42
|
-
get_tensor_model_parallel_world_size,
|
43
|
-
)
|
44
43
|
from vllm.model_executor.utils import set_weight_attrs
|
45
|
-
from sglang.srt.weight_utils import (
|
46
|
-
default_weight_loader,
|
47
|
-
hf_model_weights_iterator,
|
48
|
-
)
|
49
44
|
|
50
45
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
46
|
from sglang.srt.layers.radix_attention import RadixAttention
|
52
47
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
48
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
53
49
|
|
54
50
|
|
55
51
|
@torch.compile
|