sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,16 @@ import dataclasses
|
|
4
4
|
import multiprocessing as mp
|
5
5
|
import os
|
6
6
|
from abc import ABC, abstractmethod
|
7
|
-
from typing import Optional
|
7
|
+
from typing import List, Optional
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
11
|
from decord import VideoReader, cpu
|
12
12
|
from PIL import Image
|
13
|
+
from transformers import BaseImageProcessorFast
|
13
14
|
|
14
|
-
from sglang.srt.
|
15
|
+
from sglang.srt.managers.schedule_batch import Modality
|
16
|
+
from sglang.srt.utils import encode_video, load_audio, load_image
|
15
17
|
|
16
18
|
|
17
19
|
@dataclasses.dataclass
|
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
|
|
78
80
|
kwargs["audios"] = audios
|
79
81
|
|
80
82
|
processor = self._processor
|
83
|
+
if hasattr(processor, "image_processor") and isinstance(
|
84
|
+
processor.image_processor, BaseImageProcessorFast
|
85
|
+
):
|
86
|
+
kwargs["device"] = "cuda"
|
81
87
|
result = processor.__call__(
|
82
88
|
text=[input_text],
|
83
89
|
padding=True,
|
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
|
|
111
117
|
|
112
118
|
return estimated_frames_list
|
113
119
|
|
120
|
+
@staticmethod
|
121
|
+
def _load_single_item(
|
122
|
+
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
123
|
+
):
|
124
|
+
"""Static method that can be pickled for multiprocessing"""
|
125
|
+
try:
|
126
|
+
if is_audio:
|
127
|
+
return load_audio(data)
|
128
|
+
elif is_video:
|
129
|
+
path = data[len("video:") :]
|
130
|
+
return encode_video(path, frame_count_limit)
|
131
|
+
else:
|
132
|
+
img, _ = load_image(data)
|
133
|
+
return img.convert("RGB") if discard_alpha_channel else img
|
134
|
+
except Exception as e:
|
135
|
+
raise RuntimeError(f"Error while loading data {data}: {e}")
|
136
|
+
|
137
|
+
def submit_data_loading_tasks(
|
138
|
+
self,
|
139
|
+
text_parts: List[str],
|
140
|
+
multimodal_tokens: MultimodalSpecialTokens,
|
141
|
+
image_data: Optional[list] = None,
|
142
|
+
audio_data: Optional[list] = None,
|
143
|
+
discard_alpha_channel: bool = True,
|
144
|
+
):
|
145
|
+
"""
|
146
|
+
load multimodal data parallelly
|
147
|
+
"""
|
148
|
+
|
149
|
+
# TODO(mick): load from server_args, env, or sampling_params
|
150
|
+
MAX_NUM_FRAMES = 30
|
151
|
+
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
152
|
+
total_frame_count = sum(estimated_frames_list)
|
153
|
+
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
154
|
+
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
155
|
+
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
156
|
+
|
157
|
+
assert len(image_data) == len(estimated_frames_list)
|
158
|
+
# Submit all tasks
|
159
|
+
futures = []
|
160
|
+
task_info = []
|
161
|
+
image_index, audio_index = 0, 0
|
162
|
+
|
163
|
+
for text_part in text_parts:
|
164
|
+
if text_part == multimodal_tokens.image_token:
|
165
|
+
data = image_data[image_index]
|
166
|
+
is_video = isinstance(data, str) and data.startswith("video:")
|
167
|
+
estimated_frames = estimated_frames_list[image_index]
|
168
|
+
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
169
|
+
futures.append(
|
170
|
+
self.io_executor.submit(
|
171
|
+
BaseMultimodalProcessor._load_single_item,
|
172
|
+
data,
|
173
|
+
is_video,
|
174
|
+
False,
|
175
|
+
frame_count_limit,
|
176
|
+
discard_alpha_channel,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
180
|
+
image_index += 1
|
181
|
+
elif text_part == multimodal_tokens.audio_token:
|
182
|
+
data = audio_data[audio_index]
|
183
|
+
futures.append(
|
184
|
+
self.io_executor.submit(
|
185
|
+
BaseMultimodalProcessor._load_single_item,
|
186
|
+
data,
|
187
|
+
False,
|
188
|
+
True,
|
189
|
+
None,
|
190
|
+
discard_alpha_channel,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
task_info.append((Modality.AUDIO, data, None))
|
194
|
+
audio_index += 1
|
195
|
+
|
196
|
+
return futures, task_info
|
197
|
+
|
114
198
|
def load_mm_data(
|
115
199
|
self,
|
116
200
|
prompt: str,
|
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
|
|
155
239
|
# split text into list of normal text and special tokens
|
156
240
|
text_parts = re.split(pattern, prompt)
|
157
241
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
image_index, audio_index = 0, 0
|
169
|
-
hashes, image_sizes, images, audios = [], [], [], []
|
242
|
+
futures, task_info = self.submit_data_loading_tasks(
|
243
|
+
text_parts=text_parts,
|
244
|
+
multimodal_tokens=multimodal_tokens,
|
245
|
+
image_data=image_data,
|
246
|
+
audio_data=audio_data,
|
247
|
+
discard_alpha_channel=discard_alpha_channel,
|
248
|
+
)
|
249
|
+
# Process results
|
250
|
+
image_sizes, images, audios = [], [], []
|
170
251
|
new_text = ""
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
frames = []
|
185
|
-
else:
|
186
|
-
image_file = image_data[image_index]
|
187
|
-
if isinstance(image_file, str) and image_file.startswith(
|
188
|
-
"video:"
|
189
|
-
):
|
190
|
-
# video
|
191
|
-
path = image_file[len("video:") :]
|
192
|
-
frames = encode_video(
|
193
|
-
path, frame_count_limit=frames_to_process
|
194
|
-
)
|
195
|
-
else:
|
196
|
-
# image
|
197
|
-
raw_image, _size = load_image(image_file)
|
198
|
-
if discard_alpha_channel:
|
199
|
-
raw_image = raw_image.convert("RGB")
|
200
|
-
frames = [raw_image]
|
201
|
-
if len(frames) == 0:
|
202
|
-
continue
|
203
|
-
|
204
|
-
image_sizes += frames[0].size * len(frames)
|
205
|
-
|
206
|
-
# Generate a hashable value for the image file
|
207
|
-
if isinstance(image_file, Image.Image):
|
208
|
-
# For PIL.Image objects, use the ID as a hashable value
|
209
|
-
hash_value = hash(id(image_file))
|
210
|
-
else:
|
211
|
-
# For other types (strings, etc.), use the regular hash
|
212
|
-
hash_value = hash(image_file)
|
213
|
-
|
214
|
-
hashes += [hash_value] * len(frames)
|
215
|
-
images += frames
|
216
|
-
image_index += 1
|
217
|
-
if frames_to_process != 0:
|
252
|
+
task_ptr = 0
|
253
|
+
|
254
|
+
for text_part in text_parts:
|
255
|
+
if text_part in multimodal_tokens.collect():
|
256
|
+
task_type, data, frame_limit = task_info[task_ptr]
|
257
|
+
result = futures[task_ptr].result()
|
258
|
+
task_ptr += 1
|
259
|
+
|
260
|
+
if task_type == Modality.IMAGE:
|
261
|
+
frames = [result] if not isinstance(result, list) else result
|
262
|
+
if frames:
|
263
|
+
image_sizes += frames[0].size * len(frames)
|
264
|
+
images += frames
|
218
265
|
new_text += multimodal_tokens.image_token * len(frames)
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
audio_file = audio_data[audio_index]
|
223
|
-
audio = load_audio(audio_file)
|
224
|
-
hashes += [hash(audio_file)]
|
225
|
-
audios += [audio]
|
226
|
-
audio_index += 1
|
266
|
+
elif task_type == Modality.AUDIO:
|
267
|
+
# audio
|
268
|
+
audios.append(result)
|
227
269
|
new_text += multimodal_tokens.audio_token
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
new_text += text_part
|
232
|
-
|
233
|
-
except Exception as e:
|
234
|
-
logger.error(f"An exception occurred while loading images: {e}")
|
235
|
-
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
270
|
+
# TODO: handle video
|
271
|
+
else:
|
272
|
+
new_text += text_part
|
236
273
|
|
237
274
|
out = BaseMultiModalProcessorOutput(
|
238
275
|
images=images,
|
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
33
33
|
base_out = self.load_mm_data(
|
34
34
|
prompt=input_ids,
|
35
35
|
image_data=image_data,
|
36
|
-
multimodal_tokens=MultimodalSpecialTokens(
|
36
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
37
|
+
image_token=processor.image_token
|
38
|
+
),
|
37
39
|
max_req_input_len=max_req_input_len,
|
38
40
|
)
|
39
41
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from typing import List, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from transformers.image_utils import SizeDict
|
5
|
+
from transformers.models.llama4.image_processing_llama4_fast import (
|
6
|
+
find_supported_resolutions,
|
7
|
+
get_best_fit,
|
8
|
+
)
|
9
|
+
|
10
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
11
|
+
BaseMultimodalProcessor,
|
12
|
+
MultimodalSpecialTokens,
|
13
|
+
)
|
14
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
15
|
+
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
16
|
+
|
17
|
+
|
18
|
+
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
19
|
+
models = [Llama4ForConditionalGeneration]
|
20
|
+
|
21
|
+
def __init__(self, hf_config, server_args, _processor):
|
22
|
+
super().__init__(hf_config, server_args, _processor)
|
23
|
+
self.vision_config = hf_config.vision_config
|
24
|
+
self.text_config = hf_config.text_config
|
25
|
+
self.boi_token_index = hf_config.boi_token_index
|
26
|
+
self.eoi_token_index = hf_config.eoi_token_index
|
27
|
+
self.image_token_index = hf_config.image_token_index
|
28
|
+
self.multimodal_tokens = MultimodalSpecialTokens(
|
29
|
+
image_token=_processor.image_token
|
30
|
+
)
|
31
|
+
|
32
|
+
async def process_mm_data_async(
|
33
|
+
self,
|
34
|
+
image_data: List[Union[str, bytes]],
|
35
|
+
input_text,
|
36
|
+
max_req_input_len=None,
|
37
|
+
*args,
|
38
|
+
**kwargs,
|
39
|
+
):
|
40
|
+
if not image_data:
|
41
|
+
return None
|
42
|
+
|
43
|
+
if isinstance(input_text, list):
|
44
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
45
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
46
|
+
|
47
|
+
# Process images and text using the base processor's load_mm_data method
|
48
|
+
processed_data = self.load_mm_data(
|
49
|
+
prompt=input_text,
|
50
|
+
multimodal_tokens=self.multimodal_tokens,
|
51
|
+
max_req_input_len=max_req_input_len or 4096,
|
52
|
+
image_data=image_data,
|
53
|
+
return_text=True,
|
54
|
+
)
|
55
|
+
|
56
|
+
# Process the images using the processor
|
57
|
+
processor = self._processor
|
58
|
+
|
59
|
+
# Process the prompt and images
|
60
|
+
processor_output = self.process_mm_data(
|
61
|
+
input_text=processed_data.input_text,
|
62
|
+
images=processed_data.images,
|
63
|
+
)
|
64
|
+
|
65
|
+
# Handle image resolutions and aspect ratios
|
66
|
+
if "pixel_values" in processor_output:
|
67
|
+
image_processor = processor.image_processor
|
68
|
+
tokenizer = self._processor.tokenizer
|
69
|
+
|
70
|
+
# Calculate tile size and find supported resolutions
|
71
|
+
tile_size = self.vision_config.image_size
|
72
|
+
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
73
|
+
|
74
|
+
possible_resolutions = find_supported_resolutions(
|
75
|
+
max_num_chunks=max_num_tiles,
|
76
|
+
patch_size=SizeDict(height=tile_size, width=tile_size),
|
77
|
+
)
|
78
|
+
|
79
|
+
# Find best fit for each image
|
80
|
+
best_fit_sizes = [
|
81
|
+
get_best_fit(
|
82
|
+
(image.size[1], image.size[0]), # (height, width)
|
83
|
+
torch.tensor(possible_resolutions),
|
84
|
+
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
85
|
+
)
|
86
|
+
for image in processed_data.images
|
87
|
+
]
|
88
|
+
|
89
|
+
# Calculate aspect ratios and patches per image
|
90
|
+
aspect_ratios = [
|
91
|
+
(image_size[0] // tile_size, image_size[1] // tile_size)
|
92
|
+
for image_size in best_fit_sizes
|
93
|
+
]
|
94
|
+
|
95
|
+
patches_per_image = [
|
96
|
+
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
97
|
+
]
|
98
|
+
|
99
|
+
# Add to image_inputs
|
100
|
+
processor_output["aspect_ratios"] = aspect_ratios
|
101
|
+
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
102
|
+
|
103
|
+
# Process embed_is_patch
|
104
|
+
vocab = tokenizer.get_vocab()
|
105
|
+
patch_id = vocab.get(processor.img_patch_token, -1)
|
106
|
+
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
107
|
+
|
108
|
+
if patch_id != -1 and image_end_id != -1:
|
109
|
+
input_ids = processor_output["input_ids"].view(-1)
|
110
|
+
|
111
|
+
# Remove BOS token if present
|
112
|
+
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
113
|
+
input_ids = input_ids[1:]
|
114
|
+
|
115
|
+
# Find image end indices and split input_ids
|
116
|
+
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
117
|
+
|
118
|
+
if image_end_indices.size(0) > 0:
|
119
|
+
# Split at image boundaries
|
120
|
+
split_indices = (image_end_indices + 1)[:-1]
|
121
|
+
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
122
|
+
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
123
|
+
|
124
|
+
# Create embed_is_patch for each image
|
125
|
+
embed_is_patch = []
|
126
|
+
for per_image_input_ids in split_input_ids:
|
127
|
+
embed_is_patch.append(per_image_input_ids == patch_id)
|
128
|
+
|
129
|
+
processor_output["embed_is_patch"] = embed_is_patch
|
130
|
+
|
131
|
+
# Convert to the format expected by SGLang
|
132
|
+
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
133
|
+
|
134
|
+
processor_output["im_start_id"] = self.boi_token_index
|
135
|
+
processor_output["im_end_id"] = self.eoi_token_index
|
136
|
+
processor_output["im_token_id"] = self.image_token_index
|
137
|
+
|
138
|
+
# Add metadata for image processing
|
139
|
+
processor_output["mm_items"] = [
|
140
|
+
MultimodalDataItem(
|
141
|
+
pixel_values=processor_output["pixel_values"],
|
142
|
+
modality=Modality.IMAGE,
|
143
|
+
)
|
144
|
+
]
|
145
|
+
|
146
|
+
return processor_output
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import hashlib
|
3
4
|
from enum import Enum, auto
|
4
5
|
|
5
6
|
# Copyright 2023-2024 SGLang Team
|
@@ -44,7 +45,7 @@ import triton.language as tl
|
|
44
45
|
from sglang.global_config import global_config
|
45
46
|
from sglang.srt.configs.model_config import ModelConfig
|
46
47
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
47
|
-
from sglang.srt.disaggregation.
|
48
|
+
from sglang.srt.disaggregation.base import BaseKVSender
|
48
49
|
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
49
50
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
51
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
@@ -82,6 +83,7 @@ global_server_args_dict = {
|
|
82
83
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
83
84
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
84
85
|
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
86
|
+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
85
87
|
}
|
86
88
|
|
87
89
|
logger = logging.getLogger(__name__)
|
@@ -157,7 +159,7 @@ class Modality(Enum):
|
|
157
159
|
@dataclasses.dataclass
|
158
160
|
class MultimodalDataItem:
|
159
161
|
"""
|
160
|
-
A single multimodal data, from a single image/video/audio or
|
162
|
+
A single multimodal data, from a single image/video/audio or others
|
161
163
|
"""
|
162
164
|
|
163
165
|
modality: Modality
|
@@ -195,17 +197,54 @@ class MultimodalDataItem:
|
|
195
197
|
|
196
198
|
def set_pad_value(self):
|
197
199
|
"""
|
198
|
-
Set the pad value after first
|
200
|
+
Set the pad value after first hashing the data
|
199
201
|
"""
|
200
202
|
|
203
|
+
def data_hash(data) -> int:
|
204
|
+
hash_bytes = hashlib.sha256(data).digest()[:8]
|
205
|
+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
206
|
+
|
207
|
+
def tensor_hash(tensor_list) -> int:
|
208
|
+
"""
|
209
|
+
hash a tensor or a tensor list
|
210
|
+
"""
|
211
|
+
tensor = tensor_list
|
212
|
+
if isinstance(tensor_list, list):
|
213
|
+
tensor_list = flatten_nested_list(tensor_list)
|
214
|
+
tensor_list = [
|
215
|
+
x.flatten() if isinstance(x, torch.Tensor) else x
|
216
|
+
for x in tensor_list
|
217
|
+
]
|
218
|
+
tensor = torch.concat(tensor_list)
|
219
|
+
|
220
|
+
tensor = tensor.detach().contiguous()
|
221
|
+
|
222
|
+
if tensor.dtype == torch.bfloat16:
|
223
|
+
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
224
|
+
tensor = tensor.float()
|
225
|
+
|
226
|
+
assert isinstance(tensor, torch.Tensor)
|
227
|
+
if tensor.is_cuda:
|
228
|
+
# TODO: improve this
|
229
|
+
tensor_cpu = tensor.cpu()
|
230
|
+
else:
|
231
|
+
tensor_cpu = tensor
|
232
|
+
|
233
|
+
mv = memoryview(tensor_cpu.numpy())
|
234
|
+
return data_hash(mv.tobytes())
|
235
|
+
|
201
236
|
def hash_feature(f):
|
202
237
|
if isinstance(f, list):
|
203
|
-
|
238
|
+
if isinstance(f[0], torch.Tensor):
|
239
|
+
return tensor_hash(f)
|
240
|
+
return data_hash(tuple(flatten_nested_list(f)))
|
204
241
|
elif isinstance(f, np.ndarray):
|
205
242
|
arr = np.ascontiguousarray(f)
|
206
243
|
arr_bytes = arr.tobytes()
|
207
|
-
return
|
208
|
-
|
244
|
+
return data_hash(arr_bytes)
|
245
|
+
elif isinstance(f, torch.Tensor):
|
246
|
+
return tensor_hash([f])
|
247
|
+
return data_hash(f)
|
209
248
|
|
210
249
|
if self.is_audio():
|
211
250
|
self.hash = hash_feature(self.audio_features)
|
@@ -230,6 +269,9 @@ class MultimodalDataItem:
|
|
230
269
|
self.modality == Modality.VIDEO
|
231
270
|
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
232
271
|
|
272
|
+
def is_valid(self) -> bool:
|
273
|
+
return self.is_image() or self.is_video() or self.is_audio()
|
274
|
+
|
233
275
|
def validate(self):
|
234
276
|
...
|
235
277
|
# TODO
|
@@ -248,7 +290,7 @@ class MultimodalInputs:
|
|
248
290
|
mrope_position_delta: Optional[torch.Tensor] = None
|
249
291
|
|
250
292
|
# image
|
251
|
-
im_token_id: Optional[
|
293
|
+
im_token_id: Optional[int] = None
|
252
294
|
im_start_id: Optional[int] = None
|
253
295
|
im_end_id: Optional[int] = None
|
254
296
|
slice_start_id: Optional[int] = None
|
@@ -268,11 +310,7 @@ class MultimodalInputs:
|
|
268
310
|
)
|
269
311
|
|
270
312
|
assert isinstance(ret.mm_items, list)
|
271
|
-
ret.mm_items = [
|
272
|
-
item
|
273
|
-
for item in ret.mm_items
|
274
|
-
if item.is_audio() or item.is_image() or item.is_video()
|
275
|
-
]
|
313
|
+
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
276
314
|
|
277
315
|
assert len(ret.mm_items) != 0
|
278
316
|
|
@@ -284,7 +322,6 @@ class MultimodalInputs:
|
|
284
322
|
item.set_pad_value()
|
285
323
|
|
286
324
|
optional_args = [
|
287
|
-
"modalities",
|
288
325
|
"im_token_id",
|
289
326
|
"im_start_id",
|
290
327
|
"im_end_id",
|
@@ -307,8 +344,8 @@ class MultimodalInputs:
|
|
307
344
|
""" """
|
308
345
|
return any(item.is_audio() for item in self.mm_items)
|
309
346
|
|
310
|
-
def
|
311
|
-
return
|
347
|
+
def contains_mm_input(self) -> bool:
|
348
|
+
return any(True for item in self.mm_items if item.is_valid())
|
312
349
|
|
313
350
|
def merge(self, other: MultimodalInputs):
|
314
351
|
"""
|
@@ -322,10 +359,8 @@ class MultimodalInputs:
|
|
322
359
|
|
323
360
|
# args needed to be merged
|
324
361
|
optional_args = [
|
325
|
-
"
|
326
|
-
"image_offsets",
|
362
|
+
"mm_items",
|
327
363
|
"image_pad_len",
|
328
|
-
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
329
364
|
]
|
330
365
|
for arg in optional_args:
|
331
366
|
self_arg = getattr(self, arg, None)
|
@@ -354,6 +389,8 @@ class Req:
|
|
354
389
|
custom_logit_processor: Optional[str] = None,
|
355
390
|
return_hidden_states: bool = False,
|
356
391
|
eos_token_ids: Optional[Set[int]] = None,
|
392
|
+
bootstrap_host: Optional[str] = None,
|
393
|
+
bootstrap_room: Optional[int] = None,
|
357
394
|
):
|
358
395
|
# Input and output info
|
359
396
|
self.rid = rid
|
@@ -438,6 +475,10 @@ class Req:
|
|
438
475
|
self.temp_scaled_logprobs = False
|
439
476
|
self.top_p_normalized_logprobs = False
|
440
477
|
|
478
|
+
# Latency Breakdown
|
479
|
+
self.queue_time_start = None
|
480
|
+
self.queue_time_end = None
|
481
|
+
|
441
482
|
# Logprobs (return values)
|
442
483
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
443
484
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
@@ -483,9 +524,9 @@ class Req:
|
|
483
524
|
self.lora_path = lora_path
|
484
525
|
|
485
526
|
# For disaggregation
|
486
|
-
self.bootstrap_host: str =
|
487
|
-
self.bootstrap_room: Optional[int] =
|
488
|
-
self.disagg_kv_sender: Optional[
|
527
|
+
self.bootstrap_host: str = bootstrap_host
|
528
|
+
self.bootstrap_room: Optional[int] = bootstrap_room
|
529
|
+
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
489
530
|
|
490
531
|
# used for warmup because we don't have a pair yet when init
|
491
532
|
self.skip_kv_transfer: bool = False
|