sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,14 @@ import dataclasses
|
|
4
4
|
import multiprocessing as mp
|
5
5
|
import os
|
6
6
|
from abc import ABC, abstractmethod
|
7
|
-
from typing import Optional
|
7
|
+
from typing import List, Optional
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
-
from
|
12
|
-
from PIL import Image
|
11
|
+
from transformers import BaseImageProcessorFast
|
13
12
|
|
14
|
-
from sglang.srt.
|
13
|
+
from sglang.srt.managers.schedule_batch import Modality
|
14
|
+
from sglang.srt.utils import encode_video, load_audio, load_image
|
15
15
|
|
16
16
|
|
17
17
|
@dataclasses.dataclass
|
@@ -78,6 +78,10 @@ class BaseMultimodalProcessor(ABC):
|
|
78
78
|
kwargs["audios"] = audios
|
79
79
|
|
80
80
|
processor = self._processor
|
81
|
+
if hasattr(processor, "image_processor") and isinstance(
|
82
|
+
processor.image_processor, BaseImageProcessorFast
|
83
|
+
):
|
84
|
+
kwargs["device"] = "cuda"
|
81
85
|
result = processor.__call__(
|
82
86
|
text=[input_text],
|
83
87
|
padding=True,
|
@@ -96,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
|
|
96
100
|
"""
|
97
101
|
estimate the total frame count from all visual input
|
98
102
|
"""
|
103
|
+
# Lazy import because decord is not available on some arm platforms.
|
104
|
+
from decord import VideoReader, cpu
|
105
|
+
|
99
106
|
# Before processing inputs
|
100
107
|
estimated_frames_list = []
|
101
108
|
for image in image_data:
|
@@ -111,6 +118,84 @@ class BaseMultimodalProcessor(ABC):
|
|
111
118
|
|
112
119
|
return estimated_frames_list
|
113
120
|
|
121
|
+
@staticmethod
|
122
|
+
def _load_single_item(
|
123
|
+
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
124
|
+
):
|
125
|
+
"""Static method that can be pickled for multiprocessing"""
|
126
|
+
try:
|
127
|
+
if is_audio:
|
128
|
+
return load_audio(data)
|
129
|
+
elif is_video:
|
130
|
+
path = data[len("video:") :]
|
131
|
+
return encode_video(path, frame_count_limit)
|
132
|
+
else:
|
133
|
+
img, _ = load_image(data)
|
134
|
+
return img.convert("RGB") if discard_alpha_channel else img
|
135
|
+
except Exception as e:
|
136
|
+
raise RuntimeError(f"Error while loading data {data}: {e}")
|
137
|
+
|
138
|
+
def submit_data_loading_tasks(
|
139
|
+
self,
|
140
|
+
text_parts: List[str],
|
141
|
+
multimodal_tokens: MultimodalSpecialTokens,
|
142
|
+
image_data: Optional[list] = None,
|
143
|
+
audio_data: Optional[list] = None,
|
144
|
+
discard_alpha_channel: bool = True,
|
145
|
+
):
|
146
|
+
"""
|
147
|
+
load multimodal data parallelly
|
148
|
+
"""
|
149
|
+
|
150
|
+
# TODO(mick): load from server_args, env, or sampling_params
|
151
|
+
MAX_NUM_FRAMES = 30
|
152
|
+
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
153
|
+
total_frame_count = sum(estimated_frames_list)
|
154
|
+
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
155
|
+
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
156
|
+
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
157
|
+
|
158
|
+
assert len(image_data) == len(estimated_frames_list)
|
159
|
+
# Submit all tasks
|
160
|
+
futures = []
|
161
|
+
task_info = []
|
162
|
+
image_index, audio_index = 0, 0
|
163
|
+
|
164
|
+
for text_part in text_parts:
|
165
|
+
if text_part == multimodal_tokens.image_token:
|
166
|
+
data = image_data[image_index]
|
167
|
+
is_video = isinstance(data, str) and data.startswith("video:")
|
168
|
+
estimated_frames = estimated_frames_list[image_index]
|
169
|
+
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
170
|
+
futures.append(
|
171
|
+
self.io_executor.submit(
|
172
|
+
BaseMultimodalProcessor._load_single_item,
|
173
|
+
data,
|
174
|
+
is_video,
|
175
|
+
False,
|
176
|
+
frame_count_limit,
|
177
|
+
discard_alpha_channel,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
181
|
+
image_index += 1
|
182
|
+
elif text_part == multimodal_tokens.audio_token:
|
183
|
+
data = audio_data[audio_index]
|
184
|
+
futures.append(
|
185
|
+
self.io_executor.submit(
|
186
|
+
BaseMultimodalProcessor._load_single_item,
|
187
|
+
data,
|
188
|
+
False,
|
189
|
+
True,
|
190
|
+
None,
|
191
|
+
discard_alpha_channel,
|
192
|
+
)
|
193
|
+
)
|
194
|
+
task_info.append((Modality.AUDIO, data, None))
|
195
|
+
audio_index += 1
|
196
|
+
|
197
|
+
return futures, task_info
|
198
|
+
|
114
199
|
def load_mm_data(
|
115
200
|
self,
|
116
201
|
prompt: str,
|
@@ -155,84 +240,37 @@ class BaseMultimodalProcessor(ABC):
|
|
155
240
|
# split text into list of normal text and special tokens
|
156
241
|
text_parts = re.split(pattern, prompt)
|
157
242
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
image_index, audio_index = 0, 0
|
169
|
-
hashes, image_sizes, images, audios = [], [], [], []
|
243
|
+
futures, task_info = self.submit_data_loading_tasks(
|
244
|
+
text_parts=text_parts,
|
245
|
+
multimodal_tokens=multimodal_tokens,
|
246
|
+
image_data=image_data,
|
247
|
+
audio_data=audio_data,
|
248
|
+
discard_alpha_channel=discard_alpha_channel,
|
249
|
+
)
|
250
|
+
# Process results
|
251
|
+
image_sizes, images, audios = [], [], []
|
170
252
|
new_text = ""
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
frames = []
|
185
|
-
else:
|
186
|
-
image_file = image_data[image_index]
|
187
|
-
if isinstance(image_file, str) and image_file.startswith(
|
188
|
-
"video:"
|
189
|
-
):
|
190
|
-
# video
|
191
|
-
path = image_file[len("video:") :]
|
192
|
-
frames = encode_video(
|
193
|
-
path, frame_count_limit=frames_to_process
|
194
|
-
)
|
195
|
-
else:
|
196
|
-
# image
|
197
|
-
raw_image, _size = load_image(image_file)
|
198
|
-
if discard_alpha_channel:
|
199
|
-
raw_image = raw_image.convert("RGB")
|
200
|
-
frames = [raw_image]
|
201
|
-
if len(frames) == 0:
|
202
|
-
continue
|
203
|
-
|
204
|
-
image_sizes += frames[0].size * len(frames)
|
205
|
-
|
206
|
-
# Generate a hashable value for the image file
|
207
|
-
if isinstance(image_file, Image.Image):
|
208
|
-
# For PIL.Image objects, use the ID as a hashable value
|
209
|
-
hash_value = hash(id(image_file))
|
210
|
-
else:
|
211
|
-
# For other types (strings, etc.), use the regular hash
|
212
|
-
hash_value = hash(image_file)
|
213
|
-
|
214
|
-
hashes += [hash_value] * len(frames)
|
215
|
-
images += frames
|
216
|
-
image_index += 1
|
217
|
-
if frames_to_process != 0:
|
253
|
+
task_ptr = 0
|
254
|
+
|
255
|
+
for text_part in text_parts:
|
256
|
+
if text_part in multimodal_tokens.collect():
|
257
|
+
task_type, data, frame_limit = task_info[task_ptr]
|
258
|
+
result = futures[task_ptr].result()
|
259
|
+
task_ptr += 1
|
260
|
+
|
261
|
+
if task_type == Modality.IMAGE:
|
262
|
+
frames = [result] if not isinstance(result, list) else result
|
263
|
+
if frames:
|
264
|
+
image_sizes += frames[0].size * len(frames)
|
265
|
+
images += frames
|
218
266
|
new_text += multimodal_tokens.image_token * len(frames)
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
audio_file = audio_data[audio_index]
|
223
|
-
audio = load_audio(audio_file)
|
224
|
-
hashes += [hash(audio_file)]
|
225
|
-
audios += [audio]
|
226
|
-
audio_index += 1
|
267
|
+
elif task_type == Modality.AUDIO:
|
268
|
+
# audio
|
269
|
+
audios.append(result)
|
227
270
|
new_text += multimodal_tokens.audio_token
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
new_text += text_part
|
232
|
-
|
233
|
-
except Exception as e:
|
234
|
-
logger.error(f"An exception occurred while loading images: {e}")
|
235
|
-
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
271
|
+
# TODO: handle video
|
272
|
+
else:
|
273
|
+
new_text += text_part
|
236
274
|
|
237
275
|
out = BaseMultiModalProcessorOutput(
|
238
276
|
images=images,
|
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
33
33
|
base_out = self.load_mm_data(
|
34
34
|
prompt=input_ids,
|
35
35
|
image_data=image_data,
|
36
|
-
multimodal_tokens=MultimodalSpecialTokens(
|
36
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
37
|
+
image_token=processor.image_token
|
38
|
+
),
|
37
39
|
max_req_input_len=max_req_input_len,
|
38
40
|
)
|
39
41
|
|
@@ -1,10 +1,8 @@
|
|
1
|
-
from typing import List,
|
1
|
+
from typing import List, Union
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from PIL import Image
|
5
|
-
from transformers import Llama4Processor
|
6
4
|
from transformers.image_utils import SizeDict
|
7
|
-
from transformers.models.llama4.
|
5
|
+
from transformers.models.llama4.image_processing_llama4_fast import (
|
8
6
|
find_supported_resolutions,
|
9
7
|
get_best_fit,
|
10
8
|
)
|
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
|
15
13
|
)
|
16
14
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
17
15
|
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
18
|
-
from sglang.srt.utils import load_image
|
19
16
|
|
20
17
|
|
21
18
|
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
25
22
|
super().__init__(hf_config, server_args, _processor)
|
26
23
|
self.vision_config = hf_config.vision_config
|
27
24
|
self.text_config = hf_config.text_config
|
25
|
+
self.boi_token_index = hf_config.boi_token_index
|
26
|
+
self.eoi_token_index = hf_config.eoi_token_index
|
27
|
+
self.image_token_index = hf_config.image_token_index
|
28
28
|
self.multimodal_tokens = MultimodalSpecialTokens(
|
29
29
|
image_token=_processor.image_token
|
30
30
|
)
|
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
54
54
|
)
|
55
55
|
|
56
56
|
# Process the images using the processor
|
57
|
-
processor =
|
58
|
-
self.server_args.model_path, **kwargs
|
59
|
-
)
|
57
|
+
processor = self._processor
|
60
58
|
|
61
59
|
# Process the prompt and images
|
62
|
-
|
63
|
-
|
60
|
+
processor_output = self.process_mm_data(
|
61
|
+
input_text=processed_data.input_text,
|
64
62
|
images=processed_data.images,
|
65
|
-
return_tensors="pt",
|
66
63
|
)
|
67
64
|
|
68
65
|
# Handle image resolutions and aspect ratios
|
69
|
-
if "pixel_values" in
|
66
|
+
if "pixel_values" in processor_output:
|
70
67
|
image_processor = processor.image_processor
|
71
68
|
tokenizer = self._processor.tokenizer
|
72
69
|
|
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
100
97
|
]
|
101
98
|
|
102
99
|
# Add to image_inputs
|
103
|
-
|
104
|
-
|
100
|
+
processor_output["aspect_ratios"] = aspect_ratios
|
101
|
+
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
105
102
|
|
106
103
|
# Process embed_is_patch
|
107
104
|
vocab = tokenizer.get_vocab()
|
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
109
106
|
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
110
107
|
|
111
108
|
if patch_id != -1 and image_end_id != -1:
|
112
|
-
input_ids =
|
109
|
+
input_ids = processor_output["input_ids"].view(-1)
|
113
110
|
|
114
111
|
# Remove BOS token if present
|
115
112
|
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
129
126
|
for per_image_input_ids in split_input_ids:
|
130
127
|
embed_is_patch.append(per_image_input_ids == patch_id)
|
131
128
|
|
132
|
-
|
129
|
+
processor_output["embed_is_patch"] = embed_is_patch
|
133
130
|
|
134
131
|
# Convert to the format expected by SGLang
|
135
|
-
|
132
|
+
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
133
|
+
|
134
|
+
processor_output["im_start_id"] = self.boi_token_index
|
135
|
+
processor_output["im_end_id"] = self.eoi_token_index
|
136
|
+
processor_output["im_token_id"] = self.image_token_index
|
136
137
|
|
137
138
|
# Add metadata for image processing
|
138
|
-
|
139
|
+
processor_output["mm_items"] = [
|
139
140
|
MultimodalDataItem(
|
140
|
-
pixel_values=
|
141
|
+
pixel_values=processor_output["pixel_values"],
|
141
142
|
modality=Modality.IMAGE,
|
142
|
-
# Add additional metadata needed for Llama4 vision processing
|
143
|
-
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
144
|
-
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
145
|
-
patches_per_image=image_inputs.get("patches_per_image", None),
|
146
143
|
)
|
147
144
|
]
|
148
145
|
|
149
|
-
return
|
150
|
-
|
151
|
-
def get_patch_per_chunk(self):
|
152
|
-
"""Calculate patches per chunk based on vision config"""
|
153
|
-
image_size = self.vision_config.image_size
|
154
|
-
patch_size = self.vision_config.patch_size
|
155
|
-
|
156
|
-
assert (
|
157
|
-
image_size % patch_size == 0
|
158
|
-
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
159
|
-
|
160
|
-
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
161
|
-
return (image_size // patch_size) ** 2 // ds_ratio
|
146
|
+
return processor_output
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import hashlib
|
3
4
|
from enum import Enum, auto
|
4
5
|
|
5
6
|
# Copyright 2023-2024 SGLang Team
|
@@ -44,7 +45,7 @@ import triton.language as tl
|
|
44
45
|
from sglang.global_config import global_config
|
45
46
|
from sglang.srt.configs.model_config import ModelConfig
|
46
47
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
47
|
-
from sglang.srt.disaggregation.
|
48
|
+
from sglang.srt.disaggregation.base import BaseKVSender
|
48
49
|
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
49
50
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
51
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
@@ -66,7 +67,6 @@ global_server_args_dict = {
|
|
66
67
|
"attention_backend": ServerArgs.attention_backend,
|
67
68
|
"sampling_backend": ServerArgs.sampling_backend,
|
68
69
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
69
|
-
"disable_mla": ServerArgs.disable_mla,
|
70
70
|
"torchao_config": ServerArgs.torchao_config,
|
71
71
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
72
72
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
@@ -76,12 +76,12 @@ global_server_args_dict = {
|
|
76
76
|
"device": ServerArgs.device,
|
77
77
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
78
78
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
79
|
-
"enable_flashmla": ServerArgs.enable_flashmla,
|
80
79
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
81
80
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
81
|
+
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
82
82
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
83
83
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
84
|
-
"
|
84
|
+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
85
85
|
}
|
86
86
|
|
87
87
|
logger = logging.getLogger(__name__)
|
@@ -157,7 +157,7 @@ class Modality(Enum):
|
|
157
157
|
@dataclasses.dataclass
|
158
158
|
class MultimodalDataItem:
|
159
159
|
"""
|
160
|
-
A single multimodal data, from a single image/video/audio or
|
160
|
+
A single multimodal data, from a single image/video/audio or others
|
161
161
|
"""
|
162
162
|
|
163
163
|
modality: Modality
|
@@ -195,17 +195,54 @@ class MultimodalDataItem:
|
|
195
195
|
|
196
196
|
def set_pad_value(self):
|
197
197
|
"""
|
198
|
-
Set the pad value after first
|
198
|
+
Set the pad value after first hashing the data
|
199
199
|
"""
|
200
200
|
|
201
|
+
def data_hash(data) -> int:
|
202
|
+
hash_bytes = hashlib.sha256(data).digest()[:8]
|
203
|
+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
204
|
+
|
205
|
+
def tensor_hash(tensor_list) -> int:
|
206
|
+
"""
|
207
|
+
hash a tensor or a tensor list
|
208
|
+
"""
|
209
|
+
tensor = tensor_list
|
210
|
+
if isinstance(tensor_list, list):
|
211
|
+
tensor_list = flatten_nested_list(tensor_list)
|
212
|
+
tensor_list = [
|
213
|
+
x.flatten() if isinstance(x, torch.Tensor) else x
|
214
|
+
for x in tensor_list
|
215
|
+
]
|
216
|
+
tensor = torch.concat(tensor_list)
|
217
|
+
|
218
|
+
tensor = tensor.detach().contiguous()
|
219
|
+
|
220
|
+
if tensor.dtype == torch.bfloat16:
|
221
|
+
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
222
|
+
tensor = tensor.float()
|
223
|
+
|
224
|
+
assert isinstance(tensor, torch.Tensor)
|
225
|
+
if tensor.is_cuda:
|
226
|
+
# TODO: improve this
|
227
|
+
tensor_cpu = tensor.cpu()
|
228
|
+
else:
|
229
|
+
tensor_cpu = tensor
|
230
|
+
|
231
|
+
mv = memoryview(tensor_cpu.numpy())
|
232
|
+
return data_hash(mv.tobytes())
|
233
|
+
|
201
234
|
def hash_feature(f):
|
202
235
|
if isinstance(f, list):
|
203
|
-
|
236
|
+
if isinstance(f[0], torch.Tensor):
|
237
|
+
return tensor_hash(f)
|
238
|
+
return data_hash(tuple(flatten_nested_list(f)))
|
204
239
|
elif isinstance(f, np.ndarray):
|
205
240
|
arr = np.ascontiguousarray(f)
|
206
241
|
arr_bytes = arr.tobytes()
|
207
|
-
return
|
208
|
-
|
242
|
+
return data_hash(arr_bytes)
|
243
|
+
elif isinstance(f, torch.Tensor):
|
244
|
+
return tensor_hash([f])
|
245
|
+
return data_hash(f)
|
209
246
|
|
210
247
|
if self.is_audio():
|
211
248
|
self.hash = hash_feature(self.audio_features)
|
@@ -230,6 +267,9 @@ class MultimodalDataItem:
|
|
230
267
|
self.modality == Modality.VIDEO
|
231
268
|
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
232
269
|
|
270
|
+
def is_valid(self) -> bool:
|
271
|
+
return self.is_image() or self.is_video() or self.is_audio()
|
272
|
+
|
233
273
|
def validate(self):
|
234
274
|
...
|
235
275
|
# TODO
|
@@ -248,7 +288,7 @@ class MultimodalInputs:
|
|
248
288
|
mrope_position_delta: Optional[torch.Tensor] = None
|
249
289
|
|
250
290
|
# image
|
251
|
-
im_token_id: Optional[
|
291
|
+
im_token_id: Optional[int] = None
|
252
292
|
im_start_id: Optional[int] = None
|
253
293
|
im_end_id: Optional[int] = None
|
254
294
|
slice_start_id: Optional[int] = None
|
@@ -268,11 +308,7 @@ class MultimodalInputs:
|
|
268
308
|
)
|
269
309
|
|
270
310
|
assert isinstance(ret.mm_items, list)
|
271
|
-
ret.mm_items = [
|
272
|
-
item
|
273
|
-
for item in ret.mm_items
|
274
|
-
if item.is_audio() or item.is_image() or item.is_video()
|
275
|
-
]
|
311
|
+
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
276
312
|
|
277
313
|
assert len(ret.mm_items) != 0
|
278
314
|
|
@@ -284,7 +320,6 @@ class MultimodalInputs:
|
|
284
320
|
item.set_pad_value()
|
285
321
|
|
286
322
|
optional_args = [
|
287
|
-
"modalities",
|
288
323
|
"im_token_id",
|
289
324
|
"im_start_id",
|
290
325
|
"im_end_id",
|
@@ -307,8 +342,8 @@ class MultimodalInputs:
|
|
307
342
|
""" """
|
308
343
|
return any(item.is_audio() for item in self.mm_items)
|
309
344
|
|
310
|
-
def
|
311
|
-
return
|
345
|
+
def contains_mm_input(self) -> bool:
|
346
|
+
return any(True for item in self.mm_items if item.is_valid())
|
312
347
|
|
313
348
|
def merge(self, other: MultimodalInputs):
|
314
349
|
"""
|
@@ -322,10 +357,8 @@ class MultimodalInputs:
|
|
322
357
|
|
323
358
|
# args needed to be merged
|
324
359
|
optional_args = [
|
325
|
-
"
|
326
|
-
"image_offsets",
|
360
|
+
"mm_items",
|
327
361
|
"image_pad_len",
|
328
|
-
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
329
362
|
]
|
330
363
|
for arg in optional_args:
|
331
364
|
self_arg = getattr(self, arg, None)
|
@@ -354,6 +387,8 @@ class Req:
|
|
354
387
|
custom_logit_processor: Optional[str] = None,
|
355
388
|
return_hidden_states: bool = False,
|
356
389
|
eos_token_ids: Optional[Set[int]] = None,
|
390
|
+
bootstrap_host: Optional[str] = None,
|
391
|
+
bootstrap_room: Optional[int] = None,
|
357
392
|
):
|
358
393
|
# Input and output info
|
359
394
|
self.rid = rid
|
@@ -438,6 +473,10 @@ class Req:
|
|
438
473
|
self.temp_scaled_logprobs = False
|
439
474
|
self.top_p_normalized_logprobs = False
|
440
475
|
|
476
|
+
# Latency Breakdown
|
477
|
+
self.queue_time_start = None
|
478
|
+
self.queue_time_end = None
|
479
|
+
|
441
480
|
# Logprobs (return values)
|
442
481
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
443
482
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
@@ -483,9 +522,9 @@ class Req:
|
|
483
522
|
self.lora_path = lora_path
|
484
523
|
|
485
524
|
# For disaggregation
|
486
|
-
self.bootstrap_host: str =
|
487
|
-
self.bootstrap_room: Optional[int] =
|
488
|
-
self.disagg_kv_sender: Optional[
|
525
|
+
self.bootstrap_host: str = bootstrap_host
|
526
|
+
self.bootstrap_room: Optional[int] = bootstrap_room
|
527
|
+
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
489
528
|
|
490
529
|
# used for warmup because we don't have a pair yet when init
|
491
530
|
self.skip_kv_transfer: bool = False
|
@@ -1440,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1440
1479
|
global_server_args_dict["use_mla_backend"]
|
1441
1480
|
and global_server_args_dict["attention_backend"] == "flashinfer"
|
1442
1481
|
)
|
1443
|
-
or global_server_args_dict["
|
1482
|
+
or global_server_args_dict["attention_backend"] == "flashmla"
|
1444
1483
|
or global_server_args_dict["attention_backend"] == "fa3"
|
1445
1484
|
):
|
1446
1485
|
seq_lens_cpu = self.seq_lens.cpu()
|