sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,55 +0,0 @@
|
|
1
|
-
# TODO: also move pad_input_ids into this module
|
2
|
-
import importlib
|
3
|
-
import logging
|
4
|
-
import pkgutil
|
5
|
-
from functools import lru_cache
|
6
|
-
|
7
|
-
from transformers import IMAGE_PROCESSOR_MAPPING
|
8
|
-
|
9
|
-
from sglang.srt.managers.image_processors.base_image_processor import (
|
10
|
-
BaseImageProcessor,
|
11
|
-
DummyImageProcessor,
|
12
|
-
)
|
13
|
-
from sglang.srt.server_args import ServerArgs
|
14
|
-
|
15
|
-
logger = logging.getLogger(__name__)
|
16
|
-
|
17
|
-
|
18
|
-
IMAGE_PROCESSOR_MAPPING = {}
|
19
|
-
|
20
|
-
|
21
|
-
def get_image_processor(
|
22
|
-
hf_config, server_args: ServerArgs, processor
|
23
|
-
) -> BaseImageProcessor:
|
24
|
-
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
|
25
|
-
if model_cls.__name__ in hf_config.architectures:
|
26
|
-
return processor_cls(hf_config, server_args, processor)
|
27
|
-
raise ValueError(
|
28
|
-
f"No image processor found for architecture: {hf_config.architectures}"
|
29
|
-
)
|
30
|
-
|
31
|
-
|
32
|
-
def get_dummy_image_processor():
|
33
|
-
return DummyImageProcessor()
|
34
|
-
|
35
|
-
|
36
|
-
@lru_cache()
|
37
|
-
def import_image_processors():
|
38
|
-
package_name = "sglang.srt.managers.image_processors"
|
39
|
-
package = importlib.import_module(package_name)
|
40
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
|
-
if not ispkg:
|
42
|
-
try:
|
43
|
-
module = importlib.import_module(name)
|
44
|
-
except Exception as e:
|
45
|
-
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
46
|
-
continue
|
47
|
-
if hasattr(module, "ImageProcessorMapping"):
|
48
|
-
entry = module.ImageProcessorMapping
|
49
|
-
if isinstance(entry, dict):
|
50
|
-
for processor_name, cls in entry.items():
|
51
|
-
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
|
52
|
-
|
53
|
-
|
54
|
-
# also register processors
|
55
|
-
import_image_processors()
|
@@ -1,219 +0,0 @@
|
|
1
|
-
import concurrent
|
2
|
-
import concurrent.futures
|
3
|
-
import dataclasses
|
4
|
-
import multiprocessing as mp
|
5
|
-
import os
|
6
|
-
from abc import ABC, abstractmethod
|
7
|
-
from typing import Optional
|
8
|
-
|
9
|
-
import PIL
|
10
|
-
import transformers
|
11
|
-
from decord import VideoReader, cpu
|
12
|
-
from PIL import Image
|
13
|
-
|
14
|
-
from sglang.srt.server_args import ServerArgs
|
15
|
-
from sglang.srt.utils import load_image
|
16
|
-
from sglang.utils import logger
|
17
|
-
|
18
|
-
global global_processor
|
19
|
-
|
20
|
-
|
21
|
-
def get_global_processor():
|
22
|
-
global global_processor
|
23
|
-
return global_processor
|
24
|
-
|
25
|
-
|
26
|
-
def init_global_processor(sglang_image_processor, server_args: ServerArgs):
|
27
|
-
"""Init the global processor for multi-modal models."""
|
28
|
-
global global_processor
|
29
|
-
transformers.logging.set_verbosity_error()
|
30
|
-
global_processor = sglang_image_processor._build_processor(server_args=server_args)
|
31
|
-
|
32
|
-
|
33
|
-
@dataclasses.dataclass
|
34
|
-
class BaseImageProcessorOutput:
|
35
|
-
image_hashes: list[int]
|
36
|
-
image_sizes: list[tuple[int, int]]
|
37
|
-
all_frames: [PIL.Image]
|
38
|
-
# input_text, with each frame of video/image represented as an image_token
|
39
|
-
input_text: str
|
40
|
-
|
41
|
-
|
42
|
-
class BaseImageProcessor(ABC):
|
43
|
-
def __init__(self, hf_config, server_args, _processor):
|
44
|
-
self.hf_config = hf_config
|
45
|
-
self._processor = _processor
|
46
|
-
self.server_args = server_args
|
47
|
-
# FIXME: not accurate, model and image specific
|
48
|
-
self.NUM_TOKEN_PER_FRAME = 330
|
49
|
-
|
50
|
-
self.executor = concurrent.futures.ProcessPoolExecutor(
|
51
|
-
initializer=init_global_processor,
|
52
|
-
mp_context=mp.get_context("fork"),
|
53
|
-
initargs=(
|
54
|
-
self,
|
55
|
-
server_args,
|
56
|
-
),
|
57
|
-
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
58
|
-
)
|
59
|
-
|
60
|
-
def _build_processor(self, server_args):
|
61
|
-
"""Init the global processor for multi modal models."""
|
62
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
63
|
-
|
64
|
-
return get_processor(
|
65
|
-
server_args.tokenizer_path,
|
66
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
67
|
-
trust_remote_code=server_args.trust_remote_code,
|
68
|
-
)
|
69
|
-
|
70
|
-
@abstractmethod
|
71
|
-
async def process_images_async(
|
72
|
-
self, image_data, input_text, max_req_input_len, **kwargs
|
73
|
-
):
|
74
|
-
pass
|
75
|
-
|
76
|
-
def get_estimated_frames_list(self, image_data):
|
77
|
-
"""
|
78
|
-
estimate the total frame count from all visual input
|
79
|
-
"""
|
80
|
-
# Before processing inputs
|
81
|
-
estimated_frames_list = []
|
82
|
-
for image in image_data:
|
83
|
-
if isinstance(image, str) and image.startswith("video:"):
|
84
|
-
path = image[len("video:") :]
|
85
|
-
# Estimate frames for the video
|
86
|
-
vr = VideoReader(path, ctx=cpu(0))
|
87
|
-
num_frames = len(vr)
|
88
|
-
else:
|
89
|
-
# For images, each contributes one frame
|
90
|
-
num_frames = 1
|
91
|
-
estimated_frames_list.append(num_frames)
|
92
|
-
|
93
|
-
return estimated_frames_list
|
94
|
-
|
95
|
-
@staticmethod
|
96
|
-
def encode_video(video_path, frame_count_limit=None):
|
97
|
-
if not os.path.exists(video_path):
|
98
|
-
logger.error(f"Video {video_path} does not exist")
|
99
|
-
return []
|
100
|
-
|
101
|
-
if frame_count_limit == 0:
|
102
|
-
return []
|
103
|
-
|
104
|
-
def uniform_sample(l, n):
|
105
|
-
gap = len(l) / n
|
106
|
-
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
107
|
-
return [l[i] for i in idxs]
|
108
|
-
|
109
|
-
vr = VideoReader(video_path, ctx=cpu(0))
|
110
|
-
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
111
|
-
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
112
|
-
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
113
|
-
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
114
|
-
|
115
|
-
frames = vr.get_batch(frame_indices).asnumpy()
|
116
|
-
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
117
|
-
return frames
|
118
|
-
|
119
|
-
def load_images(
|
120
|
-
self,
|
121
|
-
input_ids: list,
|
122
|
-
image_data,
|
123
|
-
image_token: str,
|
124
|
-
max_req_input_len: int,
|
125
|
-
return_text: Optional[bool] = True,
|
126
|
-
discard_alpha_channel: bool = True,
|
127
|
-
) -> BaseImageProcessorOutput:
|
128
|
-
"""
|
129
|
-
Each frame of video/image will be replaced by a single image token
|
130
|
-
|
131
|
-
Args:
|
132
|
-
|
133
|
-
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
134
|
-
|
135
|
-
"""
|
136
|
-
image_hashes, image_sizes = [], []
|
137
|
-
all_frames = []
|
138
|
-
new_text_parts = []
|
139
|
-
|
140
|
-
if isinstance(input_ids, list) and return_text:
|
141
|
-
assert len(input_ids) and isinstance(input_ids[0], int)
|
142
|
-
input_text = self._processor.tokenizer.decode(input_ids)
|
143
|
-
else:
|
144
|
-
input_text = input_ids
|
145
|
-
|
146
|
-
if return_text:
|
147
|
-
text_parts = input_text.split(image_token)
|
148
|
-
|
149
|
-
# TODO(mick): load from server_args, env, or sampling_params
|
150
|
-
MAX_NUM_FRAMES = 30
|
151
|
-
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
152
|
-
total_frame_count = sum(estimated_frames_list)
|
153
|
-
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
154
|
-
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
155
|
-
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
156
|
-
|
157
|
-
assert len(image_data) == len(estimated_frames_list)
|
158
|
-
|
159
|
-
# Process each input with allocated frames
|
160
|
-
for image_index, (image, estimated_frames) in enumerate(
|
161
|
-
zip(image_data, estimated_frames_list)
|
162
|
-
):
|
163
|
-
if len(all_frames) >= MAX_NUM_FRAMES:
|
164
|
-
max_frames_to_process = 0
|
165
|
-
else:
|
166
|
-
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
167
|
-
|
168
|
-
if max_frames_to_process == 0:
|
169
|
-
frames = []
|
170
|
-
else:
|
171
|
-
try:
|
172
|
-
if isinstance(image, str) and image.startswith("video:"):
|
173
|
-
path = image[len("video:") :]
|
174
|
-
frames = BaseImageProcessor.encode_video(
|
175
|
-
path, frame_count_limit=max_frames_to_process
|
176
|
-
)
|
177
|
-
else:
|
178
|
-
raw_image, _size = load_image(image)
|
179
|
-
if discard_alpha_channel:
|
180
|
-
raw_image = raw_image.convert("RGB")
|
181
|
-
frames = [raw_image]
|
182
|
-
assert len(frames) != 0
|
183
|
-
except FileNotFoundError as e:
|
184
|
-
print(e)
|
185
|
-
return None
|
186
|
-
|
187
|
-
image_sizes += [frames[0].size] * len(frames)
|
188
|
-
image_hashes += [hash(image)] * len(frames)
|
189
|
-
all_frames += frames
|
190
|
-
|
191
|
-
if return_text:
|
192
|
-
new_text_parts.append(text_parts[image_index])
|
193
|
-
if max_frames_to_process != 0:
|
194
|
-
new_text_parts.append(image_token * len(frames))
|
195
|
-
assert max_frames_to_process >= len(frames)
|
196
|
-
if return_text:
|
197
|
-
new_text_parts.append(text_parts[-1])
|
198
|
-
|
199
|
-
input_text = "".join(new_text_parts)
|
200
|
-
return BaseImageProcessorOutput(
|
201
|
-
image_hashes, image_sizes, all_frames, input_text
|
202
|
-
)
|
203
|
-
|
204
|
-
|
205
|
-
class DummyImageProcessor(BaseImageProcessor):
|
206
|
-
def __init__(self):
|
207
|
-
pass
|
208
|
-
|
209
|
-
async def process_images_async(self, *args, **kwargs):
|
210
|
-
return None
|
211
|
-
|
212
|
-
|
213
|
-
def init_global_processor(
|
214
|
-
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
|
215
|
-
):
|
216
|
-
"""Init the global processor for multi-modal models."""
|
217
|
-
global global_processor
|
218
|
-
transformers.logging.set_verbosity_error()
|
219
|
-
global_processor = sglang_image_processor._build_processor(server_args=server_args)
|
@@ -1,86 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
from typing import List, Union
|
3
|
-
|
4
|
-
from sglang.srt.managers.image_processor import BaseImageProcessor
|
5
|
-
from sglang.srt.managers.image_processors.base_image_processor import (
|
6
|
-
get_global_processor,
|
7
|
-
)
|
8
|
-
from sglang.srt.models.minicpmv import MiniCPMV
|
9
|
-
|
10
|
-
|
11
|
-
class MiniCPMVImageProcessor(BaseImageProcessor):
|
12
|
-
def __init__(self, hf_config, server_args, _processor):
|
13
|
-
super().__init__(hf_config, server_args, _processor)
|
14
|
-
self.IMAGE_TOKEN = "(<image>./</image>)"
|
15
|
-
|
16
|
-
@staticmethod
|
17
|
-
def _process_images_task(images, input_text):
|
18
|
-
processor = get_global_processor()
|
19
|
-
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
|
20
|
-
return {
|
21
|
-
"input_ids": result.input_ids,
|
22
|
-
"pixel_values": result.pixel_values,
|
23
|
-
"tgt_sizes": result.tgt_sizes,
|
24
|
-
}
|
25
|
-
|
26
|
-
async def _process_images(self, images, input_text):
|
27
|
-
if self.executor is not None:
|
28
|
-
loop = asyncio.get_event_loop()
|
29
|
-
image_inputs = await loop.run_in_executor(
|
30
|
-
self.executor,
|
31
|
-
MiniCPMVImageProcessor._process_images_task,
|
32
|
-
images,
|
33
|
-
input_text,
|
34
|
-
)
|
35
|
-
else:
|
36
|
-
image_inputs = self._processor(
|
37
|
-
images=images, text=input_text, return_tensors="pt"
|
38
|
-
)
|
39
|
-
|
40
|
-
return image_inputs
|
41
|
-
|
42
|
-
async def process_images_async(
|
43
|
-
self,
|
44
|
-
image_data: List[Union[str, bytes]],
|
45
|
-
input_ids,
|
46
|
-
request_obj,
|
47
|
-
max_req_input_len,
|
48
|
-
):
|
49
|
-
if not image_data:
|
50
|
-
return None
|
51
|
-
if not isinstance(image_data, list):
|
52
|
-
image_data = [image_data]
|
53
|
-
|
54
|
-
base_output = self.load_images(
|
55
|
-
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
|
56
|
-
)
|
57
|
-
if base_output is None:
|
58
|
-
return None
|
59
|
-
|
60
|
-
if len(base_output.all_frames) == 0:
|
61
|
-
return None
|
62
|
-
res = await self._process_images(
|
63
|
-
images=base_output.all_frames, input_text=base_output.input_text
|
64
|
-
)
|
65
|
-
|
66
|
-
# Collect special token ids
|
67
|
-
tokenizer = self._processor.tokenizer
|
68
|
-
im_start_id = tokenizer.im_start_id
|
69
|
-
im_end_id = tokenizer.im_end_id
|
70
|
-
if tokenizer.slice_start_id:
|
71
|
-
slice_start_id = tokenizer.slice_start_id
|
72
|
-
slice_end_id = tokenizer.slice_end_id
|
73
|
-
return {
|
74
|
-
"input_ids": res["input_ids"].flatten().tolist(),
|
75
|
-
"pixel_values": res["pixel_values"],
|
76
|
-
"tgt_sizes": res["tgt_sizes"],
|
77
|
-
"image_hashes": base_output.image_hashes,
|
78
|
-
"modalities": request_obj.modalities or ["image"],
|
79
|
-
"im_start_id": im_start_id,
|
80
|
-
"im_end_id": im_end_id,
|
81
|
-
"slice_start_id": slice_start_id,
|
82
|
-
"slice_end_id": slice_end_id,
|
83
|
-
}
|
84
|
-
|
85
|
-
|
86
|
-
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
|
@@ -1,134 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Callable, List, Optional, Tuple
|
3
|
-
|
4
|
-
from sglang.srt.managers.schedule_batch import ImageInputs
|
5
|
-
from sglang.utils import logger
|
6
|
-
|
7
|
-
|
8
|
-
class MultiModalityDataPaddingPattern:
|
9
|
-
"""
|
10
|
-
Data tokens (like image tokens) often need special handling during padding
|
11
|
-
to maintain model compatibility. This class provides the interface for
|
12
|
-
implementing different padding strategies for data tokens
|
13
|
-
"""
|
14
|
-
|
15
|
-
@abstractmethod
|
16
|
-
def pad_input_tokens(
|
17
|
-
self, input_ids: List[int], image_inputs: ImageInputs
|
18
|
-
) -> List[int]:
|
19
|
-
"""
|
20
|
-
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
21
|
-
"""
|
22
|
-
pass
|
23
|
-
|
24
|
-
|
25
|
-
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
26
|
-
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
27
|
-
|
28
|
-
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
29
|
-
"""
|
30
|
-
|
31
|
-
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
32
|
-
self.data_token_id_pairs = data_token_pairs
|
33
|
-
|
34
|
-
def pad_input_tokens(
|
35
|
-
self, input_ids: List[int], image_inputs: ImageInputs
|
36
|
-
) -> List[int]:
|
37
|
-
"""
|
38
|
-
This function will replace the data-tokens inbetween with pad_values accordingly
|
39
|
-
"""
|
40
|
-
pad_values = image_inputs.pad_values
|
41
|
-
data_token_pairs = self.data_token_id_pairs
|
42
|
-
image_inputs.image_offsets = []
|
43
|
-
if data_token_pairs is None:
|
44
|
-
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
45
|
-
if data_token_pairs is None:
|
46
|
-
logger.warning(
|
47
|
-
"No data_token_pairs provided, RadixAttention might be influenced."
|
48
|
-
)
|
49
|
-
return input_ids
|
50
|
-
start_token_ids = [s for s, _e in data_token_pairs]
|
51
|
-
end_tokens_ids = [e for _s, e in data_token_pairs]
|
52
|
-
# First start token marks new data
|
53
|
-
data_start_token = start_token_ids[0]
|
54
|
-
|
55
|
-
padded_ids = []
|
56
|
-
last_idx = 0
|
57
|
-
data_idx = -1
|
58
|
-
|
59
|
-
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
60
|
-
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
61
|
-
|
62
|
-
if len(start_indices) != len(end_indices):
|
63
|
-
return input_ids
|
64
|
-
|
65
|
-
for start_idx, end_idx in zip(start_indices, end_indices):
|
66
|
-
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
67
|
-
|
68
|
-
if input_ids[start_idx] == data_start_token:
|
69
|
-
data_idx += 1
|
70
|
-
image_inputs.image_offsets += [start_idx]
|
71
|
-
|
72
|
-
num_tokens = end_idx - start_idx - 1
|
73
|
-
pad_value = pad_values[data_idx]
|
74
|
-
padded_ids.extend([pad_value] * num_tokens)
|
75
|
-
|
76
|
-
last_idx = end_idx
|
77
|
-
|
78
|
-
padded_ids.extend(input_ids[last_idx:])
|
79
|
-
|
80
|
-
assert len(input_ids) == len(padded_ids)
|
81
|
-
return padded_ids
|
82
|
-
|
83
|
-
|
84
|
-
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
85
|
-
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
86
|
-
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
87
|
-
|
88
|
-
This strategy should be used when a single data token represents content that should
|
89
|
-
be expanded to multiple tokens during processing.
|
90
|
-
"""
|
91
|
-
|
92
|
-
def __init__(
|
93
|
-
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
94
|
-
) -> None:
|
95
|
-
self.num_data_token_calc_func = num_data_token_calc_func
|
96
|
-
|
97
|
-
def pad_input_tokens(
|
98
|
-
self, input_ids: List[int], image_inputs: ImageInputs
|
99
|
-
) -> List[int]:
|
100
|
-
"""
|
101
|
-
This function will follow the procedure of:
|
102
|
-
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
103
|
-
2. the padded data tokens will be replaced with their pad_values
|
104
|
-
"""
|
105
|
-
image_grid_thws = image_inputs.image_grid_thws
|
106
|
-
pad_values = image_inputs.pad_values
|
107
|
-
|
108
|
-
image_indices = [
|
109
|
-
idx
|
110
|
-
for idx, token in enumerate(input_ids)
|
111
|
-
if token == image_inputs.im_token_id
|
112
|
-
]
|
113
|
-
|
114
|
-
image_inputs.image_offsets = []
|
115
|
-
|
116
|
-
input_ids_with_image = []
|
117
|
-
for image_cnt, _ in enumerate(image_grid_thws):
|
118
|
-
print(f"image_cnt {image_cnt}")
|
119
|
-
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
120
|
-
if image_cnt == 0:
|
121
|
-
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
122
|
-
else:
|
123
|
-
non_image_tokens = input_ids[
|
124
|
-
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
125
|
-
]
|
126
|
-
input_ids_with_image.extend(non_image_tokens)
|
127
|
-
image_inputs.image_offsets.append(len(input_ids_with_image))
|
128
|
-
pad_ids = pad_values * (
|
129
|
-
(num_image_tokens + len(pad_values)) // len(pad_values)
|
130
|
-
)
|
131
|
-
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
132
|
-
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
133
|
-
|
134
|
-
return input_ids_with_image
|
File without changes
|
File without changes
|