sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
|
|
1
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
|
+
|
1
3
|
import asyncio
|
4
|
+
import inspect
|
2
5
|
|
3
6
|
import uvloop
|
4
7
|
import zmq
|
5
8
|
import zmq.asyncio
|
9
|
+
|
6
10
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
11
|
+
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
7
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
8
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
9
|
-
from sglang.
|
14
|
+
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
10
15
|
|
11
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
12
17
|
|
@@ -32,48 +37,43 @@ class DetokenizerManager:
|
|
32
37
|
|
33
38
|
async def handle_loop(self):
|
34
39
|
while True:
|
35
|
-
recv_obj = await self.recv_from_router.recv_pyobj()
|
40
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
41
|
+
assert isinstance(recv_obj, BatchTokenIDOut)
|
36
42
|
|
37
|
-
|
38
|
-
|
43
|
+
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
|
+
surr_texts = self.tokenizer.batch_decode(
|
45
|
+
recv_obj.surr_output_ids,
|
46
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
|
+
)
|
49
|
+
read_texts = self.tokenizer.batch_decode(
|
50
|
+
recv_obj.read_output_ids,
|
51
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
52
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
53
|
+
)
|
39
54
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
)
|
55
|
+
# Trim stop str
|
56
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
57
|
+
output_strs = []
|
58
|
+
for i in range(len(recv_obj.rids)):
|
59
|
+
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
if recv_obj.finished_reason[i] is None:
|
61
|
+
new_text = find_printable_text(new_text)
|
62
|
+
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
45
63
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
)
|
58
|
-
if not isinstance(first_token, str):
|
59
|
-
first_token = first_token.decode("utf-8", errors="ignore")
|
60
|
-
if first_token.startswith("▁"):
|
61
|
-
output_strs[i] = " " + output_strs[i]
|
62
|
-
|
63
|
-
output_strs[i] = (
|
64
|
-
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
|
65
|
-
)
|
66
|
-
|
67
|
-
self.send_to_tokenizer.send_pyobj(
|
68
|
-
BatchStrOut(
|
69
|
-
recv_obj.rids,
|
70
|
-
output_strs,
|
71
|
-
recv_obj.meta_info,
|
72
|
-
recv_obj.finished,
|
73
|
-
)
|
64
|
+
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
65
|
+
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
66
|
+
if pos != -1:
|
67
|
+
output_strs[i] = output_strs[i][:pos]
|
68
|
+
|
69
|
+
self.send_to_tokenizer.send_pyobj(
|
70
|
+
BatchStrOut(
|
71
|
+
rids=recv_obj.rids,
|
72
|
+
output_strs=output_strs,
|
73
|
+
meta_info=recv_obj.meta_info,
|
74
|
+
finished_reason=recv_obj.finished_reason,
|
74
75
|
)
|
75
|
-
|
76
|
-
raise ValueError(f"Invalid object: {recv_obj}")
|
76
|
+
)
|
77
77
|
|
78
78
|
|
79
79
|
def start_detokenizer_process(
|
@@ -81,9 +81,11 @@ def start_detokenizer_process(
|
|
81
81
|
port_args: PortArgs,
|
82
82
|
pipe_writer,
|
83
83
|
):
|
84
|
+
graceful_registry(inspect.currentframe().f_code.co_name)
|
85
|
+
|
84
86
|
try:
|
85
87
|
manager = DetokenizerManager(server_args, port_args)
|
86
|
-
except Exception
|
88
|
+
except Exception:
|
87
89
|
pipe_writer.send(get_exception_traceback())
|
88
90
|
raise
|
89
91
|
pipe_writer.send("init ok")
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,14 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
The definition of objects transfered between different
|
3
|
+
processes (TokenizerManager, DetokenizerManager, Controller).
|
4
|
+
"""
|
5
|
+
|
1
6
|
import uuid
|
2
7
|
from dataclasses import dataclass
|
3
8
|
from typing import Dict, List, Optional, Union
|
4
9
|
|
10
|
+
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
5
11
|
from sglang.srt.sampling_params import SamplingParams
|
6
12
|
|
7
13
|
|
8
14
|
@dataclass
|
9
15
|
class GenerateReqInput:
|
10
16
|
# The input prompt
|
11
|
-
text: Union[List[str], str]
|
17
|
+
text: Optional[Union[List[str], str]] = None
|
18
|
+
# The token ids for text; one can either specify text or input_ids
|
19
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
12
20
|
# The image input
|
13
21
|
image_data: Optional[Union[List[str], str]] = None
|
14
22
|
# The sampling_params
|
@@ -19,13 +27,24 @@ class GenerateReqInput:
|
|
19
27
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
20
28
|
# The start location of the prompt for return_logprob
|
21
29
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
30
|
+
# The number of top logprobs to return
|
31
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None
|
22
32
|
# Whether to detokenize tokens in logprobs
|
23
33
|
return_text_in_logprobs: bool = False
|
24
34
|
# Whether to stream output
|
25
35
|
stream: bool = False
|
26
36
|
|
27
37
|
def post_init(self):
|
28
|
-
|
38
|
+
if (self.text is None and self.input_ids is None) or (
|
39
|
+
self.text is not None and self.input_ids is not None
|
40
|
+
):
|
41
|
+
raise ValueError("Either text or input_ids should be provided.")
|
42
|
+
|
43
|
+
if self.text is not None:
|
44
|
+
is_single = isinstance(self.text, str)
|
45
|
+
else:
|
46
|
+
is_single = isinstance(self.input_ids[0], int)
|
47
|
+
self.is_single = is_single
|
29
48
|
|
30
49
|
if is_single:
|
31
50
|
if self.sampling_params is None:
|
@@ -36,8 +55,10 @@ class GenerateReqInput:
|
|
36
55
|
self.return_logprob = False
|
37
56
|
if self.logprob_start_len is None:
|
38
57
|
self.logprob_start_len = 0
|
58
|
+
if self.top_logprobs_num is None:
|
59
|
+
self.top_logprobs_num = 0
|
39
60
|
else:
|
40
|
-
num = len(self.text)
|
61
|
+
num = len(self.text) if self.text is not None else len(self.input_ids)
|
41
62
|
|
42
63
|
if self.image_data is None:
|
43
64
|
self.image_data = [None] * num
|
@@ -52,7 +73,8 @@ class GenerateReqInput:
|
|
52
73
|
if self.rid is None:
|
53
74
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
54
75
|
else:
|
55
|
-
|
76
|
+
if not isinstance(self.rid, list):
|
77
|
+
raise ValueError("The rid should be a list.")
|
56
78
|
|
57
79
|
if self.return_logprob is None:
|
58
80
|
self.return_logprob = [False] * num
|
@@ -64,6 +86,11 @@ class GenerateReqInput:
|
|
64
86
|
elif not isinstance(self.logprob_start_len, list):
|
65
87
|
self.logprob_start_len = [self.logprob_start_len] * num
|
66
88
|
|
89
|
+
if self.top_logprobs_num is None:
|
90
|
+
self.top_logprobs_num = [0] * num
|
91
|
+
elif not isinstance(self.top_logprobs_num, list):
|
92
|
+
self.top_logprobs_num = [self.top_logprobs_num] * num
|
93
|
+
|
67
94
|
|
68
95
|
@dataclass
|
69
96
|
class TokenizedGenerateReqInput:
|
@@ -76,26 +103,28 @@ class TokenizedGenerateReqInput:
|
|
76
103
|
sampling_params: SamplingParams
|
77
104
|
return_logprob: bool
|
78
105
|
logprob_start_len: int
|
106
|
+
top_logprobs_num: int
|
79
107
|
stream: bool
|
80
108
|
|
81
109
|
|
82
110
|
@dataclass
|
83
111
|
class BatchTokenIDOut:
|
84
112
|
rids: List[str]
|
85
|
-
|
86
|
-
|
87
|
-
|
113
|
+
decoded_texts: List[str]
|
114
|
+
surr_output_ids: List[List[int]]
|
115
|
+
read_output_ids: List[List[int]]
|
88
116
|
skip_special_tokens: List[bool]
|
117
|
+
spaces_between_special_tokens: List[bool]
|
89
118
|
meta_info: List[Dict]
|
90
|
-
|
119
|
+
finished_reason: List[BaseFinishReason]
|
91
120
|
|
92
121
|
|
93
122
|
@dataclass
|
94
123
|
class BatchStrOut:
|
95
124
|
rids: List[str]
|
96
|
-
|
125
|
+
output_strs: List[str]
|
97
126
|
meta_info: List[Dict]
|
98
|
-
|
127
|
+
finished_reason: List[BaseFinishReason]
|
99
128
|
|
100
129
|
|
101
130
|
@dataclass
|
@@ -103,6 +132,11 @@ class FlushCacheReq:
|
|
103
132
|
pass
|
104
133
|
|
105
134
|
|
135
|
+
@dataclass
|
136
|
+
class AbortReq:
|
137
|
+
rid: str
|
138
|
+
|
139
|
+
|
106
140
|
@dataclass
|
107
141
|
class DetokenizeReqInput:
|
108
142
|
input_ids: List[int]
|
@@ -1,15 +1,20 @@
|
|
1
|
+
"""TokenizerManager is a process that tokenizes the text."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import concurrent.futures
|
3
5
|
import dataclasses
|
6
|
+
import logging
|
4
7
|
import multiprocessing as mp
|
5
8
|
import os
|
6
|
-
from typing import List
|
9
|
+
from typing import Dict, List
|
7
10
|
|
8
11
|
import numpy as np
|
9
12
|
import transformers
|
10
13
|
import uvloop
|
11
14
|
import zmq
|
12
15
|
import zmq.asyncio
|
16
|
+
from fastapi import BackgroundTasks
|
17
|
+
|
13
18
|
from sglang.srt.hf_transformers_utils import (
|
14
19
|
get_config,
|
15
20
|
get_context_length,
|
@@ -17,8 +22,9 @@ from sglang.srt.hf_transformers_utils import (
|
|
17
22
|
get_tokenizer,
|
18
23
|
)
|
19
24
|
from sglang.srt.managers.io_struct import (
|
25
|
+
AbortReq,
|
20
26
|
BatchStrOut,
|
21
|
-
|
27
|
+
BatchTokenIDOut,
|
22
28
|
FlushCacheReq,
|
23
29
|
GenerateReqInput,
|
24
30
|
TokenizedGenerateReqInput,
|
@@ -26,54 +32,19 @@ from sglang.srt.managers.io_struct import (
|
|
26
32
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
27
33
|
from sglang.srt.sampling_params import SamplingParams
|
28
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
29
|
-
from sglang.srt.utils import
|
35
|
+
from sglang.srt.utils import is_multimodal_model, load_image
|
36
|
+
from sglang.utils import get_exception_traceback
|
30
37
|
|
31
38
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
32
39
|
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
33
42
|
|
34
43
|
@dataclasses.dataclass
|
35
44
|
class ReqState:
|
36
45
|
out_list: List
|
37
46
|
finished: bool
|
38
47
|
event: asyncio.Event
|
39
|
-
lock: asyncio.Lock
|
40
|
-
|
41
|
-
|
42
|
-
global global_processor
|
43
|
-
|
44
|
-
|
45
|
-
def init_global_processor(server_args: ServerArgs):
|
46
|
-
global global_processor
|
47
|
-
transformers.logging.set_verbosity_error()
|
48
|
-
global_processor = get_processor(
|
49
|
-
server_args.tokenizer_path,
|
50
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
51
|
-
trust_remote_code=server_args.trust_remote_code,
|
52
|
-
)
|
53
|
-
|
54
|
-
|
55
|
-
def get_pixel_values(
|
56
|
-
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
57
|
-
):
|
58
|
-
try:
|
59
|
-
processor = processor or global_processor
|
60
|
-
image = load_image(image_data)
|
61
|
-
image_hash = hash(image_data)
|
62
|
-
if image_aspect_ratio == "pad":
|
63
|
-
image = expand2square(
|
64
|
-
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
65
|
-
)
|
66
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
67
|
-
elif image_aspect_ratio == "anyres":
|
68
|
-
pixel_values = process_anyres_image(
|
69
|
-
image, processor.image_processor, image_grid_pinpoints
|
70
|
-
)
|
71
|
-
else:
|
72
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
73
|
-
pixel_values = pixel_values.astype(np.float16)
|
74
|
-
return pixel_values, image_hash, image.size
|
75
|
-
except Exception:
|
76
|
-
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
77
48
|
|
78
49
|
|
79
50
|
class TokenizerManager:
|
@@ -81,6 +52,7 @@ class TokenizerManager:
|
|
81
52
|
self,
|
82
53
|
server_args: ServerArgs,
|
83
54
|
port_args: PortArgs,
|
55
|
+
model_overide_args: dict = None,
|
84
56
|
):
|
85
57
|
self.server_args = server_args
|
86
58
|
|
@@ -93,9 +65,10 @@ class TokenizerManager:
|
|
93
65
|
|
94
66
|
self.model_path = server_args.model_path
|
95
67
|
self.hf_config = get_config(
|
96
|
-
self.model_path,
|
68
|
+
self.model_path,
|
69
|
+
trust_remote_code=server_args.trust_remote_code,
|
70
|
+
model_overide_args=model_overide_args,
|
97
71
|
)
|
98
|
-
|
99
72
|
self.context_len = get_context_length(self.hf_config)
|
100
73
|
|
101
74
|
if is_multimodal_model(self.model_path):
|
@@ -119,7 +92,7 @@ class TokenizerManager:
|
|
119
92
|
)
|
120
93
|
|
121
94
|
self.to_create_loop = True
|
122
|
-
self.rid_to_state
|
95
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
123
96
|
|
124
97
|
async def get_pixel_values(self, image_data):
|
125
98
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -140,15 +113,26 @@ class TokenizerManager:
|
|
140
113
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
141
114
|
)
|
142
115
|
|
143
|
-
async def generate_request(self, obj: GenerateReqInput):
|
116
|
+
async def generate_request(self, obj: GenerateReqInput, request=None):
|
144
117
|
if self.to_create_loop:
|
145
|
-
|
146
|
-
|
147
|
-
is_single = isinstance(obj.text, str)
|
118
|
+
self.create_handle_loop()
|
148
119
|
|
120
|
+
obj.post_init()
|
121
|
+
is_single = obj.is_single
|
149
122
|
if is_single:
|
150
123
|
rid = obj.rid
|
151
|
-
|
124
|
+
|
125
|
+
if obj.input_ids is None:
|
126
|
+
input_ids = self.tokenizer.encode(obj.text)
|
127
|
+
else:
|
128
|
+
input_ids = obj.input_ids
|
129
|
+
|
130
|
+
if len(input_ids) >= self.context_len:
|
131
|
+
raise ValueError(
|
132
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
133
|
+
f"model's context length ({self.context_len} tokens)."
|
134
|
+
)
|
135
|
+
|
152
136
|
sampling_params = SamplingParams(**obj.sampling_params)
|
153
137
|
if sampling_params.max_new_tokens != 0:
|
154
138
|
sampling_params.normalize(self.tokenizer)
|
@@ -174,29 +158,64 @@ class TokenizerManager:
|
|
174
158
|
sampling_params=sampling_params,
|
175
159
|
return_logprob=obj.return_logprob,
|
176
160
|
logprob_start_len=obj.logprob_start_len,
|
161
|
+
top_logprobs_num=obj.top_logprobs_num,
|
177
162
|
stream=obj.stream,
|
178
163
|
)
|
179
164
|
self.send_to_router.send_pyobj(tokenized_obj)
|
180
165
|
|
181
|
-
lock = asyncio.Lock()
|
182
166
|
event = asyncio.Event()
|
183
|
-
state = ReqState([], False, event
|
167
|
+
state = ReqState([], False, event)
|
184
168
|
self.rid_to_state[rid] = state
|
185
169
|
|
186
170
|
while True:
|
187
|
-
|
188
|
-
|
171
|
+
try:
|
172
|
+
await asyncio.wait_for(event.wait(), timeout=4)
|
173
|
+
except asyncio.TimeoutError:
|
174
|
+
if request is not None and await request.is_disconnected():
|
175
|
+
self.abort_request(rid)
|
176
|
+
raise ValueError(f"Abort request {rid}")
|
177
|
+
continue
|
178
|
+
|
179
|
+
out = self.convert_logprob_style(
|
180
|
+
state.out_list[-1],
|
181
|
+
obj.return_logprob,
|
182
|
+
obj.top_logprobs_num,
|
183
|
+
obj.return_text_in_logprobs,
|
184
|
+
)
|
185
|
+
|
186
|
+
if self.server_args.log_requests and state.finished:
|
187
|
+
logger.info(f"in={obj.text}, out={out}")
|
188
|
+
|
189
189
|
state.out_list = []
|
190
190
|
if state.finished:
|
191
191
|
del self.rid_to_state[rid]
|
192
|
+
|
193
|
+
yield out
|
194
|
+
|
192
195
|
break
|
196
|
+
|
193
197
|
event.clear()
|
198
|
+
|
199
|
+
yield out
|
194
200
|
else:
|
195
|
-
|
196
|
-
|
201
|
+
if obj.stream:
|
202
|
+
raise ValueError("Do not support stream for batch mode.")
|
203
|
+
|
204
|
+
if obj.input_ids is None:
|
205
|
+
bs = len(obj.text)
|
206
|
+
else:
|
207
|
+
bs = len(obj.input_ids)
|
208
|
+
|
197
209
|
for i in range(bs):
|
198
210
|
rid = obj.rid[i]
|
199
|
-
|
211
|
+
|
212
|
+
if obj.input_ids is None:
|
213
|
+
input_text = obj.text[i]
|
214
|
+
input_ids = self.tokenizer.encode(obj.text[i])
|
215
|
+
else:
|
216
|
+
input_text = None
|
217
|
+
input_ids = obj.input_ids[i]
|
218
|
+
|
200
219
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
201
220
|
if sampling_params.max_new_tokens != 0:
|
202
221
|
sampling_params.normalize(self.tokenizer)
|
@@ -209,7 +228,7 @@ class TokenizerManager:
|
|
209
228
|
)
|
210
229
|
tokenized_obj = TokenizedGenerateReqInput(
|
211
230
|
rid=rid,
|
212
|
-
input_text=
|
231
|
+
input_text=input_text,
|
213
232
|
input_ids=input_ids,
|
214
233
|
pixel_values=pixel_values,
|
215
234
|
image_hash=image_hash,
|
@@ -217,53 +236,176 @@ class TokenizerManager:
|
|
217
236
|
sampling_params=sampling_params,
|
218
237
|
return_logprob=obj.return_logprob[i],
|
219
238
|
logprob_start_len=obj.logprob_start_len[i],
|
239
|
+
top_logprobs_num=obj.top_logprobs_num[i],
|
220
240
|
stream=obj.stream,
|
221
241
|
)
|
222
242
|
self.send_to_router.send_pyobj(tokenized_obj)
|
223
243
|
|
224
|
-
lock = asyncio.Lock()
|
225
244
|
event = asyncio.Event()
|
226
|
-
state = ReqState([], False, event
|
245
|
+
state = ReqState([], False, event)
|
227
246
|
self.rid_to_state[rid] = state
|
228
247
|
|
229
248
|
output_list = []
|
230
249
|
for i in range(bs):
|
231
250
|
rid = obj.rid[i]
|
232
251
|
state = self.rid_to_state[rid]
|
233
|
-
|
234
|
-
|
252
|
+
|
253
|
+
while True:
|
254
|
+
try:
|
255
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
256
|
+
break
|
257
|
+
except asyncio.TimeoutError:
|
258
|
+
if request is not None and await request.is_disconnected():
|
259
|
+
for rid in obj.rid:
|
260
|
+
self.abort_request(rid)
|
261
|
+
raise ValueError(f"Abort request {rid}")
|
262
|
+
continue
|
263
|
+
|
264
|
+
output_list.append(
|
265
|
+
self.convert_logprob_style(
|
266
|
+
state.out_list[-1],
|
267
|
+
obj.return_logprob[i],
|
268
|
+
obj.top_logprobs_num[i],
|
269
|
+
obj.return_text_in_logprobs,
|
270
|
+
)
|
271
|
+
)
|
235
272
|
assert state.finished
|
236
273
|
del self.rid_to_state[rid]
|
237
274
|
|
238
275
|
yield output_list
|
239
276
|
|
240
|
-
|
241
|
-
|
242
|
-
|
277
|
+
def flush_cache(self):
|
278
|
+
req = FlushCacheReq()
|
279
|
+
self.send_to_router.send_pyobj(req)
|
280
|
+
|
281
|
+
def abort_request(self, rid):
|
282
|
+
if rid not in self.rid_to_state:
|
283
|
+
return
|
284
|
+
del self.rid_to_state[rid]
|
285
|
+
req = AbortReq(rid)
|
286
|
+
self.send_to_router.send_pyobj(req)
|
287
|
+
|
288
|
+
def create_abort_task(self, obj: GenerateReqInput):
|
289
|
+
# Abort the request if the client is disconnected.
|
290
|
+
async def abort_request():
|
291
|
+
await asyncio.sleep(3)
|
292
|
+
if obj.is_single:
|
293
|
+
self.abort_request(obj.rid)
|
294
|
+
else:
|
295
|
+
for rid in obj.rids:
|
296
|
+
self.abort_request(rid)
|
243
297
|
|
244
|
-
|
245
|
-
|
246
|
-
|
298
|
+
background_tasks = BackgroundTasks()
|
299
|
+
background_tasks.add_task(abort_request)
|
300
|
+
return background_tasks
|
247
301
|
|
248
|
-
|
302
|
+
def create_handle_loop(self):
|
249
303
|
self.to_create_loop = False
|
250
304
|
loop = asyncio.get_event_loop()
|
251
305
|
loop.create_task(self.handle_loop())
|
252
306
|
|
253
307
|
async def handle_loop(self):
|
254
308
|
while True:
|
255
|
-
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
309
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
310
|
+
assert isinstance(recv_obj, BatchStrOut)
|
311
|
+
|
312
|
+
for i, rid in enumerate(recv_obj.rids):
|
313
|
+
state = self.rid_to_state.get(rid, None)
|
314
|
+
if state is None:
|
315
|
+
continue
|
316
|
+
|
317
|
+
recv_obj.meta_info[i]["id"] = rid
|
318
|
+
out_dict = {
|
319
|
+
"text": recv_obj.output_strs[i],
|
320
|
+
"meta_info": recv_obj.meta_info[i],
|
321
|
+
}
|
322
|
+
state.out_list.append(out_dict)
|
323
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
324
|
+
state.event.set()
|
325
|
+
|
326
|
+
def convert_logprob_style(
|
327
|
+
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
328
|
+
):
|
329
|
+
if return_logprob:
|
330
|
+
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
331
|
+
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
332
|
+
)
|
333
|
+
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
334
|
+
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
335
|
+
)
|
336
|
+
|
337
|
+
if top_logprobs_num > 0:
|
338
|
+
ret["meta_info"][
|
339
|
+
"prefill_top_logprobs"
|
340
|
+
] = self.detokenize_top_logprobs_tokens(
|
341
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
342
|
+
)
|
343
|
+
ret["meta_info"][
|
344
|
+
"decode_top_logprobs"
|
345
|
+
] = self.detokenize_top_logprobs_tokens(
|
346
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
347
|
+
)
|
348
|
+
return ret
|
349
|
+
|
350
|
+
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
351
|
+
if not decode_to_text:
|
352
|
+
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
353
|
+
|
354
|
+
token_ids = [tid for _, tid in token_logprobs]
|
355
|
+
token_texts = self.tokenizer.batch_decode(token_ids)
|
356
|
+
return [
|
357
|
+
(logprob, token_id, token_text)
|
358
|
+
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
359
|
+
]
|
360
|
+
|
361
|
+
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
362
|
+
for i, t in enumerate(top_logprobs):
|
363
|
+
if t:
|
364
|
+
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
365
|
+
return top_logprobs
|
366
|
+
|
367
|
+
|
368
|
+
global global_processor
|
369
|
+
|
370
|
+
|
371
|
+
def init_global_processor(server_args: ServerArgs):
|
372
|
+
global global_processor
|
373
|
+
transformers.logging.set_verbosity_error()
|
374
|
+
global_processor = get_processor(
|
375
|
+
server_args.tokenizer_path,
|
376
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
377
|
+
trust_remote_code=server_args.trust_remote_code,
|
378
|
+
)
|
379
|
+
|
380
|
+
|
381
|
+
def get_pixel_values(
|
382
|
+
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
383
|
+
):
|
384
|
+
try:
|
385
|
+
processor = processor or global_processor
|
386
|
+
image, image_size = load_image(image_data)
|
387
|
+
if image_size is not None:
|
388
|
+
image_hash = hash(image_data)
|
389
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
390
|
+
for _ in range(len(pixel_values)):
|
391
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
392
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
393
|
+
return pixel_values, image_hash, image_size
|
394
|
+
else:
|
395
|
+
image_hash = hash(image_data)
|
396
|
+
if image_aspect_ratio == "pad":
|
397
|
+
image = expand2square(
|
398
|
+
image,
|
399
|
+
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
400
|
+
)
|
401
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
402
|
+
elif image_aspect_ratio == "anyres":
|
403
|
+
pixel_values = process_anyres_image(
|
404
|
+
image, processor.image_processor, image_grid_pinpoints
|
405
|
+
)
|
268
406
|
else:
|
269
|
-
|
407
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
408
|
+
pixel_values = pixel_values.astype(np.float16)
|
409
|
+
return pixel_values, image_hash, image.size
|
410
|
+
except Exception:
|
411
|
+
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|