sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -17,21 +17,56 @@ limitations under the License.
|
|
17
17
|
# and "Punica: Multi-Tenant LoRA Serving"
|
18
18
|
|
19
19
|
|
20
|
+
import logging
|
20
21
|
import re
|
21
|
-
from dataclasses import dataclass
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
26
26
|
from sglang.srt.lora.lora_config import LoRAConfig
|
27
|
-
from sglang.srt.model_executor.forward_batch_info import
|
28
|
-
from sglang.srt.utils import
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
29
29
|
|
30
|
-
|
31
|
-
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
if is_flashinfer_available():
|
32
33
|
from flashinfer import SegmentGEMMWrapper
|
33
34
|
|
34
35
|
|
36
|
+
def get_module_name(name):
|
37
|
+
# Fallback solution of mapping from config module name to module name in model class.
|
38
|
+
# Please check if it aligns with your base model.
|
39
|
+
# Please implement the function in the model class if it is not.
|
40
|
+
# You can reference this function in llama.py.
|
41
|
+
params_mapping = {
|
42
|
+
"q_proj": "qkv_proj",
|
43
|
+
"k_proj": "qkv_proj",
|
44
|
+
"v_proj": "qkv_proj",
|
45
|
+
"gate_proj": "gate_up_proj",
|
46
|
+
"up_proj": "gate_up_proj",
|
47
|
+
}
|
48
|
+
return params_mapping.get(name, name)
|
49
|
+
|
50
|
+
|
51
|
+
def get_hidden_dim(module_name, config):
|
52
|
+
# Fallback solution of get_hidden_dim for different modules
|
53
|
+
# Please check if it aligns with your base model.
|
54
|
+
# Please implement the function in the model class if it is not.
|
55
|
+
# You can reference this function in llama.py.
|
56
|
+
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
57
|
+
return config.hidden_size, config.hidden_size
|
58
|
+
elif module_name in ["kv_proj"]:
|
59
|
+
return config.hidden_size, config.hidden_size // (
|
60
|
+
config.num_attention_heads // config.num_key_value_heads
|
61
|
+
)
|
62
|
+
elif module_name == "gate_up_proj":
|
63
|
+
return config.hidden_size, config.intermediate_size
|
64
|
+
elif module_name == "down_proj":
|
65
|
+
return config.intermediate_size, config.hidden_size
|
66
|
+
else:
|
67
|
+
raise NotImplementedError()
|
68
|
+
|
69
|
+
|
35
70
|
def get_stacked_name(name):
|
36
71
|
# origin name -> (name for A, name for B)
|
37
72
|
params_mapping = {
|
@@ -104,12 +139,20 @@ class LoRAManager:
|
|
104
139
|
self.origin_target_modules = set(self.origin_target_modules) | set(
|
105
140
|
self.configs[name].target_modules
|
106
141
|
)
|
107
|
-
self.
|
108
|
-
|
142
|
+
if hasattr(self.base_model, "get_module_name"):
|
143
|
+
self.target_modules = {
|
109
144
|
self.base_model.get_module_name(module)
|
110
145
|
for module in self.origin_target_modules
|
111
|
-
|
112
|
-
|
146
|
+
}
|
147
|
+
else:
|
148
|
+
logger.warning(
|
149
|
+
f"WARNING: get_module_name() is not defined, "
|
150
|
+
f"which is used to map config module name to model implementation module name."
|
151
|
+
f"Use the default one, but please check if it is correct for your model."
|
152
|
+
)
|
153
|
+
self.target_modules = {
|
154
|
+
get_module_name(module) for module in self.origin_target_modules
|
155
|
+
}
|
113
156
|
self.target_weights = set(
|
114
157
|
[get_stacked_name(module) for module in self.origin_target_modules]
|
115
158
|
)
|
@@ -147,7 +190,15 @@ class LoRAManager:
|
|
147
190
|
num_layer = self.base_hf_config.num_hidden_layers
|
148
191
|
for module_A, module_B in self.target_weights:
|
149
192
|
# init A tensor, column_major=True
|
150
|
-
|
193
|
+
if hasattr(self.base_model, "get_hidden_dim"):
|
194
|
+
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
195
|
+
else:
|
196
|
+
logger.warning(
|
197
|
+
f"WARNING: get_hidden_dim() is not defined, "
|
198
|
+
f"which is used to get the hidden dim for different lora modules"
|
199
|
+
f"Use the default one, but please check if it is correct for your model."
|
200
|
+
)
|
201
|
+
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
151
202
|
c = self.loras[-1].get_stacked_multiply(module_A)
|
152
203
|
if module_A not in self.A_buffer:
|
153
204
|
self.A_buffer[module_A] = [
|
@@ -163,7 +214,15 @@ class LoRAManager:
|
|
163
214
|
for i in range(num_layer)
|
164
215
|
]
|
165
216
|
# init B tensor, column_major=True
|
166
|
-
|
217
|
+
if hasattr(self.base_model, "get_hidden_dim"):
|
218
|
+
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
219
|
+
else:
|
220
|
+
logger.warning(
|
221
|
+
f"WARNING: get_hidden_dim() is not defined, "
|
222
|
+
f"which is used to get the hidden dim for different lora modules"
|
223
|
+
f"Use the default one, but please check if it is correct for your model."
|
224
|
+
)
|
225
|
+
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
167
226
|
c = self.loras[-1].get_stacked_multiply(module_B)
|
168
227
|
if module_B not in self.B_buffer:
|
169
228
|
self.B_buffer[module_B] = [
|
@@ -208,33 +267,46 @@ class LoRAManager:
|
|
208
267
|
if lora_weight_name:
|
209
268
|
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
210
269
|
|
211
|
-
def prepare_lora_batch(self,
|
270
|
+
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
212
271
|
# load active loras into lora memory pool
|
213
|
-
cur_uids = set(
|
272
|
+
cur_uids = set(forward_batch.lora_paths)
|
214
273
|
assert len(cur_uids) <= self.max_loras_per_batch
|
215
274
|
i = 0
|
275
|
+
j = len(self.active_uids)
|
216
276
|
evictable_uids = list(self.active_uids)
|
217
277
|
for uid in cur_uids:
|
218
278
|
if uid not in self.active_uids:
|
219
|
-
|
220
|
-
|
221
|
-
|
279
|
+
if j < self.max_loras_per_batch:
|
280
|
+
index = j
|
281
|
+
j += 1
|
282
|
+
else:
|
283
|
+
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
284
|
+
i += 1
|
285
|
+
assert i < len(evictable_uids)
|
222
286
|
self.active_uids.remove(evictable_uids[i])
|
223
287
|
self.buffer_id.pop(evictable_uids[i])
|
224
|
-
|
288
|
+
index = i
|
289
|
+
i += 1
|
290
|
+
self.load_lora(uid, index)
|
225
291
|
self.active_uids.add(uid)
|
226
|
-
self.buffer_id[uid] =
|
227
|
-
i += 1
|
292
|
+
self.buffer_id[uid] = index
|
228
293
|
|
229
294
|
if cur_uids == set([None]):
|
230
295
|
return
|
231
296
|
|
232
297
|
# setup lora in forward modules
|
233
|
-
bs =
|
234
|
-
seg_lens =
|
298
|
+
bs = forward_batch.batch_size
|
299
|
+
seg_lens = (
|
300
|
+
forward_batch.extend_seq_lens
|
301
|
+
if forward_batch.forward_mode.is_extend()
|
302
|
+
else torch.ones(bs, device="cuda")
|
303
|
+
)
|
304
|
+
# FIXME: reuse the data rather than recompute
|
305
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
306
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
235
307
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
236
|
-
for i,
|
237
|
-
weight_indices[i] = self.buffer_id[
|
308
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
309
|
+
weight_indices[i] = self.buffer_id[lora_path]
|
238
310
|
|
239
311
|
for module_name, module in self.lora_modules:
|
240
312
|
layer_id = get_layer_id(module_name)
|
@@ -245,7 +317,7 @@ class LoRAManager:
|
|
245
317
|
self.A_buffer[weight_name][layer_id],
|
246
318
|
self.B_buffer[weight_name][layer_id],
|
247
319
|
bs,
|
248
|
-
|
320
|
+
seg_indptr,
|
249
321
|
weight_indices,
|
250
322
|
)
|
251
323
|
else:
|
@@ -254,6 +326,6 @@ class LoRAManager:
|
|
254
326
|
self.B_buffer["q_proj"][layer_id],
|
255
327
|
self.B_buffer["kv_proj"][layer_id],
|
256
328
|
bs,
|
257
|
-
|
329
|
+
seg_indptr,
|
258
330
|
weight_indices,
|
259
331
|
)
|
@@ -15,13 +15,12 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
17
17
|
|
18
|
-
import asyncio
|
19
18
|
import dataclasses
|
19
|
+
import logging
|
20
|
+
from collections import OrderedDict
|
20
21
|
from typing import List
|
21
22
|
|
22
|
-
import uvloop
|
23
23
|
import zmq
|
24
|
-
import zmq.asyncio
|
25
24
|
|
26
25
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
27
26
|
from sglang.srt.managers.io_struct import (
|
@@ -32,9 +31,10 @@ from sglang.srt.managers.io_struct import (
|
|
32
31
|
)
|
33
32
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
34
33
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
34
|
+
from sglang.srt.utils import configure_logger, kill_parent_process
|
35
35
|
from sglang.utils import find_printable_text, get_exception_traceback
|
36
36
|
|
37
|
-
|
37
|
+
logger = logging.getLogger(__name__)
|
38
38
|
|
39
39
|
|
40
40
|
@dataclasses.dataclass
|
@@ -57,12 +57,12 @@ class DetokenizerManager:
|
|
57
57
|
port_args: PortArgs,
|
58
58
|
):
|
59
59
|
# Init inter-process communication
|
60
|
-
context = zmq.
|
61
|
-
self.
|
62
|
-
self.
|
60
|
+
context = zmq.Context(2)
|
61
|
+
self.recv_from_scheduler = context.socket(zmq.PULL)
|
62
|
+
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
|
63
63
|
|
64
64
|
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
65
|
-
self.send_to_tokenizer.connect(f"
|
65
|
+
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
|
66
66
|
|
67
67
|
if server_args.skip_tokenizer_init:
|
68
68
|
self.tokenizer = None
|
@@ -73,13 +73,13 @@ class DetokenizerManager:
|
|
73
73
|
trust_remote_code=server_args.trust_remote_code,
|
74
74
|
)
|
75
75
|
|
76
|
-
self.decode_status =
|
76
|
+
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
-
|
78
|
+
def event_loop(self):
|
79
79
|
"""The event loop that handles requests"""
|
80
80
|
|
81
81
|
while True:
|
82
|
-
recv_obj =
|
82
|
+
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
83
83
|
|
84
84
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
85
85
|
# If it is embedding model, no detokenization is needed.
|
@@ -170,16 +170,29 @@ class DetokenizerManager:
|
|
170
170
|
)
|
171
171
|
|
172
172
|
|
173
|
-
|
173
|
+
class LimitedCapacityDict(OrderedDict):
|
174
|
+
def __init__(self, capacity=1 << 15, *args, **kwargs):
|
175
|
+
super().__init__(*args, **kwargs)
|
176
|
+
self.capacity = capacity
|
177
|
+
|
178
|
+
def __setitem__(self, key, value):
|
179
|
+
if len(self) >= self.capacity:
|
180
|
+
# Remove the oldest element (first item in the dict)
|
181
|
+
self.popitem(last=False)
|
182
|
+
# Set the new item
|
183
|
+
super().__setitem__(key, value)
|
184
|
+
|
185
|
+
|
186
|
+
def run_detokenizer_process(
|
174
187
|
server_args: ServerArgs,
|
175
188
|
port_args: PortArgs,
|
176
|
-
pipe_writer,
|
177
189
|
):
|
190
|
+
configure_logger(server_args)
|
191
|
+
|
178
192
|
try:
|
179
193
|
manager = DetokenizerManager(server_args, port_args)
|
194
|
+
manager.event_loop()
|
180
195
|
except Exception:
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
loop = asyncio.get_event_loop()
|
185
|
-
loop.run_until_complete(manager.handle_loop())
|
196
|
+
msg = get_exception_traceback()
|
197
|
+
logger.error(msg)
|
198
|
+
kill_parent_process()
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# TODO: also move pad_input_ids into this module
|
2
|
+
import asyncio
|
3
|
+
import concurrent.futures
|
4
|
+
import logging
|
5
|
+
import multiprocessing as mp
|
6
|
+
import os
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import List, Optional, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import transformers
|
12
|
+
|
13
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
14
|
+
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
15
|
+
from sglang.srt.server_args import ServerArgs
|
16
|
+
from sglang.srt.utils import load_image
|
17
|
+
from sglang.utils import get_exception_traceback
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
global global_processor
|
22
|
+
|
23
|
+
|
24
|
+
def init_global_processor(server_args: ServerArgs):
|
25
|
+
"""Init the global processor for multi modal models."""
|
26
|
+
global global_processor
|
27
|
+
transformers.logging.set_verbosity_error()
|
28
|
+
global_processor = get_processor(
|
29
|
+
server_args.tokenizer_path,
|
30
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
31
|
+
trust_remote_code=server_args.trust_remote_code,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class BaseImageProcessor(ABC):
|
36
|
+
@abstractmethod
|
37
|
+
async def process_images_async(self, image_data, **kwargs):
|
38
|
+
pass
|
39
|
+
|
40
|
+
|
41
|
+
class DummyImageProcessor(BaseImageProcessor):
|
42
|
+
async def process_images_async(self, *args, **kwargs):
|
43
|
+
return None
|
44
|
+
|
45
|
+
|
46
|
+
class LlavaImageProcessor(BaseImageProcessor):
|
47
|
+
def __init__(self, hf_config, server_args, _image_processor):
|
48
|
+
self.hf_config = hf_config
|
49
|
+
self._image_processor = _image_processor
|
50
|
+
self.executor = concurrent.futures.ProcessPoolExecutor(
|
51
|
+
initializer=init_global_processor,
|
52
|
+
mp_context=mp.get_context("fork"),
|
53
|
+
initargs=(server_args,),
|
54
|
+
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
55
|
+
)
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def _process_single_image_task(
|
59
|
+
image_data: Union[str, bytes],
|
60
|
+
image_aspect_ratio: Optional[str] = None,
|
61
|
+
image_grid_pinpoints: Optional[str] = None,
|
62
|
+
image_processor=None,
|
63
|
+
):
|
64
|
+
image_processor = image_processor or global_processor.image_processor
|
65
|
+
|
66
|
+
try:
|
67
|
+
image, image_size = load_image(image_data)
|
68
|
+
if image_size is not None:
|
69
|
+
# It is a video with multiple images
|
70
|
+
image_hash = hash(image_data)
|
71
|
+
pixel_values = image_processor(image)["pixel_values"]
|
72
|
+
for _ in range(len(pixel_values)):
|
73
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
74
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
75
|
+
return pixel_values, image_hash, image_size
|
76
|
+
else:
|
77
|
+
# It is an image
|
78
|
+
image_hash = hash(image_data)
|
79
|
+
if image_aspect_ratio == "pad":
|
80
|
+
image = expand2square(
|
81
|
+
image,
|
82
|
+
tuple(int(x * 255) for x in image_processor.image_mean),
|
83
|
+
)
|
84
|
+
pixel_values = image_processor(image.convert("RGB"))[
|
85
|
+
"pixel_values"
|
86
|
+
][0]
|
87
|
+
elif image_aspect_ratio == "anyres" or (
|
88
|
+
image_aspect_ratio is not None
|
89
|
+
and "anyres_max" in image_aspect_ratio
|
90
|
+
):
|
91
|
+
pixel_values = process_anyres_image(
|
92
|
+
image, image_processor, image_grid_pinpoints
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
pixel_values = image_processor(image)["pixel_values"][0]
|
96
|
+
|
97
|
+
if isinstance(pixel_values, np.ndarray):
|
98
|
+
pixel_values = pixel_values.astype(np.float16)
|
99
|
+
|
100
|
+
return pixel_values, image_hash, image.size
|
101
|
+
except Exception:
|
102
|
+
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
103
|
+
|
104
|
+
async def _process_single_image(
|
105
|
+
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
106
|
+
):
|
107
|
+
if self.executor is not None:
|
108
|
+
loop = asyncio.get_event_loop()
|
109
|
+
return await loop.run_in_executor(
|
110
|
+
self.executor,
|
111
|
+
LlavaImageProcessor._process_single_image_task,
|
112
|
+
image_data,
|
113
|
+
aspect_ratio,
|
114
|
+
grid_pinpoints,
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
return self._process_single_image_task(
|
118
|
+
image_data, aspect_ratio, grid_pinpoints
|
119
|
+
)
|
120
|
+
|
121
|
+
async def process_images_async(
|
122
|
+
self, image_data: List[Union[str, bytes]], request_obj
|
123
|
+
):
|
124
|
+
if not image_data:
|
125
|
+
return None
|
126
|
+
|
127
|
+
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
128
|
+
grid_pinpoints = (
|
129
|
+
self.hf_config.image_grid_pinpoints
|
130
|
+
if hasattr(self.hf_config, "image_grid_pinpoints")
|
131
|
+
and "anyres" in aspect_ratio
|
132
|
+
else None
|
133
|
+
)
|
134
|
+
|
135
|
+
if isinstance(image_data, list) and len(image_data) > 0:
|
136
|
+
# Multiple images
|
137
|
+
if len(image_data) > 1:
|
138
|
+
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
139
|
+
pixel_values, image_hashes, image_sizes = [], [], []
|
140
|
+
res = []
|
141
|
+
for img_data in image_data:
|
142
|
+
res.append(
|
143
|
+
self._process_single_image(
|
144
|
+
img_data, aspect_ratio, grid_pinpoints
|
145
|
+
)
|
146
|
+
)
|
147
|
+
res = await asyncio.gather(*res)
|
148
|
+
for pixel_v, image_h, image_s in res:
|
149
|
+
pixel_values.append(pixel_v)
|
150
|
+
image_hashes.append(image_h)
|
151
|
+
image_sizes.append(image_s)
|
152
|
+
|
153
|
+
if isinstance(pixel_values[0], np.ndarray):
|
154
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
155
|
+
else:
|
156
|
+
# A single image
|
157
|
+
pixel_values, image_hash, image_size = await self._process_single_image(
|
158
|
+
image_data[0], aspect_ratio, grid_pinpoints
|
159
|
+
)
|
160
|
+
image_hashes = [image_hash]
|
161
|
+
image_sizes = [image_size]
|
162
|
+
elif isinstance(image_data, str):
|
163
|
+
# A single image
|
164
|
+
pixel_values, image_hash, image_size = await self._process_single_image(
|
165
|
+
image_data, aspect_ratio, grid_pinpoints
|
166
|
+
)
|
167
|
+
image_hashes = [image_hash]
|
168
|
+
image_sizes = [image_size]
|
169
|
+
else:
|
170
|
+
raise ValueError(f"Invalid image data: {image_data}")
|
171
|
+
|
172
|
+
return {
|
173
|
+
"pixel_values": pixel_values,
|
174
|
+
"image_hashes": image_hashes,
|
175
|
+
"image_sizes": image_sizes,
|
176
|
+
"modalities": request_obj.modalities,
|
177
|
+
}
|
178
|
+
|
179
|
+
|
180
|
+
def get_image_processor(
|
181
|
+
hf_config, server_args: ServerArgs, _image_processor
|
182
|
+
) -> BaseImageProcessor:
|
183
|
+
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
184
|
+
|
185
|
+
|
186
|
+
def get_dummy_image_processor():
|
187
|
+
return DummyImageProcessor()
|