sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
from PIL import Image
|
6
|
+
|
7
|
+
from sglang.srt.managers.image_processor import BaseImageProcessor
|
8
|
+
from sglang.srt.managers.image_processors.base_image_processor import (
|
9
|
+
get_global_processor,
|
10
|
+
)
|
11
|
+
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
12
|
+
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
13
|
+
|
14
|
+
|
15
|
+
# Compatible with Qwen2VL and Qwen2_5VL
|
16
|
+
class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
17
|
+
def __init__(self, hf_config, server_args, _processor):
|
18
|
+
super().__init__(hf_config, server_args, _processor)
|
19
|
+
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
20
|
+
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
21
|
+
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
22
|
+
self.image_token_id = hf_config.image_token_id
|
23
|
+
self.video_token_id = hf_config.video_token_id
|
24
|
+
self.NUM_TOKEN_PER_FRAME = 770
|
25
|
+
self.IMAGE_FACTOR = 28
|
26
|
+
self.MIN_PIXELS = 4 * 28 * 28
|
27
|
+
self.MAX_PIXELS = 16384 * 28 * 28
|
28
|
+
self.MAX_PIXELS = 16384 * 28 * 28
|
29
|
+
self.MAX_RATIO = 200
|
30
|
+
|
31
|
+
@staticmethod
|
32
|
+
def _process_images_task(images, input_text, _hf_config):
|
33
|
+
if isinstance(images, list) and len(images) == 0:
|
34
|
+
images = None
|
35
|
+
result = get_global_processor().__call__(
|
36
|
+
text=[input_text], images=images, padding=True, return_tensors="pt"
|
37
|
+
)
|
38
|
+
|
39
|
+
return {
|
40
|
+
"input_ids": result.input_ids,
|
41
|
+
"pixel_values": getattr(result, "pixel_values", None),
|
42
|
+
"image_grid_thw": getattr(result, "image_grid_thw", None),
|
43
|
+
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
|
44
|
+
"video_grid_thws": getattr(result, "video_grid_thws", None),
|
45
|
+
}
|
46
|
+
|
47
|
+
async def _process_images(self, images, input_text) -> dict:
|
48
|
+
if self.executor is not None:
|
49
|
+
loop = asyncio.get_event_loop()
|
50
|
+
return await loop.run_in_executor(
|
51
|
+
self.executor,
|
52
|
+
Qwen2_5VLImageProcessor._process_images_task,
|
53
|
+
images,
|
54
|
+
input_text,
|
55
|
+
self.hf_config,
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
return self._process_images_task(images, input_text, self.hf_config)
|
59
|
+
|
60
|
+
async def process_images_async(
|
61
|
+
self,
|
62
|
+
image_data: List[Union[str, bytes]],
|
63
|
+
input_ids,
|
64
|
+
request_obj,
|
65
|
+
max_req_input_len,
|
66
|
+
*args,
|
67
|
+
**kwargs,
|
68
|
+
):
|
69
|
+
if not image_data:
|
70
|
+
return None
|
71
|
+
if isinstance(image_data, str):
|
72
|
+
image_data = [image_data]
|
73
|
+
|
74
|
+
image_token = self.IMAGE_TOKEN
|
75
|
+
base_output = self.load_images(
|
76
|
+
input_ids,
|
77
|
+
image_data,
|
78
|
+
image_token,
|
79
|
+
max_req_input_len,
|
80
|
+
)
|
81
|
+
|
82
|
+
def smart_resize(
|
83
|
+
height: int,
|
84
|
+
width: int,
|
85
|
+
factor: int = self.IMAGE_FACTOR,
|
86
|
+
min_pixels: int = self.MIN_PIXELS,
|
87
|
+
max_pixels: int = self.MAX_PIXELS,
|
88
|
+
) -> tuple[int, int]:
|
89
|
+
"""
|
90
|
+
Rescales the image so that the following conditions are met:
|
91
|
+
|
92
|
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
93
|
+
|
94
|
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
95
|
+
|
96
|
+
3. The aspect ratio of the image is maintained as closely as possible.
|
97
|
+
"""
|
98
|
+
if max(height, width) / min(height, width) > self.MAX_RATIO:
|
99
|
+
raise ValueError(
|
100
|
+
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
101
|
+
)
|
102
|
+
h_bar = max(factor, round_by_factor(height, factor))
|
103
|
+
w_bar = max(factor, round_by_factor(width, factor))
|
104
|
+
if h_bar * w_bar > max_pixels:
|
105
|
+
beta = math.sqrt((height * width) / max_pixels)
|
106
|
+
h_bar = floor_by_factor(height / beta, factor)
|
107
|
+
w_bar = floor_by_factor(width / beta, factor)
|
108
|
+
elif h_bar * w_bar < min_pixels:
|
109
|
+
beta = math.sqrt(min_pixels / (height * width))
|
110
|
+
h_bar = ceil_by_factor(height * beta, factor)
|
111
|
+
w_bar = ceil_by_factor(width * beta, factor)
|
112
|
+
return h_bar, w_bar
|
113
|
+
|
114
|
+
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
|
115
|
+
width, height = image.size
|
116
|
+
min_pixels = self.MIN_PIXELS
|
117
|
+
max_pixels = self.MAX_PIXELS
|
118
|
+
resized_height, resized_width = smart_resize(
|
119
|
+
height,
|
120
|
+
width,
|
121
|
+
factor=size_factor,
|
122
|
+
min_pixels=min_pixels,
|
123
|
+
max_pixels=max_pixels,
|
124
|
+
)
|
125
|
+
image = image.resize((resized_width, resized_height))
|
126
|
+
return image
|
127
|
+
|
128
|
+
def round_by_factor(number: int, factor: int) -> int:
|
129
|
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
130
|
+
return round(number / factor) * factor
|
131
|
+
|
132
|
+
def ceil_by_factor(number: int, factor: int) -> int:
|
133
|
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
134
|
+
return math.ceil(number / factor) * factor
|
135
|
+
|
136
|
+
def floor_by_factor(number: int, factor: int) -> int:
|
137
|
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
138
|
+
return math.floor(number / factor) * factor
|
139
|
+
|
140
|
+
images = [resize_image(image) for image in base_output.all_frames]
|
141
|
+
|
142
|
+
ret = await self._process_images(images, base_output.input_text)
|
143
|
+
return {
|
144
|
+
"input_ids": ret["input_ids"].flatten().tolist(),
|
145
|
+
"pixel_values": ret["pixel_values"],
|
146
|
+
"image_hashes": base_output.image_hashes,
|
147
|
+
"modalities": request_obj.modalities or ["image"],
|
148
|
+
"image_grid_thws": ret["image_grid_thw"],
|
149
|
+
"video_grid_thws": ret["video_grid_thws"],
|
150
|
+
"im_start_id": self.IM_START_TOKEN_ID,
|
151
|
+
"im_end_id": self.IM_END_TOKEN_ID,
|
152
|
+
"im_token_id": self.image_token_id,
|
153
|
+
"video_token_id": self.video_token_id,
|
154
|
+
"second_per_grid_ts": ret["second_per_grid_ts"],
|
155
|
+
}
|
156
|
+
|
157
|
+
|
158
|
+
ImageProcessorMapping = {
|
159
|
+
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
160
|
+
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
161
|
+
}
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput:
|
|
293
293
|
class EmbeddingReqInput:
|
294
294
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
295
295
|
text: Optional[Union[List[str], str]] = None
|
296
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
297
|
+
image_data: Optional[Union[List[str], str]] = None
|
296
298
|
# The token ids for text; one can either specify text or input_ids.
|
297
299
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
298
300
|
# The request id.
|
@@ -303,28 +305,40 @@ class EmbeddingReqInput:
|
|
303
305
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
304
306
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
305
307
|
log_metrics: bool = True
|
308
|
+
# The modalities of the image data [image, multi-images, video]
|
309
|
+
modalities: Optional[List[str]] = None
|
306
310
|
|
307
311
|
def normalize_batch_and_arguments(self):
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
+
# at least one of text, input_ids, or image should be provided
|
313
|
+
if self.text is None and self.input_ids is None and self.image_data is None:
|
314
|
+
raise ValueError(
|
315
|
+
"At least one of text, input_ids, or image should be provided"
|
316
|
+
)
|
317
|
+
|
318
|
+
# text and input_ids cannot be provided at the same time
|
319
|
+
if self.text is not None and self.input_ids is not None:
|
320
|
+
raise ValueError("text and input_ids cannot be provided at the same time")
|
312
321
|
|
313
322
|
# Derive the batch size
|
323
|
+
self.batch_size = 0
|
324
|
+
self.is_single = True
|
325
|
+
|
326
|
+
# check the batch size of text
|
314
327
|
if self.text is not None:
|
315
|
-
if isinstance(self.text,
|
316
|
-
self.
|
317
|
-
self.batch_size = 1
|
328
|
+
if isinstance(self.text, list):
|
329
|
+
self.batch_size += len(self.text)
|
318
330
|
else:
|
319
|
-
self.
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
self.batch_size
|
331
|
+
self.batch_size += 1
|
332
|
+
|
333
|
+
# check the batch size of input_ids
|
334
|
+
if self.input_ids is not None:
|
335
|
+
if isinstance(self.input_ids[0], list):
|
336
|
+
self.batch_size += len(self.input_ids)
|
325
337
|
else:
|
326
|
-
self.
|
327
|
-
|
338
|
+
self.batch_size += 1
|
339
|
+
|
340
|
+
if self.batch_size > 1:
|
341
|
+
self.is_single = False
|
328
342
|
|
329
343
|
# Fill in default arguments
|
330
344
|
if self.is_single:
|
@@ -352,6 +366,7 @@ class EmbeddingReqInput:
|
|
352
366
|
return EmbeddingReqInput(
|
353
367
|
text=self.text[i] if self.text is not None else None,
|
354
368
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
369
|
+
image_data=self.image_data[i] if self.image_data is not None else None,
|
355
370
|
sampling_params=self.sampling_params[i],
|
356
371
|
rid=self.rid[i],
|
357
372
|
)
|
@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput:
|
|
365
380
|
input_text: str
|
366
381
|
# The input token ids
|
367
382
|
input_ids: List[int]
|
383
|
+
# The image inputs
|
384
|
+
image_inputs: dict
|
368
385
|
# Dummy sampling params for compatibility
|
369
386
|
sampling_params: SamplingParams
|
370
387
|
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Callable, List, Optional, Tuple
|
3
|
+
|
4
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
5
|
+
from sglang.utils import logger
|
6
|
+
|
7
|
+
|
8
|
+
class MultiModalityDataPaddingPattern:
|
9
|
+
"""
|
10
|
+
Data tokens (like image tokens) often need special handling during padding
|
11
|
+
to maintain model compatibility. This class provides the interface for
|
12
|
+
implementing different padding strategies for data tokens
|
13
|
+
"""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def pad_input_tokens(
|
17
|
+
self, input_ids: List[int], image_inputs: ImageInputs
|
18
|
+
) -> List[int]:
|
19
|
+
"""
|
20
|
+
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
21
|
+
"""
|
22
|
+
pass
|
23
|
+
|
24
|
+
|
25
|
+
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
26
|
+
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
27
|
+
|
28
|
+
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
32
|
+
self.data_token_id_pairs = data_token_pairs
|
33
|
+
|
34
|
+
def pad_input_tokens(
|
35
|
+
self, input_ids: List[int], image_inputs: ImageInputs
|
36
|
+
) -> List[int]:
|
37
|
+
"""
|
38
|
+
This function will replace the data-tokens inbetween with pad_values accordingly
|
39
|
+
"""
|
40
|
+
pad_values = image_inputs.pad_values
|
41
|
+
data_token_pairs = self.data_token_id_pairs
|
42
|
+
image_inputs.image_offsets = []
|
43
|
+
if data_token_pairs is None:
|
44
|
+
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
45
|
+
if data_token_pairs is None:
|
46
|
+
logger.warning(
|
47
|
+
"No data_token_pairs provided, RadixAttention might be influenced."
|
48
|
+
)
|
49
|
+
return input_ids
|
50
|
+
start_token_ids = [s for s, _e in data_token_pairs]
|
51
|
+
end_tokens_ids = [e for _s, e in data_token_pairs]
|
52
|
+
# First start token marks new data
|
53
|
+
data_start_token = start_token_ids[0]
|
54
|
+
|
55
|
+
padded_ids = []
|
56
|
+
last_idx = 0
|
57
|
+
data_idx = -1
|
58
|
+
|
59
|
+
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
60
|
+
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
61
|
+
|
62
|
+
if len(start_indices) != len(end_indices):
|
63
|
+
return input_ids
|
64
|
+
|
65
|
+
for start_idx, end_idx in zip(start_indices, end_indices):
|
66
|
+
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
67
|
+
|
68
|
+
if input_ids[start_idx] == data_start_token:
|
69
|
+
data_idx += 1
|
70
|
+
image_inputs.image_offsets += [start_idx]
|
71
|
+
|
72
|
+
num_tokens = end_idx - start_idx - 1
|
73
|
+
pad_value = pad_values[data_idx]
|
74
|
+
padded_ids.extend([pad_value] * num_tokens)
|
75
|
+
|
76
|
+
last_idx = end_idx
|
77
|
+
|
78
|
+
padded_ids.extend(input_ids[last_idx:])
|
79
|
+
|
80
|
+
assert len(input_ids) == len(padded_ids)
|
81
|
+
return padded_ids
|
82
|
+
|
83
|
+
|
84
|
+
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
85
|
+
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
86
|
+
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
87
|
+
|
88
|
+
This strategy should be used when a single data token represents content that should
|
89
|
+
be expanded to multiple tokens during processing.
|
90
|
+
"""
|
91
|
+
|
92
|
+
def __init__(
|
93
|
+
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
94
|
+
) -> None:
|
95
|
+
self.num_data_token_calc_func = num_data_token_calc_func
|
96
|
+
|
97
|
+
def pad_input_tokens(
|
98
|
+
self, input_ids: List[int], image_inputs: ImageInputs
|
99
|
+
) -> List[int]:
|
100
|
+
"""
|
101
|
+
This function will follow the procedure of:
|
102
|
+
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
103
|
+
2. the padded data tokens will be replaced with their pad_values
|
104
|
+
"""
|
105
|
+
image_grid_thws = image_inputs.image_grid_thws
|
106
|
+
pad_values = image_inputs.pad_values
|
107
|
+
|
108
|
+
image_indices = [
|
109
|
+
idx
|
110
|
+
for idx, token in enumerate(input_ids)
|
111
|
+
if token == image_inputs.im_token_id
|
112
|
+
]
|
113
|
+
|
114
|
+
image_inputs.image_offsets = []
|
115
|
+
|
116
|
+
input_ids_with_image = []
|
117
|
+
for image_cnt, _ in enumerate(image_grid_thws):
|
118
|
+
print(f"image_cnt {image_cnt}")
|
119
|
+
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
120
|
+
if image_cnt == 0:
|
121
|
+
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
122
|
+
else:
|
123
|
+
non_image_tokens = input_ids[
|
124
|
+
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
125
|
+
]
|
126
|
+
input_ids_with_image.extend(non_image_tokens)
|
127
|
+
image_inputs.image_offsets.append(len(input_ids_with_image))
|
128
|
+
pad_ids = pad_values * (
|
129
|
+
(num_image_tokens + len(pad_values)) // len(pad_values)
|
130
|
+
)
|
131
|
+
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
132
|
+
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
133
|
+
|
134
|
+
return input_ids_with_image
|