sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
+
import inspect
|
2
3
|
|
3
4
|
import uvloop
|
4
5
|
import zmq
|
@@ -7,7 +8,8 @@ import zmq.asyncio
|
|
7
8
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
8
9
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
9
10
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
10
|
-
from sglang.
|
11
|
+
from sglang.utils import get_exception_traceback, graceful_registry
|
12
|
+
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
11
13
|
|
12
14
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
13
15
|
|
@@ -33,51 +35,47 @@ class DetokenizerManager:
|
|
33
35
|
|
34
36
|
async def handle_loop(self):
|
35
37
|
while True:
|
36
|
-
recv_obj = await self.recv_from_router.recv_pyobj()
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
output_strs[i] = output_strs[i][:pos]
|
57
|
-
|
58
|
-
if len(output_tokens[i]) > 0:
|
59
|
-
first_token = self.tokenizer.convert_ids_to_tokens(
|
60
|
-
int(output_tokens[i][0])
|
61
|
-
)
|
62
|
-
if not isinstance(first_token, str):
|
63
|
-
first_token = first_token.decode("utf-8", errors="ignore")
|
64
|
-
if first_token.startswith("▁"):
|
65
|
-
output_strs[i] = " " + output_strs[i]
|
66
|
-
|
67
|
-
output_strs[i] = (
|
68
|
-
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
|
69
|
-
)
|
70
|
-
|
71
|
-
self.send_to_tokenizer.send_pyobj(
|
72
|
-
BatchStrOut(
|
73
|
-
recv_obj.rids,
|
74
|
-
output_strs,
|
75
|
-
recv_obj.meta_info,
|
76
|
-
recv_obj.finished,
|
38
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
39
|
+
assert isinstance(recv_obj, BatchTokenIDOut)
|
40
|
+
|
41
|
+
output_tokens = recv_obj.output_tokens
|
42
|
+
|
43
|
+
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
|
+
output_strs = self.tokenizer.batch_decode(
|
45
|
+
output_tokens,
|
46
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
48
|
+
0
|
49
|
+
],
|
50
|
+
)
|
51
|
+
|
52
|
+
# Trim stop str
|
53
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
54
|
+
for i in range(len(output_strs)):
|
55
|
+
if len(output_tokens[i]) > 0:
|
56
|
+
first_token = self.tokenizer.convert_ids_to_tokens(
|
57
|
+
int(output_tokens[i][0])
|
77
58
|
)
|
59
|
+
if not isinstance(first_token, str):
|
60
|
+
first_token = first_token.decode("utf-8", errors="ignore")
|
61
|
+
if first_token.startswith("▁"):
|
62
|
+
output_strs[i] = " " + output_strs[i]
|
63
|
+
|
64
|
+
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
65
|
+
|
66
|
+
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
67
|
+
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
68
|
+
if pos != -1:
|
69
|
+
output_strs[i] = output_strs[i][:pos]
|
70
|
+
|
71
|
+
self.send_to_tokenizer.send_pyobj(
|
72
|
+
BatchStrOut(
|
73
|
+
rids=recv_obj.rids,
|
74
|
+
output_str=output_strs,
|
75
|
+
meta_info=recv_obj.meta_info,
|
76
|
+
finished_reason=recv_obj.finished_reason,
|
78
77
|
)
|
79
|
-
|
80
|
-
raise ValueError(f"Invalid object: {recv_obj}")
|
78
|
+
)
|
81
79
|
|
82
80
|
|
83
81
|
def start_detokenizer_process(
|
@@ -85,9 +83,11 @@ def start_detokenizer_process(
|
|
85
83
|
port_args: PortArgs,
|
86
84
|
pipe_writer,
|
87
85
|
):
|
86
|
+
graceful_registry(inspect.currentframe().f_code.co_name)
|
87
|
+
|
88
88
|
try:
|
89
89
|
manager = DetokenizerManager(server_args, port_args)
|
90
|
-
except Exception
|
90
|
+
except Exception:
|
91
91
|
pipe_writer.send(get_exception_traceback())
|
92
92
|
raise
|
93
93
|
pipe_writer.send("init ok")
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -3,12 +3,15 @@ from dataclasses import dataclass
|
|
3
3
|
from typing import Dict, List, Optional, Union
|
4
4
|
|
5
5
|
from sglang.srt.sampling_params import SamplingParams
|
6
|
+
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
6
7
|
|
7
8
|
|
8
9
|
@dataclass
|
9
10
|
class GenerateReqInput:
|
10
11
|
# The input prompt
|
11
|
-
text: Union[List[str], str]
|
12
|
+
text: Optional[Union[List[str], str]] = None
|
13
|
+
# The token ids for text; one can either specify text or input_ids
|
14
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
12
15
|
# The image input
|
13
16
|
image_data: Optional[Union[List[str], str]] = None
|
14
17
|
# The sampling_params
|
@@ -25,10 +28,19 @@ class GenerateReqInput:
|
|
25
28
|
return_text_in_logprobs: bool = False
|
26
29
|
# Whether to stream output
|
27
30
|
stream: bool = False
|
28
|
-
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
29
31
|
|
30
32
|
def post_init(self):
|
31
|
-
|
33
|
+
|
34
|
+
if (self.text is None and self.input_ids is None) or (
|
35
|
+
self.text is not None and self.input_ids is not None
|
36
|
+
):
|
37
|
+
raise ValueError("Either text or input_ids should be provided.")
|
38
|
+
|
39
|
+
if self.text is not None:
|
40
|
+
is_single = isinstance(self.text, str)
|
41
|
+
else:
|
42
|
+
is_single = isinstance(self.input_ids[0], int)
|
43
|
+
self.is_single = is_single
|
32
44
|
|
33
45
|
if is_single:
|
34
46
|
if self.sampling_params is None:
|
@@ -42,7 +54,7 @@ class GenerateReqInput:
|
|
42
54
|
if self.top_logprobs_num is None:
|
43
55
|
self.top_logprobs_num = 0
|
44
56
|
else:
|
45
|
-
num = len(self.text)
|
57
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
46
58
|
|
47
59
|
if self.image_data is None:
|
48
60
|
self.image_data = [None] * num
|
@@ -57,7 +69,8 @@ class GenerateReqInput:
|
|
57
69
|
if self.rid is None:
|
58
70
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
59
71
|
else:
|
60
|
-
|
72
|
+
if not isinstance(self.rid, list):
|
73
|
+
raise ValueError("The rid should be a list.")
|
61
74
|
|
62
75
|
if self.return_logprob is None:
|
63
76
|
self.return_logprob = [False] * num
|
@@ -93,21 +106,19 @@ class TokenizedGenerateReqInput:
|
|
93
106
|
@dataclass
|
94
107
|
class BatchTokenIDOut:
|
95
108
|
rids: List[str]
|
109
|
+
prev_output_strs: List[str]
|
96
110
|
output_tokens: List[List[int]]
|
97
|
-
output_and_jump_forward_strs: List[str]
|
98
|
-
hit_stop_str: List[Optional[str]]
|
99
111
|
skip_special_tokens: List[bool]
|
100
112
|
spaces_between_special_tokens: List[bool]
|
101
113
|
meta_info: List[Dict]
|
102
|
-
|
103
|
-
|
114
|
+
finished_reason: List[BaseFinishReason]
|
104
115
|
|
105
116
|
@dataclass
|
106
117
|
class BatchStrOut:
|
107
118
|
rids: List[str]
|
108
119
|
output_str: List[str]
|
109
120
|
meta_info: List[Dict]
|
110
|
-
|
121
|
+
finished_reason: List[BaseFinishReason]
|
111
122
|
|
112
123
|
|
113
124
|
@dataclass
|
@@ -115,6 +126,11 @@ class FlushCacheReq:
|
|
115
126
|
pass
|
116
127
|
|
117
128
|
|
129
|
+
@dataclass
|
130
|
+
class AbortReq:
|
131
|
+
rid: str
|
132
|
+
|
133
|
+
|
118
134
|
@dataclass
|
119
135
|
class DetokenizeReqInput:
|
120
136
|
input_ids: List[int]
|
@@ -19,18 +19,32 @@ class FinishReason(IntEnum):
|
|
19
19
|
EOS_TOKEN = auto()
|
20
20
|
LENGTH = auto()
|
21
21
|
STOP_STR = auto()
|
22
|
+
ABORT = auto()
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def to_str(reason):
|
26
|
+
if reason == FinishReason.EOS_TOKEN:
|
27
|
+
return None
|
28
|
+
elif reason == FinishReason.LENGTH:
|
29
|
+
return "length"
|
30
|
+
elif reason == FinishReason.STOP_STR:
|
31
|
+
return "stop"
|
32
|
+
elif reason == FinishReason.ABORT:
|
33
|
+
return "abort"
|
34
|
+
else:
|
35
|
+
return None
|
22
36
|
|
23
37
|
|
24
38
|
class Req:
|
25
|
-
def __init__(self, rid,
|
39
|
+
def __init__(self, rid, origin_input_text, origin_input_ids):
|
26
40
|
self.rid = rid
|
27
|
-
self.
|
28
|
-
self.
|
41
|
+
self.origin_input_text = origin_input_text
|
42
|
+
self.origin_input_ids = origin_input_ids
|
43
|
+
self.origin_input_ids_unpadded = origin_input_ids # before image padding
|
44
|
+
self.prev_output_str = ""
|
45
|
+
self.prev_output_ids = []
|
29
46
|
self.output_ids = []
|
30
|
-
|
31
|
-
# Since jump forward may retokenize the prompt with partial outputs,
|
32
|
-
# we maintain the original prompt length to report the correct usage.
|
33
|
-
self.prompt_tokens = len(input_ids)
|
47
|
+
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
|
34
48
|
|
35
49
|
# The number of decoded tokens for token usage report. Note that
|
36
50
|
# this does not include the jump forward tokens.
|
@@ -52,6 +66,7 @@ class Req:
|
|
52
66
|
self.finish_reason = None
|
53
67
|
self.hit_stop_str = None
|
54
68
|
|
69
|
+
# Prefix info
|
55
70
|
self.extend_input_len = 0
|
56
71
|
self.prefix_indices = []
|
57
72
|
self.last_node = None
|
@@ -62,67 +77,36 @@ class Req:
|
|
62
77
|
self.top_logprobs_num = 0
|
63
78
|
self.normalized_prompt_logprob = None
|
64
79
|
self.prefill_token_logprobs = None
|
65
|
-
self.decode_token_logprobs = None
|
66
80
|
self.prefill_top_logprobs = None
|
67
|
-
self.
|
81
|
+
self.decode_token_logprobs = []
|
82
|
+
self.decode_top_logprobs = []
|
83
|
+
# The tokens is prefilled but need to be considered as decode tokens
|
84
|
+
# and should be updated for the decode logprobs
|
85
|
+
self.last_update_decode_tokens = 0
|
68
86
|
|
69
87
|
# Constrained decoding
|
70
88
|
self.regex_fsm = None
|
71
89
|
self.regex_fsm_state = 0
|
72
90
|
self.jump_forward_map = None
|
73
|
-
self.output_and_jump_forward_str = ""
|
74
|
-
|
75
|
-
def max_new_tokens(self):
|
76
|
-
return self.sampling_params.max_new_tokens
|
77
91
|
|
78
|
-
def
|
79
|
-
|
80
|
-
# FIXME: This logic does not really solve the problem of determining whether
|
81
|
-
# there should be a leading space.
|
82
|
-
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
92
|
+
def partial_decode(self, ids):
|
93
|
+
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
83
94
|
first_token = (
|
84
95
|
first_token.decode() if isinstance(first_token, bytes) else first_token
|
85
96
|
)
|
86
|
-
if first_token.startswith("▁")
|
87
|
-
old_output_str = " " + old_output_str
|
88
|
-
new_input_string = (
|
89
|
-
self.input_text
|
90
|
-
+ self.output_and_jump_forward_str
|
91
|
-
+ old_output_str
|
92
|
-
+ jump_forward_str
|
93
|
-
)
|
94
|
-
new_input_ids = self.tokenizer.encode(new_input_string)
|
95
|
-
if self.pixel_values is not None:
|
96
|
-
# NOTE: This is a hack because the old input_ids contains the image padding
|
97
|
-
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
98
|
-
else:
|
99
|
-
jump_forward_tokens_len = (
|
100
|
-
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
101
|
-
)
|
102
|
-
|
103
|
-
# print("=" * 100)
|
104
|
-
# print(f"Catch jump forward:\n{jump_forward_str}")
|
105
|
-
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
106
|
-
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
97
|
+
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
|
107
98
|
|
108
|
-
|
109
|
-
self.
|
110
|
-
self.sampling_params.max_new_tokens = max(
|
111
|
-
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
|
112
|
-
)
|
113
|
-
self.regex_fsm_state = next_state
|
114
|
-
self.output_and_jump_forward_str = (
|
115
|
-
self.output_and_jump_forward_str + old_output_str + jump_forward_str
|
116
|
-
)
|
117
|
-
|
118
|
-
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
119
|
-
# print("*" * 100)
|
99
|
+
def max_new_tokens(self):
|
100
|
+
return self.sampling_params.max_new_tokens
|
120
101
|
|
121
102
|
def check_finished(self):
|
122
103
|
if self.finished:
|
123
104
|
return
|
124
105
|
|
125
|
-
if
|
106
|
+
if (
|
107
|
+
len(self.prev_output_ids) + len(self.output_ids)
|
108
|
+
>= self.sampling_params.max_new_tokens
|
109
|
+
):
|
126
110
|
self.finished = True
|
127
111
|
self.finish_reason = FinishReason.LENGTH
|
128
112
|
return
|
@@ -141,14 +125,66 @@ class Req:
|
|
141
125
|
)
|
142
126
|
|
143
127
|
for stop_str in self.sampling_params.stop_strs:
|
144
|
-
|
128
|
+
# FIXME: (minor) try incremental match in prev_output_str
|
129
|
+
if stop_str in tail_str or stop_str in self.prev_output_str:
|
145
130
|
self.finished = True
|
146
131
|
self.finish_reason = FinishReason.STOP_STR
|
147
132
|
self.hit_stop_str = stop_str
|
148
133
|
return
|
149
134
|
|
135
|
+
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
136
|
+
# FIXME: This logic does not really solve the problem of determining whether
|
137
|
+
# there should be a leading space.
|
138
|
+
cur_output_str = self.partial_decode(self.output_ids)
|
139
|
+
|
140
|
+
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
141
|
+
if self.origin_input_text is None:
|
142
|
+
# Recovering text can only use unpadded ids
|
143
|
+
self.origin_input_text = self.tokenizer.decode(
|
144
|
+
self.origin_input_ids_unpadded
|
145
|
+
)
|
146
|
+
|
147
|
+
all_text = (
|
148
|
+
self.origin_input_text
|
149
|
+
+ self.prev_output_str
|
150
|
+
+ cur_output_str
|
151
|
+
+ jump_forward_str
|
152
|
+
)
|
153
|
+
all_ids = self.tokenizer.encode(all_text)
|
154
|
+
prompt_tokens = len(self.origin_input_ids_unpadded)
|
155
|
+
self.origin_input_ids = all_ids[:prompt_tokens]
|
156
|
+
self.origin_input_ids_unpadded = self.origin_input_ids
|
157
|
+
# NOTE: the output ids may not strictly correspond to the output text
|
158
|
+
old_prev_output_ids = self.prev_output_ids
|
159
|
+
self.prev_output_ids = all_ids[prompt_tokens:]
|
160
|
+
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
|
161
|
+
self.output_ids = []
|
162
|
+
|
163
|
+
self.regex_fsm_state = next_state
|
164
|
+
|
165
|
+
if self.return_logprob:
|
166
|
+
# For fast-forward part's logprobs
|
167
|
+
k = 0
|
168
|
+
for i, old_id in enumerate(old_prev_output_ids):
|
169
|
+
if old_id == self.prev_output_ids[i]:
|
170
|
+
k = k + 1
|
171
|
+
else:
|
172
|
+
break
|
173
|
+
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
174
|
+
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
175
|
+
self.logprob_start_len = prompt_tokens + k
|
176
|
+
self.last_update_decode_tokens = len(self.prev_output_ids) - k
|
177
|
+
|
178
|
+
# print("=" * 100)
|
179
|
+
# print(f"Catch jump forward:\n{jump_forward_str}")
|
180
|
+
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
181
|
+
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
182
|
+
|
183
|
+
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
184
|
+
# print("*" * 100)
|
185
|
+
|
150
186
|
def __repr__(self):
|
151
|
-
return f"rid(n={self.rid}, " f"input_ids={self.
|
187
|
+
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
152
188
|
|
153
189
|
|
154
190
|
@dataclass
|
@@ -319,6 +355,7 @@ class Batch:
|
|
319
355
|
|
320
356
|
def retract_decode(self):
|
321
357
|
sorted_indices = [i for i in range(len(self.reqs))]
|
358
|
+
# TODO(lsyin): improve the priority of retraction
|
322
359
|
sorted_indices.sort(
|
323
360
|
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
324
361
|
reverse=True,
|
@@ -332,25 +369,34 @@ class Batch:
|
|
332
369
|
req = self.reqs[idx]
|
333
370
|
retracted_reqs.append(req)
|
334
371
|
|
335
|
-
|
372
|
+
# TODO: apply more fine-grained retraction
|
373
|
+
last_uncached_pos = len(req.prefix_indices)
|
374
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
375
|
+
req_pool_indices_cpu[idx]
|
376
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
377
|
+
self.token_to_kv_pool.dec_refs(token_indices)
|
378
|
+
|
379
|
+
# release the last node
|
380
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
381
|
+
|
382
|
+
cur_output_str = req.partial_decode(req.output_ids)
|
383
|
+
req.prev_output_str = req.prev_output_str + cur_output_str
|
384
|
+
req.prev_output_ids.extend(req.output_ids)
|
385
|
+
|
336
386
|
req.prefix_indices = None
|
337
387
|
req.last_node = None
|
338
388
|
req.extend_input_len = 0
|
339
389
|
req.output_ids = []
|
340
|
-
req.regex_fsm_state = 0
|
341
|
-
|
342
|
-
# TODO: apply more fine-grained retraction
|
343
390
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
self.token_to_kv_pool.dec_refs(token_indices)
|
391
|
+
# For incremental logprobs
|
392
|
+
req.last_update_decode_tokens = 0
|
393
|
+
req.logprob_start_len = 10**9
|
348
394
|
|
349
395
|
self.filter_batch(sorted_indices)
|
350
396
|
|
351
397
|
return retracted_reqs
|
352
398
|
|
353
|
-
def check_for_jump_forward(self):
|
399
|
+
def check_for_jump_forward(self, model_runner):
|
354
400
|
jump_forward_reqs = []
|
355
401
|
filter_indices = [i for i in range(len(self.reqs))]
|
356
402
|
|
@@ -364,24 +410,34 @@ class Batch:
|
|
364
410
|
if len(jump_forward_str) <= 1:
|
365
411
|
continue
|
366
412
|
|
367
|
-
# insert the old request into tree_cache
|
368
|
-
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
369
413
|
if req_pool_indices_cpu is None:
|
370
414
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
415
|
+
|
416
|
+
# insert the old request into tree_cache
|
417
|
+
self.tree_cache.cache_req(
|
418
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
419
|
+
last_uncached_pos=len(req.prefix_indices),
|
420
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
377
421
|
)
|
378
|
-
|
379
|
-
|
380
|
-
self.tree_cache.
|
422
|
+
|
423
|
+
# unlock the last node
|
424
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
381
425
|
|
382
426
|
# jump-forward
|
383
427
|
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
384
428
|
|
429
|
+
# re-applying image padding
|
430
|
+
if req.pixel_values is not None:
|
431
|
+
(
|
432
|
+
req.origin_input_ids,
|
433
|
+
req.image_offset,
|
434
|
+
) = model_runner.model.pad_input_ids(
|
435
|
+
req.origin_input_ids_unpadded,
|
436
|
+
req.pad_value,
|
437
|
+
req.pixel_values.shape,
|
438
|
+
req.image_size,
|
439
|
+
)
|
440
|
+
|
385
441
|
jump_forward_reqs.append(req)
|
386
442
|
filter_indices.remove(i)
|
387
443
|
|
@@ -5,10 +5,10 @@ import uvloop
|
|
5
5
|
import zmq
|
6
6
|
import zmq.asyncio
|
7
7
|
|
8
|
-
from sglang.
|
8
|
+
from sglang.global_config import global_config
|
9
9
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
10
10
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
|
-
from sglang.
|
11
|
+
from sglang.utils import get_exception_traceback
|
12
12
|
|
13
13
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
14
14
|
|
@@ -30,7 +30,7 @@ class RouterManager:
|
|
30
30
|
self.recv_reqs = []
|
31
31
|
|
32
32
|
# Init some configs
|
33
|
-
self.
|
33
|
+
self.request_dependency_time = global_config.request_dependency_time
|
34
34
|
|
35
35
|
async def loop_for_forward(self):
|
36
36
|
while True:
|
@@ -46,9 +46,9 @@ class RouterManager:
|
|
46
46
|
if len(out_pyobjs) != 0:
|
47
47
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
48
48
|
if has_finished:
|
49
|
-
if self.
|
49
|
+
if self.request_dependency_time > 0:
|
50
50
|
slept = True
|
51
|
-
await asyncio.sleep(self.
|
51
|
+
await asyncio.sleep(self.request_dependency_time)
|
52
52
|
|
53
53
|
if not slept:
|
54
54
|
await asyncio.sleep(0.0006)
|
@@ -60,9 +60,7 @@ class RouterManager:
|
|
60
60
|
|
61
61
|
|
62
62
|
def start_router_process(
|
63
|
-
server_args: ServerArgs,
|
64
|
-
port_args: PortArgs,
|
65
|
-
pipe_writer,
|
63
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
66
64
|
):
|
67
65
|
logging.basicConfig(
|
68
66
|
level=getattr(logging, server_args.log_level.upper()),
|
@@ -70,7 +68,7 @@ def start_router_process(
|
|
70
68
|
)
|
71
69
|
|
72
70
|
try:
|
73
|
-
model_client = ModelRpcClient(server_args, port_args)
|
71
|
+
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
74
72
|
router = RouterManager(model_client, port_args)
|
75
73
|
except Exception:
|
76
74
|
pipe_writer.send(get_exception_traceback())
|