sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +5 -0
- sglang/global_config.py +4 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +52 -19
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +8 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/router/infer_batch.py +31 -19
- sglang/srt/managers/router/manager.py +6 -8
- sglang/srt/managers/router/model_rpc.py +59 -23
- sglang/srt/managers/router/model_runner.py +6 -6
- sglang/srt/managers/router/radix_cache.py +47 -17
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +54 -22
- sglang/srt/model_config.py +4 -0
- sglang/srt/models/commandr.py +6 -10
- sglang/srt/models/dbrx.py +14 -15
- sglang/srt/models/gemma.py +7 -10
- sglang/srt/models/llama2.py +7 -10
- sglang/srt/models/llava.py +2 -6
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +7 -13
- sglang/srt/models/qwen.py +20 -13
- sglang/srt/models/qwen2.py +7 -10
- sglang/srt/models/stablelm.py +13 -12
- sglang/srt/models/yivl.py +1 -4
- sglang/srt/server.py +32 -18
- sglang/srt/server_args.py +9 -6
- sglang/srt/utils.py +126 -17
- sglang/srt/weight_utils.py +66 -51
- sglang/utils.py +77 -26
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.16"
|
2
2
|
|
3
3
|
# SGL API Components
|
4
4
|
from sglang.api import (
|
@@ -19,6 +19,7 @@ from sglang.api import (
|
|
19
19
|
user,
|
20
20
|
user_begin,
|
21
21
|
user_end,
|
22
|
+
video,
|
22
23
|
)
|
23
24
|
|
24
25
|
# SGL Backends
|
@@ -46,6 +47,7 @@ __all__ = [
|
|
46
47
|
"gen_int",
|
47
48
|
"gen_string",
|
48
49
|
"image",
|
50
|
+
"video",
|
49
51
|
"select",
|
50
52
|
"system",
|
51
53
|
"user",
|
sglang/api.py
CHANGED
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
|
|
15
15
|
SglRoleBegin,
|
16
16
|
SglRoleEnd,
|
17
17
|
SglSelect,
|
18
|
+
SglVideo,
|
18
19
|
)
|
19
20
|
|
20
21
|
|
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
|
|
151
152
|
return SglImage(expr)
|
152
153
|
|
153
154
|
|
155
|
+
def video(path: str, num_frames: int):
|
156
|
+
return SglVideo(path, num_frames)
|
157
|
+
|
158
|
+
|
154
159
|
def select(
|
155
160
|
name: Optional[str] = None,
|
156
161
|
choices: List[str] = None,
|
sglang/global_config.py
CHANGED
@@ -16,7 +16,7 @@ class GlobalConfig:
|
|
16
16
|
|
17
17
|
# Optimization configs
|
18
18
|
self.eager_fill_image = False
|
19
|
-
self.
|
19
|
+
self.enable_precache_with_tracing = True
|
20
20
|
self.enable_parallel_encoding = True
|
21
21
|
self.enable_parallel_decoding = True
|
22
22
|
|
@@ -25,5 +25,8 @@ class GlobalConfig:
|
|
25
25
|
# adjust_cache: Adjust the position embedding of KV cache.
|
26
26
|
self.concate_and_append_mode = "no_adjust"
|
27
27
|
|
28
|
+
# Request dependency time due to network delay
|
29
|
+
self.request_dependency_time = 0.03
|
30
|
+
|
28
31
|
|
29
32
|
global_config = GlobalConfig()
|
sglang/lang/chat_template.py
CHANGED
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
|
|
259
259
|
return get_chat_template("vicuna_v1.1")
|
260
260
|
if "llava-v1.5" in model_path.lower():
|
261
261
|
return get_chat_template("vicuna_v1.1")
|
262
|
+
if "llava-next-video-7b" in model_path.lower():
|
263
|
+
return get_chat_template("vicuna_v1.1")
|
262
264
|
|
263
265
|
|
264
266
|
@register_chat_template_matching_function
|
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
|
|
283
285
|
|
284
286
|
@register_chat_template_matching_function
|
285
287
|
def match_chat_ml(model_path: str):
|
288
|
+
# import pdb;pdb.set_trace()
|
286
289
|
model_path = model_path.lower()
|
287
290
|
if "tinyllama" in model_path:
|
288
291
|
return get_chat_template("chatml")
|
289
292
|
if "qwen" in model_path and "chat" in model_path:
|
290
293
|
return get_chat_template("chatml")
|
291
|
-
if
|
294
|
+
if (
|
295
|
+
"llava-v1.6-34b" in model_path
|
296
|
+
or "llava-v1.6-yi-34b" in model_path
|
297
|
+
or "llava-next-video-34b" in model_path
|
298
|
+
):
|
292
299
|
return get_chat_template("chatml-llava")
|
293
300
|
|
294
301
|
|
295
302
|
@register_chat_template_matching_function
|
296
303
|
def match_chat_yi(model_path: str):
|
297
304
|
model_path = model_path.lower()
|
298
|
-
if "yi" in model_path:
|
305
|
+
if "yi" in model_path and "llava" not in model_path:
|
299
306
|
return get_chat_template("yi")
|
300
307
|
|
301
308
|
|
sglang/lang/interpreter.py
CHANGED
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
|
|
28
28
|
SglVariable,
|
29
29
|
SglVarScopeBegin,
|
30
30
|
SglVarScopeEnd,
|
31
|
+
SglVideo,
|
31
32
|
)
|
32
|
-
from sglang.utils import encode_image_base64
|
33
|
+
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
33
34
|
|
34
35
|
|
35
36
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -86,9 +87,9 @@ def run_program_batch(
|
|
86
87
|
if hasattr(backend, "endpoint"):
|
87
88
|
backend = backend.endpoint
|
88
89
|
|
89
|
-
#
|
90
|
-
if len(batch_arguments) > 1:
|
91
|
-
|
90
|
+
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
91
|
+
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
92
|
+
cache_program(program, backend)
|
92
93
|
|
93
94
|
# Run all programs
|
94
95
|
if num_threads == "auto":
|
@@ -154,21 +155,12 @@ def run_program_batch(
|
|
154
155
|
return rets
|
155
156
|
|
156
157
|
|
157
|
-
def
|
158
|
-
|
159
|
-
# TODO: handle multiple backends
|
160
|
-
from sglang.lang.tracer import extract_prefix_by_tracing
|
158
|
+
def cache_program(program, backend):
|
159
|
+
from sglang.lang.tracer import extract_prefix_by_tracing
|
161
160
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
program.pin_prefix_rid = prefix_rid
|
166
|
-
return prefix_rid
|
167
|
-
return None
|
168
|
-
|
169
|
-
|
170
|
-
def unpin_program(program, backend):
|
171
|
-
pass
|
161
|
+
prefix = extract_prefix_by_tracing(program, backend)
|
162
|
+
if prefix and len(prefix) > 64:
|
163
|
+
backend.cache_prefix(prefix)
|
172
164
|
|
173
165
|
|
174
166
|
class StreamExecutor:
|
@@ -195,6 +187,7 @@ class StreamExecutor:
|
|
195
187
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
196
188
|
self.meta_info = {} # Dict[name: str -> info: str]
|
197
189
|
self.is_finished = False
|
190
|
+
self.error = None
|
198
191
|
|
199
192
|
# For completion
|
200
193
|
self.text_ = "" # The full text
|
@@ -310,17 +303,39 @@ class StreamExecutor:
|
|
310
303
|
self.backend.end_program(self)
|
311
304
|
|
312
305
|
def _thread_worker_func(self):
|
306
|
+
error = None
|
307
|
+
|
313
308
|
while True:
|
314
309
|
expr = self.queue.get()
|
315
310
|
if expr is None:
|
316
311
|
self.queue.task_done()
|
317
312
|
break
|
318
313
|
|
319
|
-
|
314
|
+
try:
|
315
|
+
self._execute(expr)
|
316
|
+
except Exception as e:
|
317
|
+
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
318
|
+
error = e
|
319
|
+
break
|
320
320
|
self.queue.task_done()
|
321
321
|
if self.stream_text_event:
|
322
322
|
self.stream_text_event.set()
|
323
323
|
|
324
|
+
# Clean the queue and events
|
325
|
+
if error is not None:
|
326
|
+
try:
|
327
|
+
while True:
|
328
|
+
self.queue.task_done()
|
329
|
+
self.queue.get_nowait()
|
330
|
+
except queue.Empty:
|
331
|
+
pass
|
332
|
+
for name in self.variable_event:
|
333
|
+
self.variable_event[name].set()
|
334
|
+
if self.stream_var_event:
|
335
|
+
for name in self.stream_var_event:
|
336
|
+
self.stream_var_event[name].set()
|
337
|
+
self.error = error
|
338
|
+
|
324
339
|
if self.stream_text_event:
|
325
340
|
self.stream_text_event.set()
|
326
341
|
|
@@ -347,6 +362,8 @@ class StreamExecutor:
|
|
347
362
|
self._execute_role_end(other)
|
348
363
|
elif isinstance(other, SglImage):
|
349
364
|
self._execute_image(other)
|
365
|
+
elif isinstance(other, SglVideo):
|
366
|
+
self._execute_video(other)
|
350
367
|
elif isinstance(other, SglVariable):
|
351
368
|
self._execute_variable(other)
|
352
369
|
elif isinstance(other, SglVarScopeBegin):
|
@@ -383,6 +400,16 @@ class StreamExecutor:
|
|
383
400
|
self.cur_images.append((path, base64_data))
|
384
401
|
self.text_ += self.chat_template.image_token
|
385
402
|
|
403
|
+
def _execute_video(self, expr: SglVideo):
|
404
|
+
path = expr.path
|
405
|
+
num_frames = expr.num_frames
|
406
|
+
|
407
|
+
base64_data = encode_video_base64(path, num_frames)
|
408
|
+
|
409
|
+
self.images_.append((path, base64_data))
|
410
|
+
self.cur_images.append((path, base64_data))
|
411
|
+
self.text_ += self.chat_template.image_token
|
412
|
+
|
386
413
|
# if global_config.eager_fill_image:
|
387
414
|
# self.backend.fill_image(self)
|
388
415
|
|
@@ -681,6 +708,9 @@ class ProgramState:
|
|
681
708
|
def sync(self):
|
682
709
|
return self.stream_executor.sync()
|
683
710
|
|
711
|
+
def error(self):
|
712
|
+
return self.stream_executor.error
|
713
|
+
|
684
714
|
def text_iter(self, var_name: Optional[str] = None):
|
685
715
|
if self.stream_executor.stream:
|
686
716
|
prev = 0
|
@@ -769,6 +799,9 @@ class ProgramState:
|
|
769
799
|
def __setitem__(self, name, value):
|
770
800
|
self.set_var(name, value)
|
771
801
|
|
802
|
+
def __contains__(self, name):
|
803
|
+
return name in self.stream_executor.variables
|
804
|
+
|
772
805
|
def __del__(self):
|
773
806
|
self.stream_executor.end()
|
774
807
|
|
sglang/lang/ir.py
CHANGED
@@ -193,17 +193,11 @@ class SglFunction:
|
|
193
193
|
backend = backend or global_config.default_backend
|
194
194
|
return trace_program(self, kwargs, backend)
|
195
195
|
|
196
|
-
def
|
197
|
-
from sglang.lang.interpreter import
|
196
|
+
def cache(self, backend=None):
|
197
|
+
from sglang.lang.interpreter import cache_program
|
198
198
|
|
199
199
|
backend = backend or global_config.default_backend
|
200
|
-
return
|
201
|
-
|
202
|
-
def unpin(self, backend=None):
|
203
|
-
from sglang.lang.interpreter import unpin_program
|
204
|
-
|
205
|
-
backend = backend or global_config.default_backend
|
206
|
-
return unpin_program(self, backend)
|
200
|
+
return cache_program(self, backend)
|
207
201
|
|
208
202
|
def compile(self, *, backend=None):
|
209
203
|
from sglang.lang.compiler import compile_func
|
@@ -336,6 +330,15 @@ class SglImage(SglExpr):
|
|
336
330
|
return f"SglImage({self.path})"
|
337
331
|
|
338
332
|
|
333
|
+
class SglVideo(SglExpr):
|
334
|
+
def __init__(self, path, num_frames):
|
335
|
+
self.path = path
|
336
|
+
self.num_frames = num_frames
|
337
|
+
|
338
|
+
def __repr__(self) -> str:
|
339
|
+
return f"SglVideo({self.path}, {self.num_frames})"
|
340
|
+
|
341
|
+
|
339
342
|
class SglGen(SglExpr):
|
340
343
|
def __init__(
|
341
344
|
self,
|
sglang/lang/tracer.py
CHANGED
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
|
|
110
110
|
##################################
|
111
111
|
|
112
112
|
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
113
|
-
assert
|
113
|
+
assert size >= 1
|
114
114
|
|
115
115
|
if self.only_trace_prefix:
|
116
116
|
raise StopTracing()
|
sglang/launch_server.py
CHANGED
@@ -2,11 +2,10 @@ import argparse
|
|
2
2
|
|
3
3
|
from sglang.srt.server import ServerArgs, launch_server
|
4
4
|
|
5
|
-
|
6
5
|
if __name__ == "__main__":
|
7
6
|
parser = argparse.ArgumentParser()
|
8
7
|
ServerArgs.add_cli_args(parser)
|
9
8
|
args = parser.parse_args()
|
10
9
|
server_args = ServerArgs.from_cli_args(args)
|
11
10
|
|
12
|
-
launch_server(server_args, None)
|
11
|
+
launch_server(server_args, None)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import argparse
|
2
|
+
import multiprocessing as mp
|
3
|
+
|
4
|
+
from sglang.srt.server import ServerArgs, launch_server
|
5
|
+
|
6
|
+
if __name__ == "__main__":
|
7
|
+
|
8
|
+
model_overide_args = {}
|
9
|
+
|
10
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
11
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
12
|
+
model_overide_args["num_frames"] = 16
|
13
|
+
model_overide_args["model_type"] = "llavavid"
|
14
|
+
if model_overide_args["num_frames"] == 32:
|
15
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
16
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
17
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
18
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
19
|
+
|
20
|
+
parser = argparse.ArgumentParser()
|
21
|
+
ServerArgs.add_cli_args(parser)
|
22
|
+
args = parser.parse_args()
|
23
|
+
|
24
|
+
if "34b" in args.model_path.lower():
|
25
|
+
model_overide_args["image_token_index"] = 64002
|
26
|
+
|
27
|
+
server_args = ServerArgs.from_cli_args(args)
|
28
|
+
|
29
|
+
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
30
|
+
|
31
|
+
launch_server(server_args, pipe_writer, model_overide_args)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
|
8
|
+
import requests
|
9
|
+
|
10
|
+
if __name__ == "__main__":
|
11
|
+
parser = argparse.ArgumentParser()
|
12
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
13
|
+
args = parser.parse_args()
|
14
|
+
|
15
|
+
response = requests.get(args.url + "/flush_cache")
|
16
|
+
assert response.status_code == 200
|
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
|
|
30
30
|
return config
|
31
31
|
|
32
32
|
|
33
|
-
def get_config(
|
33
|
+
def get_config(
|
34
|
+
model: str,
|
35
|
+
trust_remote_code: bool,
|
36
|
+
revision: Optional[str] = None,
|
37
|
+
model_overide_args: Optional[dict] = None,
|
38
|
+
):
|
34
39
|
config = AutoConfig.from_pretrained(
|
35
40
|
model, trust_remote_code=trust_remote_code, revision=revision
|
36
41
|
)
|
42
|
+
if model_overide_args:
|
43
|
+
config.update(model_overide_args)
|
37
44
|
return config
|
38
45
|
|
39
46
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
|
|
8
8
|
@dataclass
|
9
9
|
class GenerateReqInput:
|
10
10
|
# The input prompt
|
11
|
-
text: Union[List[str], str]
|
11
|
+
text: Optional[Union[List[str], str]] = None
|
12
|
+
# The token ids for text; one can either specify text or input_ids
|
13
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
12
14
|
# The image input
|
13
15
|
image_data: Optional[Union[List[str], str]] = None
|
14
16
|
# The sampling_params
|
@@ -28,7 +30,17 @@ class GenerateReqInput:
|
|
28
30
|
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
29
31
|
|
30
32
|
def post_init(self):
|
31
|
-
|
33
|
+
|
34
|
+
if self.text is None:
|
35
|
+
assert self.input_ids is not None, "Either text or input_ids should be provided"
|
36
|
+
else:
|
37
|
+
assert self.input_ids is None, "Either text or input_ids should be provided"
|
38
|
+
|
39
|
+
if self.text is not None:
|
40
|
+
is_single = isinstance(self.text, str)
|
41
|
+
else:
|
42
|
+
is_single = isinstance(self.input_ids[0], int)
|
43
|
+
self.is_single = is_single
|
32
44
|
|
33
45
|
if is_single:
|
34
46
|
if self.sampling_params is None:
|
@@ -42,7 +54,7 @@ class GenerateReqInput:
|
|
42
54
|
if self.top_logprobs_num is None:
|
43
55
|
self.top_logprobs_num = 0
|
44
56
|
else:
|
45
|
-
num = len(self.text)
|
57
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
46
58
|
|
47
59
|
if self.image_data is None:
|
48
60
|
self.image_data = [None] * num
|
@@ -20,6 +20,17 @@ class FinishReason(IntEnum):
|
|
20
20
|
LENGTH = auto()
|
21
21
|
STOP_STR = auto()
|
22
22
|
|
23
|
+
@staticmethod
|
24
|
+
def to_str(reason):
|
25
|
+
if reason == FinishReason.EOS_TOKEN:
|
26
|
+
return None
|
27
|
+
elif reason == FinishReason.LENGTH:
|
28
|
+
return "length"
|
29
|
+
elif reason == FinishReason.STOP_STR:
|
30
|
+
return "stop"
|
31
|
+
else:
|
32
|
+
return None
|
33
|
+
|
23
34
|
|
24
35
|
class Req:
|
25
36
|
def __init__(self, rid, input_text, input_ids):
|
@@ -85,6 +96,9 @@ class Req:
|
|
85
96
|
)
|
86
97
|
if first_token.startswith("▁"):
|
87
98
|
old_output_str = " " + old_output_str
|
99
|
+
if self.input_text is None:
|
100
|
+
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
101
|
+
self.input_text = self.tokenizer.decode(self.input_ids)
|
88
102
|
new_input_string = (
|
89
103
|
self.input_text
|
90
104
|
+ self.output_and_jump_forward_str
|
@@ -332,20 +346,20 @@ class Batch:
|
|
332
346
|
req = self.reqs[idx]
|
333
347
|
retracted_reqs.append(req)
|
334
348
|
|
335
|
-
|
349
|
+
# TODO: apply more fine-grained retraction
|
350
|
+
last_uncached_pos = len(req.prefix_indices)
|
351
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
352
|
+
req_pool_indices_cpu[idx]
|
353
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
354
|
+
self.token_to_kv_pool.dec_refs(token_indices)
|
355
|
+
|
356
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
336
357
|
req.prefix_indices = None
|
337
358
|
req.last_node = None
|
338
359
|
req.extend_input_len = 0
|
339
360
|
req.output_ids = []
|
340
361
|
req.regex_fsm_state = 0
|
341
362
|
|
342
|
-
# TODO: apply more fine-grained retraction
|
343
|
-
|
344
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
345
|
-
req_pool_indices_cpu[idx]
|
346
|
-
][: seq_lens_cpu[idx]]
|
347
|
-
self.token_to_kv_pool.dec_refs(token_indices)
|
348
|
-
|
349
363
|
self.filter_batch(sorted_indices)
|
350
364
|
|
351
365
|
return retracted_reqs
|
@@ -364,20 +378,18 @@ class Batch:
|
|
364
378
|
if len(jump_forward_str) <= 1:
|
365
379
|
continue
|
366
380
|
|
367
|
-
# insert the old request into tree_cache
|
368
|
-
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
369
381
|
if req_pool_indices_cpu is None:
|
370
382
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
383
|
+
|
384
|
+
# insert the old request into tree_cache
|
385
|
+
self.tree_cache.cache_req(
|
386
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
387
|
+
last_uncached_pos=len(req.prefix_indices),
|
388
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
377
389
|
)
|
378
|
-
|
379
|
-
|
380
|
-
self.tree_cache.
|
390
|
+
|
391
|
+
# unlock the last node
|
392
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
381
393
|
|
382
394
|
# jump-forward
|
383
395
|
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
@@ -5,7 +5,7 @@ import uvloop
|
|
5
5
|
import zmq
|
6
6
|
import zmq.asyncio
|
7
7
|
|
8
|
-
from sglang.
|
8
|
+
from sglang.global_config import global_config
|
9
9
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
10
10
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
11
|
from sglang.srt.utils import get_exception_traceback
|
@@ -30,7 +30,7 @@ class RouterManager:
|
|
30
30
|
self.recv_reqs = []
|
31
31
|
|
32
32
|
# Init some configs
|
33
|
-
self.
|
33
|
+
self.request_dependency_time = global_config.request_dependency_time
|
34
34
|
|
35
35
|
async def loop_for_forward(self):
|
36
36
|
while True:
|
@@ -46,9 +46,9 @@ class RouterManager:
|
|
46
46
|
if len(out_pyobjs) != 0:
|
47
47
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
48
48
|
if has_finished:
|
49
|
-
if self.
|
49
|
+
if self.request_dependency_time > 0:
|
50
50
|
slept = True
|
51
|
-
await asyncio.sleep(self.
|
51
|
+
await asyncio.sleep(self.request_dependency_time)
|
52
52
|
|
53
53
|
if not slept:
|
54
54
|
await asyncio.sleep(0.0006)
|
@@ -60,9 +60,7 @@ class RouterManager:
|
|
60
60
|
|
61
61
|
|
62
62
|
def start_router_process(
|
63
|
-
server_args: ServerArgs,
|
64
|
-
port_args: PortArgs,
|
65
|
-
pipe_writer,
|
63
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
66
64
|
):
|
67
65
|
logging.basicConfig(
|
68
66
|
level=getattr(logging, server_args.log_level.upper()),
|
@@ -70,7 +68,7 @@ def start_router_process(
|
|
70
68
|
)
|
71
69
|
|
72
70
|
try:
|
73
|
-
model_client = ModelRpcClient(server_args, port_args)
|
71
|
+
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
74
72
|
router = RouterManager(model_client, port_args)
|
75
73
|
except Exception:
|
76
74
|
pipe_writer.send(get_exception_traceback())
|