sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,18 @@ import multiprocessing
|
|
4
4
|
import time
|
5
5
|
import warnings
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import List
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
|
-
import numpy as np
|
10
9
|
import rpyc
|
11
10
|
import torch
|
12
11
|
from rpyc.utils.classic import obtain
|
13
12
|
from rpyc.utils.server import ThreadedServer
|
13
|
+
|
14
|
+
try:
|
15
|
+
from vllm.logger import _default_handler as vllm_default_logger
|
16
|
+
except ImportError:
|
17
|
+
from vllm.logger import logger as vllm_default_logger
|
18
|
+
|
14
19
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
15
20
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
16
21
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -19,7 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|
19
24
|
FlushCacheReq,
|
20
25
|
TokenizedGenerateReqInput,
|
21
26
|
)
|
22
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
27
|
+
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
|
23
28
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
24
29
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
25
30
|
from sglang.srt.managers.router.scheduler import Scheduler
|
@@ -31,17 +36,20 @@ from sglang.srt.utils import (
|
|
31
36
|
is_multimodal_model,
|
32
37
|
set_random_seed,
|
33
38
|
)
|
34
|
-
|
39
|
+
|
35
40
|
|
36
41
|
logger = logging.getLogger("model_rpc")
|
42
|
+
vllm_default_logger.setLevel(logging.WARN)
|
43
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
37
44
|
|
38
45
|
|
39
|
-
class ModelRpcServer
|
40
|
-
def
|
46
|
+
class ModelRpcServer:
|
47
|
+
def __init__(
|
41
48
|
self,
|
42
49
|
tp_rank: int,
|
43
50
|
server_args: ServerArgs,
|
44
51
|
port_args: PortArgs,
|
52
|
+
model_overide_args: Optional[dict] = None,
|
45
53
|
):
|
46
54
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
47
55
|
|
@@ -50,18 +58,16 @@ class ModelRpcServer(rpyc.Service):
|
|
50
58
|
self.tp_size = server_args.tp_size
|
51
59
|
self.schedule_heuristic = server_args.schedule_heuristic
|
52
60
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
53
|
-
vllm_default_handler.setLevel(
|
54
|
-
level=getattr(logging, server_args.log_level.upper())
|
55
|
-
)
|
56
61
|
|
57
62
|
# Init model and tokenizer
|
58
63
|
self.model_config = ModelConfig(
|
59
64
|
server_args.model_path,
|
60
65
|
server_args.trust_remote_code,
|
61
66
|
context_length=server_args.context_length,
|
67
|
+
model_overide_args=model_overide_args,
|
62
68
|
)
|
63
69
|
|
64
|
-
#
|
70
|
+
# For model end global settings
|
65
71
|
server_args_dict = {
|
66
72
|
"enable_flashinfer": server_args.enable_flashinfer,
|
67
73
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
@@ -90,7 +96,6 @@ class ModelRpcServer(rpyc.Service):
|
|
90
96
|
tokenizer_mode=server_args.tokenizer_mode,
|
91
97
|
trust_remote_code=server_args.trust_remote_code,
|
92
98
|
)
|
93
|
-
self.eos_token_id = self.tokenizer.eos_token_id
|
94
99
|
self.max_total_num_token = self.model_runner.max_total_num_token
|
95
100
|
self.max_num_running_seq = self.max_total_num_token // 2
|
96
101
|
self.max_prefill_num_token = max(
|
@@ -111,10 +116,15 @@ class ModelRpcServer(rpyc.Service):
|
|
111
116
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
112
117
|
f"context_len={self.model_config.context_len}, "
|
113
118
|
)
|
114
|
-
|
119
|
+
if self.tp_rank == 0:
|
120
|
+
logger.info(f"server_args: {server_args.print_mode_args()}")
|
115
121
|
|
116
122
|
# Init cache
|
117
|
-
self.tree_cache = RadixCache(
|
123
|
+
self.tree_cache = RadixCache(
|
124
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
125
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
126
|
+
disable=server_args.disable_radix_cache,
|
127
|
+
)
|
118
128
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
119
129
|
self.scheduler = Scheduler(
|
120
130
|
self.schedule_heuristic,
|
@@ -132,6 +142,8 @@ class ModelRpcServer(rpyc.Service):
|
|
132
142
|
self.out_pyobjs = []
|
133
143
|
self.decode_forward_ct = 0
|
134
144
|
self.stream_interval = server_args.stream_interval
|
145
|
+
self.num_generated_tokens = 0
|
146
|
+
self.last_stats_tic = time.time()
|
135
147
|
|
136
148
|
# Init the FSM cache for constrained generation
|
137
149
|
self.regex_fsm_cache = FSMCache(
|
@@ -161,7 +173,7 @@ class ModelRpcServer(rpyc.Service):
|
|
161
173
|
logger.info("Cache flushed successfully!")
|
162
174
|
else:
|
163
175
|
warnings.warn(
|
164
|
-
"Cache not flushed because there are pending requests. "
|
176
|
+
f"Cache not flushed because there are pending requests. "
|
165
177
|
f"#queue-req: {len(self.forward_queue)}, "
|
166
178
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
167
179
|
)
|
@@ -198,6 +210,8 @@ class ModelRpcServer(rpyc.Service):
|
|
198
210
|
# Run new fill batch
|
199
211
|
self.forward_fill_batch(new_batch)
|
200
212
|
|
213
|
+
self.cache_filled_batch(new_batch)
|
214
|
+
|
201
215
|
if not new_batch.is_empty():
|
202
216
|
if self.running_batch is None:
|
203
217
|
self.running_batch = new_batch
|
@@ -208,6 +222,7 @@ class ModelRpcServer(rpyc.Service):
|
|
208
222
|
if self.running_batch is not None:
|
209
223
|
# Run a few decode batches continuously for reducing overhead
|
210
224
|
for _ in range(10):
|
225
|
+
self.num_generated_tokens += len(self.running_batch.reqs)
|
211
226
|
self.forward_decode_batch(self.running_batch)
|
212
227
|
|
213
228
|
if self.running_batch.is_empty():
|
@@ -223,10 +238,14 @@ class ModelRpcServer(rpyc.Service):
|
|
223
238
|
self.token_to_kv_pool.available_size()
|
224
239
|
+ self.tree_cache.evictable_size()
|
225
240
|
)
|
241
|
+
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
242
|
+
self.num_generated_tokens = 0
|
243
|
+
self.last_stats_tic = time.time()
|
226
244
|
logger.info(
|
227
245
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
228
246
|
f"#token: {num_used}, "
|
229
247
|
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
248
|
+
f"gen throughput (token/s): {throuhgput:.2f}, "
|
230
249
|
f"#queue-req: {len(self.forward_queue)}"
|
231
250
|
)
|
232
251
|
else:
|
@@ -262,6 +281,7 @@ class ModelRpcServer(rpyc.Service):
|
|
262
281
|
req.sampling_params = recv_req.sampling_params
|
263
282
|
req.return_logprob = recv_req.return_logprob
|
264
283
|
req.logprob_start_len = recv_req.logprob_start_len
|
284
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
265
285
|
req.stream = recv_req.stream
|
266
286
|
req.tokenizer = self.tokenizer
|
267
287
|
|
@@ -338,25 +358,26 @@ class ModelRpcServer(rpyc.Service):
|
|
338
358
|
and req.extend_input_len + new_batch_input_tokens
|
339
359
|
< self.max_prefill_num_token
|
340
360
|
):
|
341
|
-
delta = self.tree_cache.
|
361
|
+
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
342
362
|
available_size += delta
|
343
363
|
|
344
364
|
if not (
|
345
365
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
346
366
|
< available_size
|
347
367
|
):
|
348
|
-
# Undo
|
349
|
-
delta = self.tree_cache.
|
368
|
+
# Undo locking
|
369
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
350
370
|
available_size += delta
|
371
|
+
break
|
351
372
|
else:
|
352
373
|
# Add this request to the running batch
|
353
|
-
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
354
374
|
can_run_list.append(req)
|
355
375
|
new_batch_total_tokens += (
|
356
376
|
req.extend_input_len + req.max_new_tokens()
|
357
377
|
)
|
358
378
|
new_batch_input_tokens += req.extend_input_len
|
359
|
-
|
379
|
+
else:
|
380
|
+
break
|
360
381
|
if len(can_run_list) == 0:
|
361
382
|
return None
|
362
383
|
|
@@ -380,12 +401,12 @@ class ModelRpcServer(rpyc.Service):
|
|
380
401
|
f"#running_req: {running_req}. "
|
381
402
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
382
403
|
)
|
383
|
-
logger.debug(
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
)
|
404
|
+
#logger.debug(
|
405
|
+
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
406
|
+
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
407
|
+
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
408
|
+
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
409
|
+
#)
|
389
410
|
|
390
411
|
new_batch = Batch.init_new(
|
391
412
|
can_run_list,
|
@@ -402,56 +423,80 @@ class ModelRpcServer(rpyc.Service):
|
|
402
423
|
self.model_config.vocab_size, self.int_token_logit_bias
|
403
424
|
)
|
404
425
|
|
405
|
-
logprobs = None
|
406
426
|
if batch.extend_num_tokens != 0:
|
407
427
|
# Forward
|
408
428
|
logits, (
|
409
|
-
|
410
|
-
|
429
|
+
prefill_token_logprobs,
|
430
|
+
normalized_prompt_logprobs,
|
431
|
+
prefill_top_logprobs,
|
432
|
+
decode_top_logprobs,
|
411
433
|
last_logprobs,
|
412
|
-
) = self.model_runner.forward(
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
logprobs = prefill_logprobs.cpu().tolist()
|
417
|
-
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
434
|
+
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
435
|
+
if prefill_token_logprobs is not None:
|
436
|
+
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
437
|
+
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
418
438
|
|
419
439
|
next_token_ids, _ = batch.sample(logits)
|
420
|
-
|
440
|
+
|
441
|
+
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
442
|
+
if last_logprobs is not None:
|
443
|
+
last_token_logprobs = (
|
444
|
+
last_logprobs[
|
445
|
+
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
446
|
+
next_token_ids].tolist()
|
447
|
+
)
|
448
|
+
|
449
|
+
next_token_ids = next_token_ids.tolist()
|
421
450
|
else:
|
422
451
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
423
|
-
logits = logprobs = normalized_logprobs = last_logprobs = None
|
424
|
-
|
425
|
-
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
426
|
-
reqs = batch.reqs
|
427
|
-
if last_logprobs is not None:
|
428
|
-
last_logprobs = (
|
429
|
-
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
430
|
-
)
|
431
452
|
|
432
453
|
# Check finish condition
|
433
454
|
pt = 0
|
434
|
-
for i, req in enumerate(reqs):
|
455
|
+
for i, req in enumerate(batch.reqs):
|
435
456
|
req.completion_tokens_wo_jump_forward += 1
|
436
457
|
req.output_ids = [next_token_ids[i]]
|
437
458
|
req.check_finished()
|
438
459
|
|
439
|
-
if
|
440
|
-
req.
|
441
|
-
|
460
|
+
if req.return_logprob:
|
461
|
+
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
462
|
+
|
463
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
464
|
+
req.prefill_token_logprobs = list(
|
465
|
+
zip(
|
466
|
+
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
467
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
468
|
+
)
|
469
|
+
)
|
470
|
+
if req.logprob_start_len == 0:
|
471
|
+
req.prefill_token_logprobs = [
|
472
|
+
(None, req.input_ids[0])
|
473
|
+
] + req.prefill_token_logprobs
|
474
|
+
req.decode_token_logprobs = [
|
475
|
+
(last_token_logprobs[i], next_token_ids[i])
|
476
|
+
]
|
442
477
|
|
443
|
-
|
444
|
-
|
445
|
-
prompt_token_len = len(req.logprob)
|
446
|
-
token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
|
447
|
-
token_logprobs = req.logprob + [last_logprobs[i]]
|
448
|
-
req.token_logprob = list(zip(token_ids, token_logprobs))
|
478
|
+
if req.top_logprobs_num > 0:
|
479
|
+
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
449
480
|
if req.logprob_start_len == 0:
|
450
|
-
req.
|
451
|
-
|
481
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
482
|
+
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
483
|
+
|
484
|
+
pt += req.extend_input_len
|
452
485
|
|
453
486
|
self.handle_finished_requests(batch)
|
454
487
|
|
488
|
+
def cache_filled_batch(self, batch: Batch):
|
489
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
490
|
+
for i, req in enumerate(batch.reqs):
|
491
|
+
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
492
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
493
|
+
last_uncached_pos=len(req.prefix_indices),
|
494
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
495
|
+
del_in_memory_pool=False,
|
496
|
+
old_last_node=req.last_node,
|
497
|
+
)
|
498
|
+
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
499
|
+
|
455
500
|
def forward_decode_batch(self, batch: Batch):
|
456
501
|
# check if decode out of memory
|
457
502
|
if not batch.check_decode_mem():
|
@@ -497,29 +542,33 @@ class ModelRpcServer(rpyc.Service):
|
|
497
542
|
batch.prepare_for_decode()
|
498
543
|
|
499
544
|
# Forward
|
500
|
-
logits, (
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
545
|
+
logits, (
|
546
|
+
_,
|
547
|
+
_,
|
548
|
+
_,
|
549
|
+
decode_top_logprobs,
|
550
|
+
last_logprobs,
|
551
|
+
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
505
552
|
next_token_ids, _ = batch.sample(logits)
|
506
|
-
next_token_ids = next_token_ids.
|
553
|
+
next_token_ids = next_token_ids.tolist()
|
507
554
|
|
508
555
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
509
|
-
reqs = batch.reqs
|
510
556
|
if last_logprobs is not None:
|
511
|
-
|
512
|
-
torch.arange(len(reqs)), next_token_ids
|
557
|
+
new_token_logprobs = last_logprobs[
|
558
|
+
torch.arange(len(batch.reqs)), next_token_ids
|
513
559
|
].tolist()
|
514
560
|
|
515
561
|
# Check finish condition
|
516
|
-
for i, (req,
|
562
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
517
563
|
req.completion_tokens_wo_jump_forward += 1
|
518
|
-
req.output_ids.append(
|
564
|
+
req.output_ids.append(next_token_id)
|
519
565
|
req.check_finished()
|
520
566
|
|
521
|
-
if
|
522
|
-
req.
|
567
|
+
if req.return_logprob:
|
568
|
+
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
569
|
+
|
570
|
+
if req.top_logprobs_num > 0:
|
571
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
523
572
|
|
524
573
|
self.handle_finished_requests(batch)
|
525
574
|
|
@@ -529,6 +578,7 @@ class ModelRpcServer(rpyc.Service):
|
|
529
578
|
output_and_jump_forward_strs = []
|
530
579
|
output_hit_stop_str = []
|
531
580
|
output_skip_special_tokens = []
|
581
|
+
output_spaces_between_special_tokens = []
|
532
582
|
output_meta_info = []
|
533
583
|
output_finished = []
|
534
584
|
finished_indices = []
|
@@ -555,6 +605,9 @@ class ModelRpcServer(rpyc.Service):
|
|
555
605
|
output_skip_special_tokens.append(
|
556
606
|
req.sampling_params.skip_special_tokens
|
557
607
|
)
|
608
|
+
output_spaces_between_special_tokens.append(
|
609
|
+
req.sampling_params.spaces_between_special_tokens
|
610
|
+
)
|
558
611
|
|
559
612
|
meta_info = {
|
560
613
|
"prompt_tokens": req.prompt_tokens,
|
@@ -562,11 +615,23 @@ class ModelRpcServer(rpyc.Service):
|
|
562
615
|
+ len(req.output_ids)
|
563
616
|
- req.prompt_tokens,
|
564
617
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
618
|
+
"finish_reason": FinishReason.to_str(req.finish_reason),
|
619
|
+
"hit_stop_str": req.hit_stop_str,
|
565
620
|
}
|
566
621
|
if req.return_logprob:
|
567
|
-
|
568
|
-
|
569
|
-
|
622
|
+
(
|
623
|
+
meta_info["prefill_token_logprobs"],
|
624
|
+
meta_info["decode_token_logprobs"],
|
625
|
+
meta_info["prefill_top_logprobs"],
|
626
|
+
meta_info["decode_top_logprobs"],
|
627
|
+
meta_info["normalized_prompt_logprob"],
|
628
|
+
) = (
|
629
|
+
req.prefill_token_logprobs,
|
630
|
+
req.decode_token_logprobs,
|
631
|
+
req.prefill_top_logprobs,
|
632
|
+
req.decode_top_logprobs,
|
633
|
+
req.normalized_prompt_logprob,
|
634
|
+
)
|
570
635
|
output_meta_info.append(meta_info)
|
571
636
|
output_finished.append(req.finished)
|
572
637
|
|
@@ -579,6 +644,7 @@ class ModelRpcServer(rpyc.Service):
|
|
579
644
|
output_and_jump_forward_strs,
|
580
645
|
output_hit_stop_str,
|
581
646
|
output_skip_special_tokens,
|
647
|
+
output_spaces_between_special_tokens,
|
582
648
|
output_meta_info,
|
583
649
|
output_finished,
|
584
650
|
)
|
@@ -587,20 +653,16 @@ class ModelRpcServer(rpyc.Service):
|
|
587
653
|
# Remove finished reqs
|
588
654
|
if finished_indices:
|
589
655
|
# Update radix cache
|
590
|
-
req_pool_indices_cpu = batch.req_pool_indices.
|
656
|
+
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
591
657
|
for i in finished_indices:
|
592
658
|
req = batch.reqs[i]
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
prefix_len = self.tree_cache.insert(
|
598
|
-
token_ids[:seq_len], indices.clone()
|
659
|
+
self.tree_cache.cache_req(
|
660
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
661
|
+
last_uncached_pos=len(req.prefix_indices),
|
662
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
599
663
|
)
|
600
664
|
|
601
|
-
self.
|
602
|
-
self.req_to_token_pool.free(req_pool_idx)
|
603
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
665
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
604
666
|
|
605
667
|
# Update batch tensors
|
606
668
|
if unfinished_indices:
|
@@ -609,14 +671,21 @@ class ModelRpcServer(rpyc.Service):
|
|
609
671
|
batch.reqs = []
|
610
672
|
|
611
673
|
|
674
|
+
class ModelRpcService(rpyc.Service):
|
675
|
+
exposed_ModelRpcServer = ModelRpcServer
|
676
|
+
|
677
|
+
|
612
678
|
class ModelRpcClient:
|
613
|
-
def __init__(
|
679
|
+
def __init__(
|
680
|
+
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
681
|
+
):
|
614
682
|
tp_size = server_args.tp_size
|
615
683
|
|
616
684
|
if tp_size == 1:
|
617
685
|
# Init model
|
618
|
-
self.model_server =
|
619
|
-
|
686
|
+
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
687
|
+
0, server_args, port_args, model_overide_args
|
688
|
+
)
|
620
689
|
|
621
690
|
# Wrap functions
|
622
691
|
def async_wrap(f):
|
@@ -630,14 +699,16 @@ class ModelRpcClient:
|
|
630
699
|
with ThreadPoolExecutor(tp_size) as executor:
|
631
700
|
# Launch model processes
|
632
701
|
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
633
|
-
self.
|
702
|
+
self.remote_services = [x[0] for x in rets]
|
634
703
|
self.procs = [x[1] for x in rets]
|
635
704
|
|
636
705
|
# Init model
|
637
706
|
def init_model(i):
|
638
|
-
return self.
|
707
|
+
return self.remote_services[i].ModelRpcServer(
|
708
|
+
i, server_args, port_args, model_overide_args
|
709
|
+
)
|
639
710
|
|
640
|
-
|
711
|
+
self.model_servers = executor.map(init_model, range(tp_size))
|
641
712
|
|
642
713
|
# Wrap functions
|
643
714
|
def async_wrap(func_name):
|
@@ -655,9 +726,13 @@ class ModelRpcClient:
|
|
655
726
|
|
656
727
|
def _init_service(port):
|
657
728
|
t = ThreadedServer(
|
658
|
-
|
729
|
+
ModelRpcService(),
|
659
730
|
port=port,
|
660
|
-
protocol_config={
|
731
|
+
protocol_config={
|
732
|
+
"allow_public_attrs": True,
|
733
|
+
"allow_pickle": True,
|
734
|
+
"sync_request_timeout": 1800,
|
735
|
+
},
|
661
736
|
)
|
662
737
|
t.start()
|
663
738
|
|
@@ -673,7 +748,11 @@ def start_model_process(port):
|
|
673
748
|
con = rpyc.connect(
|
674
749
|
"localhost",
|
675
750
|
port,
|
676
|
-
config={
|
751
|
+
config={
|
752
|
+
"allow_public_attrs": True,
|
753
|
+
"allow_pickle": True,
|
754
|
+
"sync_request_timeout": 1800,
|
755
|
+
},
|
677
756
|
)
|
678
757
|
break
|
679
758
|
except ConnectionRefusedError:
|