sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +111 -167
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +20 -21
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -447
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
- sglang-0.1.15.dist-info/RECORD +69 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -6,11 +6,15 @@ import warnings
|
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
7
|
from typing import List
|
8
8
|
|
9
|
-
import numpy as np
|
10
9
|
import rpyc
|
11
10
|
import torch
|
12
11
|
from rpyc.utils.classic import obtain
|
13
12
|
from rpyc.utils.server import ThreadedServer
|
13
|
+
try:
|
14
|
+
from vllm.logger import _default_handler as vllm_default_logger
|
15
|
+
except ImportError:
|
16
|
+
from vllm.logger import logger as vllm_default_logger
|
17
|
+
|
14
18
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
15
19
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
16
20
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -31,13 +35,15 @@ from sglang.srt.utils import (
|
|
31
35
|
is_multimodal_model,
|
32
36
|
set_random_seed,
|
33
37
|
)
|
34
|
-
|
38
|
+
|
35
39
|
|
36
40
|
logger = logging.getLogger("model_rpc")
|
41
|
+
vllm_default_logger.setLevel(logging.WARN)
|
42
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
37
43
|
|
38
44
|
|
39
|
-
class ModelRpcServer
|
40
|
-
def
|
45
|
+
class ModelRpcServer:
|
46
|
+
def __init__(
|
41
47
|
self,
|
42
48
|
tp_rank: int,
|
43
49
|
server_args: ServerArgs,
|
@@ -50,9 +56,6 @@ class ModelRpcServer(rpyc.Service):
|
|
50
56
|
self.tp_size = server_args.tp_size
|
51
57
|
self.schedule_heuristic = server_args.schedule_heuristic
|
52
58
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
53
|
-
vllm_default_handler.setLevel(
|
54
|
-
level=getattr(logging, server_args.log_level.upper())
|
55
|
-
)
|
56
59
|
|
57
60
|
# Init model and tokenizer
|
58
61
|
self.model_config = ModelConfig(
|
@@ -61,7 +64,7 @@ class ModelRpcServer(rpyc.Service):
|
|
61
64
|
context_length=server_args.context_length,
|
62
65
|
)
|
63
66
|
|
64
|
-
#
|
67
|
+
# For model end global settings
|
65
68
|
server_args_dict = {
|
66
69
|
"enable_flashinfer": server_args.enable_flashinfer,
|
67
70
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
@@ -90,7 +93,6 @@ class ModelRpcServer(rpyc.Service):
|
|
90
93
|
tokenizer_mode=server_args.tokenizer_mode,
|
91
94
|
trust_remote_code=server_args.trust_remote_code,
|
92
95
|
)
|
93
|
-
self.eos_token_id = self.tokenizer.eos_token_id
|
94
96
|
self.max_total_num_token = self.model_runner.max_total_num_token
|
95
97
|
self.max_num_running_seq = self.max_total_num_token // 2
|
96
98
|
self.max_prefill_num_token = max(
|
@@ -111,10 +113,11 @@ class ModelRpcServer(rpyc.Service):
|
|
111
113
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
112
114
|
f"context_len={self.model_config.context_len}, "
|
113
115
|
)
|
114
|
-
|
116
|
+
if self.tp_rank == 0:
|
117
|
+
logger.info(f"server_args: {server_args.print_mode_args()}")
|
115
118
|
|
116
119
|
# Init cache
|
117
|
-
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
120
|
+
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
|
118
121
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
119
122
|
self.scheduler = Scheduler(
|
120
123
|
self.schedule_heuristic,
|
@@ -161,7 +164,7 @@ class ModelRpcServer(rpyc.Service):
|
|
161
164
|
logger.info("Cache flushed successfully!")
|
162
165
|
else:
|
163
166
|
warnings.warn(
|
164
|
-
"Cache not flushed because there are pending requests. "
|
167
|
+
f"Cache not flushed because there are pending requests. "
|
165
168
|
f"#queue-req: {len(self.forward_queue)}, "
|
166
169
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
167
170
|
)
|
@@ -262,6 +265,7 @@ class ModelRpcServer(rpyc.Service):
|
|
262
265
|
req.sampling_params = recv_req.sampling_params
|
263
266
|
req.return_logprob = recv_req.return_logprob
|
264
267
|
req.logprob_start_len = recv_req.logprob_start_len
|
268
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
265
269
|
req.stream = recv_req.stream
|
266
270
|
req.tokenizer = self.tokenizer
|
267
271
|
|
@@ -348,6 +352,7 @@ class ModelRpcServer(rpyc.Service):
|
|
348
352
|
# Undo the insertion
|
349
353
|
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
350
354
|
available_size += delta
|
355
|
+
break
|
351
356
|
else:
|
352
357
|
# Add this request to the running batch
|
353
358
|
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
@@ -356,7 +361,8 @@ class ModelRpcServer(rpyc.Service):
|
|
356
361
|
req.extend_input_len + req.max_new_tokens()
|
357
362
|
)
|
358
363
|
new_batch_input_tokens += req.extend_input_len
|
359
|
-
|
364
|
+
else:
|
365
|
+
break
|
360
366
|
if len(can_run_list) == 0:
|
361
367
|
return None
|
362
368
|
|
@@ -380,12 +386,12 @@ class ModelRpcServer(rpyc.Service):
|
|
380
386
|
f"#running_req: {running_req}. "
|
381
387
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
382
388
|
)
|
383
|
-
logger.debug(
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
)
|
389
|
+
#logger.debug(
|
390
|
+
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
391
|
+
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
392
|
+
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
393
|
+
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
394
|
+
#)
|
389
395
|
|
390
396
|
new_batch = Batch.init_new(
|
391
397
|
can_run_list,
|
@@ -402,53 +408,63 @@ class ModelRpcServer(rpyc.Service):
|
|
402
408
|
self.model_config.vocab_size, self.int_token_logit_bias
|
403
409
|
)
|
404
410
|
|
405
|
-
logprobs = None
|
406
411
|
if batch.extend_num_tokens != 0:
|
407
412
|
# Forward
|
408
413
|
logits, (
|
409
|
-
|
410
|
-
|
414
|
+
prefill_token_logprobs,
|
415
|
+
normalized_prompt_logprobs,
|
416
|
+
prefill_top_logprobs,
|
417
|
+
decode_top_logprobs,
|
411
418
|
last_logprobs,
|
412
|
-
) = self.model_runner.forward(
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
logprobs = prefill_logprobs.cpu().tolist()
|
417
|
-
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
419
|
+
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
420
|
+
if prefill_token_logprobs is not None:
|
421
|
+
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
422
|
+
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
418
423
|
|
419
424
|
next_token_ids, _ = batch.sample(logits)
|
420
|
-
|
425
|
+
|
426
|
+
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
427
|
+
if last_logprobs is not None:
|
428
|
+
last_token_logprobs = (
|
429
|
+
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
|
430
|
+
)
|
431
|
+
|
432
|
+
next_token_ids = next_token_ids.tolist()
|
421
433
|
else:
|
422
434
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
423
|
-
logits = logprobs = normalized_logprobs = last_logprobs = None
|
424
|
-
|
425
|
-
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
426
|
-
reqs = batch.reqs
|
427
|
-
if last_logprobs is not None:
|
428
|
-
last_logprobs = (
|
429
|
-
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
430
|
-
)
|
431
435
|
|
432
436
|
# Check finish condition
|
433
437
|
pt = 0
|
434
|
-
for i, req in enumerate(reqs):
|
438
|
+
for i, req in enumerate(batch.reqs):
|
435
439
|
req.completion_tokens_wo_jump_forward += 1
|
436
440
|
req.output_ids = [next_token_ids[i]]
|
437
441
|
req.check_finished()
|
438
442
|
|
439
|
-
if
|
440
|
-
req.
|
441
|
-
req.normalized_logprob = normalized_logprobs[i]
|
443
|
+
if req.return_logprob:
|
444
|
+
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
442
445
|
|
443
|
-
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
446
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
447
|
+
req.prefill_token_logprobs = list(
|
448
|
+
zip(
|
449
|
+
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
450
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
451
|
+
)
|
452
|
+
)
|
453
|
+
if req.logprob_start_len == 0:
|
454
|
+
req.prefill_token_logprobs = [
|
455
|
+
(None, req.input_ids[0])
|
456
|
+
] + req.prefill_token_logprobs
|
457
|
+
req.decode_token_logprobs = [
|
458
|
+
(last_token_logprobs[i], next_token_ids[i])
|
459
|
+
]
|
460
|
+
|
461
|
+
if req.top_logprobs_num > 0:
|
462
|
+
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
449
463
|
if req.logprob_start_len == 0:
|
450
|
-
req.
|
451
|
-
|
464
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
465
|
+
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
466
|
+
|
467
|
+
pt += req.extend_input_len
|
452
468
|
|
453
469
|
self.handle_finished_requests(batch)
|
454
470
|
|
@@ -497,29 +513,33 @@ class ModelRpcServer(rpyc.Service):
|
|
497
513
|
batch.prepare_for_decode()
|
498
514
|
|
499
515
|
# Forward
|
500
|
-
logits, (
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
516
|
+
logits, (
|
517
|
+
_,
|
518
|
+
_,
|
519
|
+
_,
|
520
|
+
decode_top_logprobs,
|
521
|
+
last_logprobs,
|
522
|
+
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
505
523
|
next_token_ids, _ = batch.sample(logits)
|
506
|
-
next_token_ids = next_token_ids.
|
524
|
+
next_token_ids = next_token_ids.tolist()
|
507
525
|
|
508
526
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
509
|
-
reqs = batch.reqs
|
510
527
|
if last_logprobs is not None:
|
511
|
-
|
512
|
-
torch.arange(len(reqs)), next_token_ids
|
528
|
+
new_token_logprobs = last_logprobs[
|
529
|
+
torch.arange(len(batch.reqs)), next_token_ids
|
513
530
|
].tolist()
|
514
531
|
|
515
532
|
# Check finish condition
|
516
|
-
for i, (req,
|
533
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
517
534
|
req.completion_tokens_wo_jump_forward += 1
|
518
|
-
req.output_ids.append(
|
535
|
+
req.output_ids.append(next_token_id)
|
519
536
|
req.check_finished()
|
520
537
|
|
521
|
-
if
|
522
|
-
req.
|
538
|
+
if req.return_logprob:
|
539
|
+
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
540
|
+
|
541
|
+
if req.top_logprobs_num > 0:
|
542
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
523
543
|
|
524
544
|
self.handle_finished_requests(batch)
|
525
545
|
|
@@ -529,6 +549,7 @@ class ModelRpcServer(rpyc.Service):
|
|
529
549
|
output_and_jump_forward_strs = []
|
530
550
|
output_hit_stop_str = []
|
531
551
|
output_skip_special_tokens = []
|
552
|
+
output_spaces_between_special_tokens = []
|
532
553
|
output_meta_info = []
|
533
554
|
output_finished = []
|
534
555
|
finished_indices = []
|
@@ -555,6 +576,9 @@ class ModelRpcServer(rpyc.Service):
|
|
555
576
|
output_skip_special_tokens.append(
|
556
577
|
req.sampling_params.skip_special_tokens
|
557
578
|
)
|
579
|
+
output_spaces_between_special_tokens.append(
|
580
|
+
req.sampling_params.spaces_between_special_tokens
|
581
|
+
)
|
558
582
|
|
559
583
|
meta_info = {
|
560
584
|
"prompt_tokens": req.prompt_tokens,
|
@@ -562,11 +586,22 @@ class ModelRpcServer(rpyc.Service):
|
|
562
586
|
+ len(req.output_ids)
|
563
587
|
- req.prompt_tokens,
|
564
588
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
589
|
+
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
|
565
590
|
}
|
566
591
|
if req.return_logprob:
|
567
|
-
|
568
|
-
|
569
|
-
|
592
|
+
(
|
593
|
+
meta_info["prefill_token_logprobs"],
|
594
|
+
meta_info["decode_token_logprobs"],
|
595
|
+
meta_info["prefill_top_logprobs"],
|
596
|
+
meta_info["decode_top_logprobs"],
|
597
|
+
meta_info["normalized_prompt_logprob"],
|
598
|
+
) = (
|
599
|
+
req.prefill_token_logprobs,
|
600
|
+
req.decode_token_logprobs,
|
601
|
+
req.prefill_top_logprobs,
|
602
|
+
req.decode_top_logprobs,
|
603
|
+
req.normalized_prompt_logprob,
|
604
|
+
)
|
570
605
|
output_meta_info.append(meta_info)
|
571
606
|
output_finished.append(req.finished)
|
572
607
|
|
@@ -579,6 +614,7 @@ class ModelRpcServer(rpyc.Service):
|
|
579
614
|
output_and_jump_forward_strs,
|
580
615
|
output_hit_stop_str,
|
581
616
|
output_skip_special_tokens,
|
617
|
+
output_spaces_between_special_tokens,
|
582
618
|
output_meta_info,
|
583
619
|
output_finished,
|
584
620
|
)
|
@@ -587,7 +623,7 @@ class ModelRpcServer(rpyc.Service):
|
|
587
623
|
# Remove finished reqs
|
588
624
|
if finished_indices:
|
589
625
|
# Update radix cache
|
590
|
-
req_pool_indices_cpu = batch.req_pool_indices.
|
626
|
+
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
591
627
|
for i in finished_indices:
|
592
628
|
req = batch.reqs[i]
|
593
629
|
req_pool_idx = req_pool_indices_cpu[i]
|
@@ -598,7 +634,7 @@ class ModelRpcServer(rpyc.Service):
|
|
598
634
|
token_ids[:seq_len], indices.clone()
|
599
635
|
)
|
600
636
|
|
601
|
-
self.token_to_kv_pool.
|
637
|
+
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
602
638
|
self.req_to_token_pool.free(req_pool_idx)
|
603
639
|
self.tree_cache.dec_ref_counter(req.last_node)
|
604
640
|
|
@@ -609,14 +645,19 @@ class ModelRpcServer(rpyc.Service):
|
|
609
645
|
batch.reqs = []
|
610
646
|
|
611
647
|
|
648
|
+
class ModelRpcService(rpyc.Service):
|
649
|
+
exposed_ModelRpcServer = ModelRpcServer
|
650
|
+
|
651
|
+
|
612
652
|
class ModelRpcClient:
|
613
653
|
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
614
654
|
tp_size = server_args.tp_size
|
615
655
|
|
616
656
|
if tp_size == 1:
|
617
657
|
# Init model
|
618
|
-
self.model_server =
|
619
|
-
|
658
|
+
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
659
|
+
0, server_args, port_args
|
660
|
+
)
|
620
661
|
|
621
662
|
# Wrap functions
|
622
663
|
def async_wrap(f):
|
@@ -630,14 +671,16 @@ class ModelRpcClient:
|
|
630
671
|
with ThreadPoolExecutor(tp_size) as executor:
|
631
672
|
# Launch model processes
|
632
673
|
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
633
|
-
self.
|
674
|
+
self.remote_services = [x[0] for x in rets]
|
634
675
|
self.procs = [x[1] for x in rets]
|
635
676
|
|
636
677
|
# Init model
|
637
678
|
def init_model(i):
|
638
|
-
return self.
|
679
|
+
return self.remote_services[i].ModelRpcServer(
|
680
|
+
i, server_args, port_args
|
681
|
+
)
|
639
682
|
|
640
|
-
|
683
|
+
self.model_servers = executor.map(init_model, range(tp_size))
|
641
684
|
|
642
685
|
# Wrap functions
|
643
686
|
def async_wrap(func_name):
|
@@ -655,7 +698,7 @@ class ModelRpcClient:
|
|
655
698
|
|
656
699
|
def _init_service(port):
|
657
700
|
t = ThreadedServer(
|
658
|
-
|
701
|
+
ModelRpcService(),
|
659
702
|
port=port,
|
660
703
|
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
661
704
|
)
|