sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,17 @@
|
|
1
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
|
+
|
1
3
|
import asyncio
|
4
|
+
import inspect
|
2
5
|
|
3
6
|
import uvloop
|
4
7
|
import zmq
|
5
8
|
import zmq.asyncio
|
6
9
|
|
7
10
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
11
|
+
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
8
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
9
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
10
|
-
from sglang.
|
14
|
+
from sglang.utils import get_exception_traceback, graceful_registry
|
11
15
|
|
12
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
13
17
|
|
@@ -33,51 +37,41 @@ class DetokenizerManager:
|
|
33
37
|
|
34
38
|
async def handle_loop(self):
|
35
39
|
while True:
|
36
|
-
recv_obj = await self.recv_from_router.recv_pyobj()
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
)
|
70
|
-
|
71
|
-
self.send_to_tokenizer.send_pyobj(
|
72
|
-
BatchStrOut(
|
73
|
-
recv_obj.rids,
|
74
|
-
output_strs,
|
75
|
-
recv_obj.meta_info,
|
76
|
-
recv_obj.finished,
|
77
|
-
)
|
40
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
41
|
+
assert isinstance(recv_obj, BatchTokenIDOut)
|
42
|
+
|
43
|
+
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
44
|
+
surr_texts = self.tokenizer.batch_decode(
|
45
|
+
recv_obj.surr_output_ids,
|
46
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
47
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
48
|
+
)
|
49
|
+
read_texts = self.tokenizer.batch_decode(
|
50
|
+
recv_obj.read_output_ids,
|
51
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
52
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
53
|
+
)
|
54
|
+
|
55
|
+
# Trim stop str
|
56
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
57
|
+
output_strs = []
|
58
|
+
for i in range(len(recv_obj.rids)):
|
59
|
+
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
61
|
+
|
62
|
+
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
63
|
+
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
64
|
+
if pos != -1:
|
65
|
+
output_strs[i] = output_strs[i][:pos]
|
66
|
+
|
67
|
+
self.send_to_tokenizer.send_pyobj(
|
68
|
+
BatchStrOut(
|
69
|
+
rids=recv_obj.rids,
|
70
|
+
output_str=output_strs,
|
71
|
+
meta_info=recv_obj.meta_info,
|
72
|
+
finished_reason=recv_obj.finished_reason,
|
78
73
|
)
|
79
|
-
|
80
|
-
raise ValueError(f"Invalid object: {recv_obj}")
|
74
|
+
)
|
81
75
|
|
82
76
|
|
83
77
|
def start_detokenizer_process(
|
@@ -85,9 +79,11 @@ def start_detokenizer_process(
|
|
85
79
|
port_args: PortArgs,
|
86
80
|
pipe_writer,
|
87
81
|
):
|
82
|
+
graceful_registry(inspect.currentframe().f_code.co_name)
|
83
|
+
|
88
84
|
try:
|
89
85
|
manager = DetokenizerManager(server_args, port_args)
|
90
|
-
except Exception
|
86
|
+
except Exception:
|
91
87
|
pipe_writer.send(get_exception_traceback())
|
92
88
|
raise
|
93
89
|
pipe_writer.send("init ok")
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
The definition of objects transfered between different
|
3
|
+
processes (TokenizerManager, DetokenizerManager, Controller).
|
4
|
+
"""
|
5
|
+
|
1
6
|
import uuid
|
2
7
|
from dataclasses import dataclass
|
3
8
|
from typing import Dict, List, Optional, Union
|
4
9
|
|
10
|
+
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
5
11
|
from sglang.srt.sampling_params import SamplingParams
|
6
12
|
|
7
13
|
|
@@ -27,14 +33,12 @@ class GenerateReqInput:
|
|
27
33
|
return_text_in_logprobs: bool = False
|
28
34
|
# Whether to stream output
|
29
35
|
stream: bool = False
|
30
|
-
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
31
36
|
|
32
37
|
def post_init(self):
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
assert self.input_ids is None, "Either text or input_ids should be provided"
|
38
|
+
if (self.text is None and self.input_ids is None) or (
|
39
|
+
self.text is not None and self.input_ids is not None
|
40
|
+
):
|
41
|
+
raise ValueError("Either text or input_ids should be provided.")
|
38
42
|
|
39
43
|
if self.text is not None:
|
40
44
|
is_single = isinstance(self.text, str)
|
@@ -69,7 +73,8 @@ class GenerateReqInput:
|
|
69
73
|
if self.rid is None:
|
70
74
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
71
75
|
else:
|
72
|
-
|
76
|
+
if not isinstance(self.rid, list):
|
77
|
+
raise ValueError("The rid should be a list.")
|
73
78
|
|
74
79
|
if self.return_logprob is None:
|
75
80
|
self.return_logprob = [False] * num
|
@@ -105,13 +110,13 @@ class TokenizedGenerateReqInput:
|
|
105
110
|
@dataclass
|
106
111
|
class BatchTokenIDOut:
|
107
112
|
rids: List[str]
|
108
|
-
|
109
|
-
|
110
|
-
|
113
|
+
decoded_texts: List[str]
|
114
|
+
surr_output_ids: List[List[int]]
|
115
|
+
read_output_ids: List[List[int]]
|
111
116
|
skip_special_tokens: List[bool]
|
112
117
|
spaces_between_special_tokens: List[bool]
|
113
118
|
meta_info: List[Dict]
|
114
|
-
|
119
|
+
finished_reason: List[BaseFinishReason]
|
115
120
|
|
116
121
|
|
117
122
|
@dataclass
|
@@ -119,7 +124,7 @@ class BatchStrOut:
|
|
119
124
|
rids: List[str]
|
120
125
|
output_str: List[str]
|
121
126
|
meta_info: List[Dict]
|
122
|
-
|
127
|
+
finished_reason: List[BaseFinishReason]
|
123
128
|
|
124
129
|
|
125
130
|
@dataclass
|
@@ -127,6 +132,11 @@ class FlushCacheReq:
|
|
127
132
|
pass
|
128
133
|
|
129
134
|
|
135
|
+
@dataclass
|
136
|
+
class AbortReq:
|
137
|
+
rid: str
|
138
|
+
|
139
|
+
|
130
140
|
@dataclass
|
131
141
|
class DetokenizeReqInput:
|
132
142
|
input_ids: List[int]
|
@@ -1,16 +1,19 @@
|
|
1
|
+
"""TokenizerManager is a process that tokenizes the text."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import concurrent.futures
|
3
5
|
import dataclasses
|
4
6
|
import logging
|
5
7
|
import multiprocessing as mp
|
6
8
|
import os
|
7
|
-
from typing import List
|
9
|
+
from typing import Dict, List
|
8
10
|
|
9
11
|
import numpy as np
|
10
12
|
import transformers
|
11
13
|
import uvloop
|
12
14
|
import zmq
|
13
15
|
import zmq.asyncio
|
16
|
+
from fastapi import BackgroundTasks
|
14
17
|
|
15
18
|
from sglang.srt.hf_transformers_utils import (
|
16
19
|
get_config,
|
@@ -19,8 +22,9 @@ from sglang.srt.hf_transformers_utils import (
|
|
19
22
|
get_tokenizer,
|
20
23
|
)
|
21
24
|
from sglang.srt.managers.io_struct import (
|
25
|
+
AbortReq,
|
22
26
|
BatchStrOut,
|
23
|
-
|
27
|
+
BatchTokenIDOut,
|
24
28
|
FlushCacheReq,
|
25
29
|
GenerateReqInput,
|
26
30
|
TokenizedGenerateReqInput,
|
@@ -28,7 +32,8 @@ from sglang.srt.managers.io_struct import (
|
|
28
32
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
29
33
|
from sglang.srt.sampling_params import SamplingParams
|
30
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import
|
35
|
+
from sglang.srt.utils import is_multimodal_model, load_image
|
36
|
+
from sglang.utils import get_exception_traceback
|
32
37
|
|
33
38
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
34
39
|
|
@@ -42,51 +47,6 @@ class ReqState:
|
|
42
47
|
event: asyncio.Event
|
43
48
|
|
44
49
|
|
45
|
-
global global_processor
|
46
|
-
|
47
|
-
|
48
|
-
def init_global_processor(server_args: ServerArgs):
|
49
|
-
global global_processor
|
50
|
-
transformers.logging.set_verbosity_error()
|
51
|
-
global_processor = get_processor(
|
52
|
-
server_args.tokenizer_path,
|
53
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
54
|
-
trust_remote_code=server_args.trust_remote_code,
|
55
|
-
)
|
56
|
-
|
57
|
-
|
58
|
-
def get_pixel_values(
|
59
|
-
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
60
|
-
):
|
61
|
-
try:
|
62
|
-
processor = processor or global_processor
|
63
|
-
image, image_size = load_image(image_data)
|
64
|
-
if image_size != None:
|
65
|
-
image_hash = hash(image_data)
|
66
|
-
pixel_values = processor.image_processor(image)["pixel_values"]
|
67
|
-
for _ in range(len(pixel_values)):
|
68
|
-
pixel_values[_] = pixel_values[_].astype(np.float16)
|
69
|
-
pixel_values = np.stack(pixel_values, axis=0)
|
70
|
-
return pixel_values, image_hash, image_size
|
71
|
-
else:
|
72
|
-
image_hash = hash(image_data)
|
73
|
-
if image_aspect_ratio == "pad":
|
74
|
-
image = expand2square(
|
75
|
-
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
76
|
-
)
|
77
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
78
|
-
elif image_aspect_ratio == "anyres":
|
79
|
-
pixel_values = process_anyres_image(
|
80
|
-
image, processor.image_processor, image_grid_pinpoints
|
81
|
-
)
|
82
|
-
else:
|
83
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
84
|
-
pixel_values = pixel_values.astype(np.float16)
|
85
|
-
return pixel_values, image_hash, image.size
|
86
|
-
except Exception:
|
87
|
-
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
88
|
-
|
89
|
-
|
90
50
|
class TokenizerManager:
|
91
51
|
def __init__(
|
92
52
|
self,
|
@@ -132,7 +92,7 @@ class TokenizerManager:
|
|
132
92
|
)
|
133
93
|
|
134
94
|
self.to_create_loop = True
|
135
|
-
self.rid_to_state
|
95
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
136
96
|
|
137
97
|
async def get_pixel_values(self, image_data):
|
138
98
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -153,10 +113,11 @@ class TokenizerManager:
|
|
153
113
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
154
114
|
)
|
155
115
|
|
156
|
-
async def generate_request(self, obj: GenerateReqInput):
|
116
|
+
async def generate_request(self, obj: GenerateReqInput, request=None):
|
157
117
|
if self.to_create_loop:
|
158
|
-
|
118
|
+
self.create_handle_loop()
|
159
119
|
|
120
|
+
obj.post_init()
|
160
121
|
is_single = obj.is_single
|
161
122
|
if is_single:
|
162
123
|
rid = obj.rid
|
@@ -169,7 +130,7 @@ class TokenizerManager:
|
|
169
130
|
if len(input_ids) >= self.context_len:
|
170
131
|
raise ValueError(
|
171
132
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
172
|
-
f"model's context length ({self.context_len} tokens)"
|
133
|
+
f"model's context length ({self.context_len} tokens)."
|
173
134
|
)
|
174
135
|
|
175
136
|
sampling_params = SamplingParams(**obj.sampling_params)
|
@@ -207,23 +168,38 @@ class TokenizerManager:
|
|
207
168
|
self.rid_to_state[rid] = state
|
208
169
|
|
209
170
|
while True:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
171
|
+
try:
|
172
|
+
await asyncio.wait_for(event.wait(), timeout=4)
|
173
|
+
except asyncio.TimeoutError:
|
174
|
+
if request is not None and await request.is_disconnected():
|
175
|
+
self.abort_request(rid)
|
176
|
+
raise ValueError(f"Abort request {rid}")
|
177
|
+
continue
|
178
|
+
|
179
|
+
out = self.convert_logprob_style(
|
180
|
+
state.out_list[-1],
|
181
|
+
obj.return_logprob,
|
182
|
+
obj.top_logprobs_num,
|
183
|
+
obj.return_text_in_logprobs,
|
184
|
+
)
|
215
185
|
|
216
186
|
if self.server_args.log_requests and state.finished:
|
217
187
|
logger.info(f"in={obj.text}, out={out}")
|
218
188
|
|
219
|
-
yield out
|
220
189
|
state.out_list = []
|
221
190
|
if state.finished:
|
222
191
|
del self.rid_to_state[rid]
|
192
|
+
|
193
|
+
yield out
|
194
|
+
|
223
195
|
break
|
196
|
+
|
224
197
|
event.clear()
|
198
|
+
|
199
|
+
yield out
|
225
200
|
else:
|
226
|
-
|
201
|
+
if obj.stream:
|
202
|
+
raise ValueError("Do not support stream for batch mode.")
|
227
203
|
|
228
204
|
if obj.input_ids is None:
|
229
205
|
bs = len(obj.text)
|
@@ -273,45 +249,83 @@ class TokenizerManager:
|
|
273
249
|
for i in range(bs):
|
274
250
|
rid = obj.rid[i]
|
275
251
|
state = self.rid_to_state[rid]
|
276
|
-
|
252
|
+
|
253
|
+
while True:
|
254
|
+
try:
|
255
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
256
|
+
break
|
257
|
+
except asyncio.TimeoutError:
|
258
|
+
if request is not None and await request.is_disconnected():
|
259
|
+
for rid in obj.rid:
|
260
|
+
self.abort_request(rid)
|
261
|
+
raise ValueError(f"Abort request {rid}")
|
262
|
+
continue
|
263
|
+
|
277
264
|
output_list.append(
|
278
|
-
self.convert_logprob_style(
|
279
|
-
|
280
|
-
|
281
|
-
|
265
|
+
self.convert_logprob_style(
|
266
|
+
state.out_list[-1],
|
267
|
+
obj.return_logprob[i],
|
268
|
+
obj.top_logprobs_num[i],
|
269
|
+
obj.return_text_in_logprobs,
|
270
|
+
)
|
271
|
+
)
|
282
272
|
assert state.finished
|
283
273
|
del self.rid_to_state[rid]
|
284
274
|
|
285
275
|
yield output_list
|
286
276
|
|
287
|
-
|
288
|
-
|
289
|
-
self.send_to_router.send_pyobj(
|
277
|
+
def flush_cache(self):
|
278
|
+
req = FlushCacheReq()
|
279
|
+
self.send_to_router.send_pyobj(req)
|
280
|
+
|
281
|
+
def abort_request(self, rid):
|
282
|
+
if rid not in self.rid_to_state:
|
283
|
+
return
|
284
|
+
del self.rid_to_state[rid]
|
285
|
+
req = AbortReq(rid)
|
286
|
+
self.send_to_router.send_pyobj(req)
|
287
|
+
|
288
|
+
def create_abort_task(self, obj: GenerateReqInput):
|
289
|
+
# Abort the request if the client is disconnected.
|
290
|
+
async def abort_request():
|
291
|
+
await asyncio.sleep(3)
|
292
|
+
if obj.is_single:
|
293
|
+
self.abort_request(obj.rid)
|
294
|
+
else:
|
295
|
+
for rid in obj.rids:
|
296
|
+
self.abort_request(rid)
|
290
297
|
|
291
|
-
|
298
|
+
background_tasks = BackgroundTasks()
|
299
|
+
background_tasks.add_task(abort_request)
|
300
|
+
return background_tasks
|
301
|
+
|
302
|
+
def create_handle_loop(self):
|
292
303
|
self.to_create_loop = False
|
293
304
|
loop = asyncio.get_event_loop()
|
294
305
|
loop.create_task(self.handle_loop())
|
295
306
|
|
296
307
|
async def handle_loop(self):
|
297
308
|
while True:
|
298
|
-
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
309
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
310
|
+
assert isinstance(recv_obj, BatchStrOut)
|
311
|
+
|
312
|
+
for i, rid in enumerate(recv_obj.rids):
|
313
|
+
state = self.rid_to_state.get(rid, None)
|
314
|
+
if state is None:
|
315
|
+
continue
|
316
|
+
|
317
|
+
recv_obj.meta_info[i]["id"] = rid
|
318
|
+
out_dict = {
|
319
|
+
"text": recv_obj.output_str[i],
|
320
|
+
"meta_info": recv_obj.meta_info[i],
|
321
|
+
}
|
322
|
+
state.out_list.append(out_dict)
|
323
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
324
|
+
state.event.set()
|
325
|
+
|
326
|
+
def convert_logprob_style(
|
327
|
+
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
328
|
+
):
|
315
329
|
if return_logprob:
|
316
330
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
317
331
|
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
@@ -320,10 +334,14 @@ class TokenizerManager:
|
|
320
334
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
321
335
|
)
|
322
336
|
if top_logprobs_num > 0:
|
323
|
-
ret["meta_info"][
|
337
|
+
ret["meta_info"][
|
338
|
+
"prefill_top_logprobs"
|
339
|
+
] = self.detokenize_top_logprobs_tokens(
|
324
340
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
325
341
|
)
|
326
|
-
ret["meta_info"][
|
342
|
+
ret["meta_info"][
|
343
|
+
"decode_top_logprobs"
|
344
|
+
] = self.detokenize_top_logprobs_tokens(
|
327
345
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
328
346
|
)
|
329
347
|
return ret
|
@@ -344,3 +362,49 @@ class TokenizerManager:
|
|
344
362
|
if t:
|
345
363
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
346
364
|
return top_logprobs
|
365
|
+
|
366
|
+
|
367
|
+
global global_processor
|
368
|
+
|
369
|
+
|
370
|
+
def init_global_processor(server_args: ServerArgs):
|
371
|
+
global global_processor
|
372
|
+
transformers.logging.set_verbosity_error()
|
373
|
+
global_processor = get_processor(
|
374
|
+
server_args.tokenizer_path,
|
375
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
376
|
+
trust_remote_code=server_args.trust_remote_code,
|
377
|
+
)
|
378
|
+
|
379
|
+
|
380
|
+
def get_pixel_values(
|
381
|
+
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
382
|
+
):
|
383
|
+
try:
|
384
|
+
processor = processor or global_processor
|
385
|
+
image, image_size = load_image(image_data)
|
386
|
+
if image_size != None:
|
387
|
+
image_hash = hash(image_data)
|
388
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
389
|
+
for _ in range(len(pixel_values)):
|
390
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
391
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
392
|
+
return pixel_values, image_hash, image_size
|
393
|
+
else:
|
394
|
+
image_hash = hash(image_data)
|
395
|
+
if image_aspect_ratio == "pad":
|
396
|
+
image = expand2square(
|
397
|
+
image,
|
398
|
+
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
399
|
+
)
|
400
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
401
|
+
elif image_aspect_ratio == "anyres":
|
402
|
+
pixel_values = process_anyres_image(
|
403
|
+
image, processor.image_processor, image_grid_pinpoints
|
404
|
+
)
|
405
|
+
else:
|
406
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
407
|
+
pixel_values = pixel_values.astype(np.float16)
|
408
|
+
return pixel_values, image_hash, image.size
|
409
|
+
except Exception:
|
410
|
+
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
sglang/srt/model_config.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
|
+
from transformers import PretrainedConfig
|
4
|
+
|
3
5
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
4
6
|
|
5
7
|
|
@@ -15,11 +17,14 @@ class ModelConfig:
|
|
15
17
|
self.path = path
|
16
18
|
self.trust_remote_code = trust_remote_code
|
17
19
|
self.revision = revision
|
18
|
-
self.
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
20
|
+
self.model_overide_args = model_overide_args
|
21
|
+
self.hf_config = get_config(
|
22
|
+
self.path,
|
23
|
+
trust_remote_code,
|
24
|
+
revision,
|
25
|
+
model_overide_args=model_overide_args,
|
26
|
+
)
|
27
|
+
self.hf_text_config = get_hf_text_config(self.hf_config)
|
23
28
|
if context_length is not None:
|
24
29
|
self.context_len = context_length
|
25
30
|
else:
|
@@ -45,3 +50,76 @@ class ModelConfig:
|
|
45
50
|
self.hidden_size = self.hf_config.hidden_size
|
46
51
|
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
47
52
|
self.vocab_size = self.hf_config.vocab_size
|
53
|
+
|
54
|
+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
55
|
+
def get_total_num_kv_heads(self) -> int:
|
56
|
+
"""Returns the total number of KV heads."""
|
57
|
+
# For GPTBigCode & Falcon:
|
58
|
+
# NOTE: for falcon, when new_decoder_architecture is True, the
|
59
|
+
# multi_query flag is ignored and we use n_head_kv for the number of
|
60
|
+
# KV heads.
|
61
|
+
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
62
|
+
new_decoder_arch_falcon = (
|
63
|
+
self.hf_config.model_type in falcon_model_types
|
64
|
+
and getattr(self.hf_config, "new_decoder_architecture", False)
|
65
|
+
)
|
66
|
+
if not new_decoder_arch_falcon and getattr(
|
67
|
+
self.hf_text_config, "multi_query", False
|
68
|
+
):
|
69
|
+
# Multi-query attention, only one KV head.
|
70
|
+
# Currently, tensor parallelism is not supported in this case.
|
71
|
+
return 1
|
72
|
+
|
73
|
+
# For DBRX and MPT
|
74
|
+
if self.hf_config.model_type in ["mpt"]:
|
75
|
+
if "kv_n_heads" in self.hf_config.attn_config:
|
76
|
+
return self.hf_config.attn_config["kv_n_heads"]
|
77
|
+
return self.hf_config.num_attention_heads
|
78
|
+
if self.hf_config.model_type in ["dbrx"]:
|
79
|
+
return getattr(
|
80
|
+
self.hf_config.attn_config,
|
81
|
+
"kv_n_heads",
|
82
|
+
self.hf_config.num_attention_heads,
|
83
|
+
)
|
84
|
+
|
85
|
+
attributes = [
|
86
|
+
# For Falcon:
|
87
|
+
"n_head_kv",
|
88
|
+
"num_kv_heads",
|
89
|
+
# For LLaMA-2:
|
90
|
+
"num_key_value_heads",
|
91
|
+
# For ChatGLM:
|
92
|
+
"multi_query_group_num",
|
93
|
+
]
|
94
|
+
for attr in attributes:
|
95
|
+
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
96
|
+
if num_kv_heads is not None:
|
97
|
+
return num_kv_heads
|
98
|
+
|
99
|
+
# For non-grouped-query attention models, the number of KV heads is
|
100
|
+
# equal to the number of attention heads.
|
101
|
+
return self.hf_text_config.num_attention_heads
|
102
|
+
|
103
|
+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
104
|
+
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
105
|
+
"""Returns the number of KV heads per GPU."""
|
106
|
+
total_num_kv_heads = self.get_total_num_kv_heads()
|
107
|
+
# If tensor parallelism is used, we divide the number of KV heads by
|
108
|
+
# the tensor parallel size. We will replicate the KV heads in the
|
109
|
+
# case where the number of KV heads is smaller than the tensor
|
110
|
+
# parallel size so each GPU has at least one KV head.
|
111
|
+
return max(1, total_num_kv_heads // tensor_parallel_size)
|
112
|
+
|
113
|
+
|
114
|
+
def get_hf_text_config(config: PretrainedConfig):
|
115
|
+
"""Get the "sub" config relevant to llm for multi modal models.
|
116
|
+
No op for pure text models.
|
117
|
+
"""
|
118
|
+
if hasattr(config, "text_config"):
|
119
|
+
# The code operates under the assumption that text_config should have
|
120
|
+
# `num_attention_heads` (among others). Assert here to fail early
|
121
|
+
# if transformers config doesn't align with this assumption.
|
122
|
+
assert hasattr(config.text_config, "num_attention_heads")
|
123
|
+
return config.text_config
|
124
|
+
else:
|
125
|
+
return config
|