sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -23,13 +23,16 @@ import psutil
|
|
23
23
|
import setproctitle
|
24
24
|
import zmq
|
25
25
|
|
26
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
26
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
27
28
|
from sglang.srt.managers.io_struct import (
|
28
29
|
TokenizedEmbeddingReqInput,
|
29
30
|
TokenizedGenerateReqInput,
|
30
31
|
)
|
32
|
+
from sglang.srt.managers.schedule_batch import Req
|
31
33
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
32
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
35
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
33
36
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
34
37
|
from sglang.utils import get_exception_traceback
|
35
38
|
|
@@ -174,6 +177,10 @@ class DataParallelController:
|
|
174
177
|
if not server_args.enable_dp_attention:
|
175
178
|
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
176
179
|
|
180
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
181
|
+
enable=server_args.enable_memory_saver
|
182
|
+
)
|
183
|
+
|
177
184
|
# Launch tensor parallel scheduler processes
|
178
185
|
scheduler_pipe_readers = []
|
179
186
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
@@ -208,7 +215,8 @@ class DataParallelController:
|
|
208
215
|
target=run_scheduler_process,
|
209
216
|
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
210
217
|
)
|
211
|
-
|
218
|
+
with memory_saver_adapter.configure_subprocess():
|
219
|
+
proc.start()
|
212
220
|
self.scheduler_procs.append(proc)
|
213
221
|
scheduler_pipe_readers.append(reader)
|
214
222
|
|
@@ -220,9 +228,14 @@ class DataParallelController:
|
|
220
228
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
221
229
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
222
230
|
|
223
|
-
def round_robin_scheduler(self, req):
|
224
|
-
self.
|
225
|
-
|
231
|
+
def round_robin_scheduler(self, req: Req):
|
232
|
+
if self.server_args.disaggregation_mode == "null":
|
233
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
234
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
235
|
+
self.workers
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
226
239
|
|
227
240
|
def shortest_queue_scheduler(self, input_requests):
|
228
241
|
raise NotImplementedError()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -96,8 +96,9 @@ class GenerateReqInput:
|
|
96
96
|
return_hidden_states: bool = False
|
97
97
|
|
98
98
|
# For disaggregated inference
|
99
|
-
bootstrap_host: Optional[str] = None
|
100
|
-
|
99
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
100
|
+
bootstrap_port: Optional[Union[List[int], int]] = None
|
101
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
101
102
|
|
102
103
|
def normalize_batch_and_arguments(self):
|
103
104
|
"""
|
@@ -397,6 +398,15 @@ class GenerateReqInput:
|
|
397
398
|
else None
|
398
399
|
),
|
399
400
|
return_hidden_states=self.return_hidden_states,
|
401
|
+
bootstrap_host=(
|
402
|
+
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
403
|
+
),
|
404
|
+
bootstrap_port=(
|
405
|
+
self.bootstrap_port[i] if self.bootstrap_port is not None else None
|
406
|
+
),
|
407
|
+
bootstrap_room=(
|
408
|
+
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
409
|
+
),
|
400
410
|
)
|
401
411
|
|
402
412
|
|
@@ -441,6 +451,7 @@ class TokenizedGenerateReqInput:
|
|
441
451
|
|
442
452
|
# For disaggregated inference
|
443
453
|
bootstrap_host: Optional[str] = None
|
454
|
+
bootstrap_port: Optional[int] = None
|
444
455
|
bootstrap_room: Optional[int] = None
|
445
456
|
|
446
457
|
|
@@ -457,6 +468,8 @@ class EmbeddingReqInput:
|
|
457
468
|
image_data: Optional[
|
458
469
|
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
459
470
|
] = None
|
471
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
472
|
+
audio_data: Optional[Union[List[str], str]] = None
|
460
473
|
# The token ids for text; one can either specify text or input_ids.
|
461
474
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
462
475
|
# The request id.
|
@@ -665,10 +678,15 @@ class BatchEmbeddingOut:
|
|
665
678
|
|
666
679
|
|
667
680
|
@dataclass
|
668
|
-
class
|
681
|
+
class FlushCacheReqInput:
|
669
682
|
pass
|
670
683
|
|
671
684
|
|
685
|
+
@dataclass
|
686
|
+
class FlushCacheReqOutput:
|
687
|
+
success: bool
|
688
|
+
|
689
|
+
|
672
690
|
@dataclass
|
673
691
|
class UpdateWeightFromDiskReqInput:
|
674
692
|
# The model path with the new weights
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -10,12 +10,13 @@ import torch
|
|
10
10
|
from torch import nn
|
11
11
|
|
12
12
|
from sglang.srt.managers.schedule_batch import (
|
13
|
+
Modality,
|
13
14
|
MultimodalDataItem,
|
14
15
|
MultimodalInputs,
|
15
16
|
global_server_args_dict,
|
16
17
|
)
|
17
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
|
-
from sglang.srt.utils import print_warning_once
|
19
|
+
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
19
20
|
|
20
21
|
logger = logging.getLogger(__name__)
|
21
22
|
|
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
97
98
|
return padded_ids
|
98
99
|
|
99
100
|
|
100
|
-
class
|
101
|
+
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
|
101
102
|
"""In this pattern, data tokens should be represented as repetitions of a single token
|
102
103
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
103
104
|
"""
|
104
105
|
|
105
|
-
def __init__(self,
|
106
|
-
self.
|
106
|
+
def __init__(self, token_ids: List[int]) -> None:
|
107
|
+
self.token_ids = token_ids
|
107
108
|
|
108
|
-
def pad_input_tokens(
|
109
|
+
def pad_input_tokens(
|
110
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
111
|
+
) -> List[int]:
|
109
112
|
"""
|
110
|
-
|
113
|
+
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
114
|
+
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
111
115
|
"""
|
112
116
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
113
|
-
|
117
|
+
if not pad_values:
|
118
|
+
# No multimodal items, return original input_ids
|
119
|
+
return input_ids
|
120
|
+
if not input_ids:
|
121
|
+
return []
|
114
122
|
|
115
123
|
input_ids_tensor = torch.tensor(input_ids)
|
116
|
-
|
124
|
+
device = input_ids_tensor.device
|
125
|
+
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
126
|
+
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
117
127
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
128
|
+
if not mask.any():
|
129
|
+
# No tokens match token_ids, return original input_ids
|
130
|
+
return input_ids
|
131
|
+
|
132
|
+
# Find contiguous regions
|
133
|
+
padded_mask = torch.cat(
|
134
|
+
(
|
135
|
+
torch.tensor([False], device=device),
|
136
|
+
mask,
|
137
|
+
torch.tensor([False], device=device),
|
138
|
+
)
|
139
|
+
)
|
140
|
+
# Find indices where the mask value changes
|
141
|
+
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
142
|
+
|
143
|
+
# Start indices are where False changes to True
|
144
|
+
starts = diff_indices[::2]
|
145
|
+
# End indices are where True changes to False (exclusive index)
|
146
|
+
ends = diff_indices[1::2]
|
147
|
+
|
148
|
+
# Check if the number of regions matches the number of pad values
|
149
|
+
if len(starts) != len(pad_values):
|
150
|
+
# Maybe log a warning here?
|
151
|
+
num_regions = len(starts)
|
152
|
+
num_pad_values = len(pad_values)
|
153
|
+
if num_regions > 0 and num_pad_values > 0:
|
154
|
+
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
155
|
+
:num_regions
|
156
|
+
]
|
157
|
+
else: # If no regions or no pad_values, this loop won't run anyway.
|
158
|
+
pad_values = [] # Ensure pad_values is empty if starts is empty
|
159
|
+
|
160
|
+
# Create a copy to modify
|
161
|
+
output_ids_tensor = input_ids_tensor.clone()
|
162
|
+
|
163
|
+
# Replace tokens in each region with the corresponding pad value
|
164
|
+
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
165
|
+
for i in range(min(len(starts), len(pad_values))):
|
166
|
+
start_idx = starts[i]
|
167
|
+
end_idx = ends[i]
|
168
|
+
pad_value = pad_values[i]
|
169
|
+
if pad_value is not None: # Ensure pad_value is not None before assignment
|
170
|
+
output_ids_tensor[start_idx:end_idx] = pad_value
|
171
|
+
else:
|
172
|
+
logger.warning(f"Skipping region {i} due to None pad_value.")
|
122
173
|
|
123
|
-
|
124
|
-
return input_ids_tensor.tolist()
|
174
|
+
return output_ids_tensor.tolist()
|
125
175
|
|
126
176
|
|
127
177
|
def get_embedding_and_mask(
|
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
|
|
150
200
|
).unsqueeze(-1)
|
151
201
|
|
152
202
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
153
|
-
|
154
203
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
155
204
|
logger.warning(
|
156
205
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
@@ -190,13 +239,13 @@ def embed_mm_inputs(
|
|
190
239
|
audio_data_embedding_func: Callable[
|
191
240
|
[List[MultimodalDataItem]], torch.Tensor
|
192
241
|
] = None,
|
193
|
-
|
242
|
+
placeholder_tokens: dict[Modality, List[int]] = None,
|
194
243
|
) -> Optional[torch.Tensor]:
|
195
244
|
"""
|
196
245
|
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
197
246
|
|
198
247
|
Args:
|
199
|
-
|
248
|
+
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
200
249
|
If none, the pad_values of multimodal items are used
|
201
250
|
|
202
251
|
Returns:
|
@@ -208,9 +257,17 @@ def embed_mm_inputs(
|
|
208
257
|
|
209
258
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
210
259
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
211
|
-
|
212
|
-
|
213
|
-
|
260
|
+
# See `pad_input_ids` for more detail
|
261
|
+
|
262
|
+
# if placeholder_tokens is specified
|
263
|
+
if placeholder_tokens is not None:
|
264
|
+
placeholder_token_ids = flatten_nested_list(
|
265
|
+
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
269
|
+
|
270
|
+
assert isinstance(placeholder_token_ids[0], int)
|
214
271
|
|
215
272
|
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
216
273
|
|
@@ -233,7 +290,7 @@ def embed_mm_inputs(
|
|
233
290
|
using_all_items = False
|
234
291
|
if len(appearing_items) == 0:
|
235
292
|
# This happens mostly when arg placeholder_token_ids is passed
|
236
|
-
logger.
|
293
|
+
logger.warning(
|
237
294
|
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
238
295
|
)
|
239
296
|
using_all_items = True
|
@@ -253,7 +310,8 @@ def embed_mm_inputs(
|
|
253
310
|
data_embedding_func=image_data_embedding_func,
|
254
311
|
embedding_items=items,
|
255
312
|
placeholder_tensor=(
|
256
|
-
|
313
|
+
# use the specified modality token to identify the location to embed
|
314
|
+
placeholder_tokens[Modality.IMAGE]
|
257
315
|
if using_all_items
|
258
316
|
else torch.tensor(
|
259
317
|
[item.pad_value for item in items],
|
@@ -275,7 +333,7 @@ def embed_mm_inputs(
|
|
275
333
|
data_embedding_func=audio_data_embedding_func,
|
276
334
|
embedding_items=items,
|
277
335
|
placeholder_tensor=(
|
278
|
-
|
336
|
+
placeholder_tokens[Modality.AUDIO]
|
279
337
|
if using_all_items
|
280
338
|
else torch.tensor(
|
281
339
|
[item.pad_value for item in items],
|
@@ -296,7 +354,7 @@ def embed_mm_inputs(
|
|
296
354
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
297
355
|
inputs_embeds = input_embedding(input_ids)
|
298
356
|
|
299
|
-
# 4.
|
357
|
+
# 4. Scatter embeddings into input embedding
|
300
358
|
for embedding, mask in zip(embeddings, masks):
|
301
359
|
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
302
360
|
inputs_embeds = inputs_embeds.masked_scatter(
|
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
|
|
316
374
|
audio_data_embedding_func: Callable[
|
317
375
|
[List[MultimodalDataItem]], torch.Tensor
|
318
376
|
] = None,
|
319
|
-
|
377
|
+
placeholder_tokens: dict[Modality, List[int]] = None,
|
320
378
|
**kwargs,
|
321
379
|
) -> torch.Tensor:
|
322
380
|
"""
|
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
|
|
328
386
|
audio_data_embedding_func : the function returning the image embedding
|
329
387
|
|
330
388
|
Returns:
|
331
|
-
inputs_embedding
|
332
389
|
forwarded hidden states
|
333
390
|
|
334
391
|
"""
|
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
|
|
346
403
|
input_embedding=embed_tokens,
|
347
404
|
image_data_embedding_func=image_data_embedding_func,
|
348
405
|
audio_data_embedding_func=audio_data_embedding_func,
|
349
|
-
|
406
|
+
placeholder_tokens=placeholder_tokens,
|
350
407
|
)
|
351
|
-
# once used, mm_inputs is useless
|
408
|
+
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
352
409
|
# just being defensive here
|
353
410
|
forward_batch.mm_inputs = None
|
354
411
|
else:
|
@@ -8,6 +8,7 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
+
from PIL import Image
|
11
12
|
from transformers import BaseImageProcessorFast
|
12
13
|
|
13
14
|
from sglang.srt.managers.schedule_batch import Modality
|
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
|
|
92
93
|
|
93
94
|
@abstractmethod
|
94
95
|
async def process_mm_data_async(
|
95
|
-
self,
|
96
|
+
self,
|
97
|
+
image_data,
|
98
|
+
input_text,
|
99
|
+
request_obj,
|
100
|
+
max_req_input_len,
|
101
|
+
**kwargs,
|
96
102
|
):
|
97
103
|
pass
|
98
104
|
|
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
|
|
104
110
|
from decord import VideoReader, cpu
|
105
111
|
|
106
112
|
# Before processing inputs
|
113
|
+
if not image_data or len(image_data) == 0:
|
114
|
+
return []
|
107
115
|
estimated_frames_list = []
|
108
116
|
for image in image_data:
|
109
117
|
if isinstance(image, str) and image.startswith("video:"):
|
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
|
|
215
223
|
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
216
224
|
|
217
225
|
"""
|
226
|
+
|
227
|
+
if image_data is None:
|
228
|
+
image_data = []
|
218
229
|
if isinstance(multimodal_tokens.image_token, int):
|
219
230
|
multimodal_tokens.image_token = (
|
220
231
|
self._processor.tokenizer.convert_ids_to_tokens(
|
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
|
|
229
240
|
prompt = self._processor.tokenizer.decode(prompt)
|
230
241
|
else:
|
231
242
|
prompt = prompt
|
243
|
+
|
244
|
+
assert isinstance(prompt, str)
|
232
245
|
if return_text:
|
233
246
|
import re
|
234
247
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17
17
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18
18
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19
|
+
from typing import List, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|
35
36
|
self.IMAGE_TOKEN = "<image>"
|
36
37
|
|
37
38
|
async def process_mm_data_async(
|
38
|
-
self,
|
39
|
+
self,
|
40
|
+
image_data: List[Union[str, bytes]],
|
41
|
+
input_text,
|
42
|
+
request_obj,
|
43
|
+
max_req_input_len,
|
44
|
+
*args,
|
45
|
+
**kwargs
|
39
46
|
):
|
40
47
|
if not image_data:
|
41
48
|
return None
|
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|
45
52
|
|
46
53
|
image_token = self.IMAGE_TOKEN
|
47
54
|
base_output = self.load_mm_data(
|
48
|
-
|
55
|
+
input_text,
|
49
56
|
image_data=image_data,
|
50
57
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
51
58
|
max_req_input_len=max_req_input_len,
|
@@ -1,7 +1,5 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
|
-
from transformers.utils import logging
|
4
|
-
|
5
3
|
from sglang.srt.managers.multimodal_processor import (
|
6
4
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
7
5
|
)
|
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
|
13
11
|
|
14
12
|
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
15
13
|
# will be removed in the future
|
16
|
-
logger = logging.get_logger(__name__)
|
17
14
|
|
18
15
|
|
19
16
|
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|
28
25
|
async def process_mm_data_async(
|
29
26
|
self,
|
30
27
|
image_data: List[Union[str, bytes]],
|
31
|
-
|
28
|
+
input_text,
|
32
29
|
request_obj,
|
33
30
|
max_req_input_len,
|
34
31
|
*args,
|
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|
41
38
|
|
42
39
|
image_token = self.IMAGE_TOKEN
|
43
40
|
base_output = self.load_mm_data(
|
44
|
-
prompt=
|
41
|
+
prompt=input_text,
|
45
42
|
image_data=image_data,
|
46
43
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
47
44
|
max_req_input_len=max_req_input_len,
|
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
17
17
|
async def process_mm_data_async(
|
18
18
|
self,
|
19
19
|
image_data: List[Union[str, bytes]],
|
20
|
-
|
20
|
+
input_text,
|
21
21
|
request_obj,
|
22
22
|
max_req_input_len,
|
23
23
|
**kwargs,
|
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
31
31
|
processor = self._processor
|
32
32
|
|
33
33
|
base_out = self.load_mm_data(
|
34
|
-
prompt=
|
34
|
+
prompt=input_text,
|
35
35
|
image_data=image_data,
|
36
36
|
multimodal_tokens=MultimodalSpecialTokens(
|
37
37
|
image_token=processor.image_token
|
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
51
51
|
async def process_mm_data_async(
|
52
52
|
self,
|
53
53
|
image_data: List[Union[str, bytes]],
|
54
|
-
|
54
|
+
input_text,
|
55
55
|
request_obj,
|
56
56
|
max_req_input_len,
|
57
|
+
**kwargs,
|
57
58
|
):
|
58
59
|
audio_data = request_obj.audio_data
|
59
60
|
if not image_data and not audio_data:
|
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
64
65
|
audio_data = [audio_data]
|
65
66
|
|
66
67
|
base_output = self.load_mm_data(
|
67
|
-
prompt=
|
68
|
+
prompt=input_text,
|
68
69
|
max_req_input_len=max_req_input_len,
|
69
70
|
audio_data=audio_data,
|
70
71
|
image_data=image_data,
|
@@ -96,7 +97,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
96
97
|
audio_start_id = tokenizer.audio_start_id
|
97
98
|
audio_end_id = tokenizer.audio_end_id
|
98
99
|
|
99
|
-
im_token_id = tokenizer.
|
100
|
+
im_token_id = tokenizer.unk_id
|
100
101
|
pixel_values = res["pixel_values"]
|
101
102
|
tgt_sizes = res["tgt_sizes"]
|
102
103
|
|
@@ -5,6 +5,7 @@ from typing import List, Union
|
|
5
5
|
import torch
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
8
9
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
9
10
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
10
11
|
)
|
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
27
28
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
28
29
|
self.image_token_id = hf_config.image_token_id
|
29
30
|
self.video_token_id = hf_config.video_token_id
|
31
|
+
self.vision_start_token_id = hf_config.vision_start_token_id
|
32
|
+
self.vision_end_token_id = hf_config.vision_end_token_id
|
30
33
|
self.NUM_TOKEN_PER_FRAME = 770
|
31
34
|
self.IMAGE_FACTOR = 28
|
32
35
|
self.MIN_PIXELS = 4 * 28 * 28
|
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
36
39
|
async def process_mm_data_async(
|
37
40
|
self,
|
38
41
|
image_data: List[Union[str, bytes]],
|
39
|
-
|
42
|
+
input_text,
|
40
43
|
request_obj,
|
41
44
|
max_req_input_len,
|
42
45
|
*args,
|
43
46
|
**kwargs,
|
44
47
|
):
|
45
|
-
if not image_data:
|
46
|
-
return None
|
47
48
|
if isinstance(image_data, str):
|
48
49
|
image_data = [image_data]
|
49
50
|
|
50
51
|
image_token = self.IMAGE_TOKEN
|
51
52
|
base_output = self.load_mm_data(
|
52
|
-
prompt=
|
53
|
+
prompt=input_text,
|
53
54
|
image_data=image_data,
|
54
55
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
55
56
|
max_req_input_len=max_req_input_len,
|
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
116
117
|
async def resize_image_async(image):
|
117
118
|
return resize_image(image)
|
118
119
|
|
119
|
-
|
120
|
-
|
120
|
+
if base_output.images:
|
121
|
+
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
122
|
+
base_output.images = await asyncio.gather(*resize_tasks)
|
121
123
|
|
122
124
|
ret = self.process_mm_data(
|
123
125
|
input_text=base_output.input_text,
|
124
|
-
images=
|
126
|
+
images=base_output.images,
|
125
127
|
)
|
126
128
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
129
|
+
items = []
|
130
|
+
|
131
|
+
input_ids = ret["input_ids"].flatten().tolist()
|
132
|
+
if "pixel_values" in ret:
|
133
|
+
items += [
|
131
134
|
MultimodalDataItem(
|
132
135
|
pixel_values=ret["pixel_values"],
|
133
|
-
image_grid_thws=
|
136
|
+
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
134
137
|
# TODO
|
135
138
|
video_grid_thws=None,
|
136
139
|
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
137
140
|
modality=Modality.IMAGE,
|
138
141
|
)
|
139
|
-
]
|
142
|
+
]
|
143
|
+
|
144
|
+
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
145
|
+
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
146
|
+
image_token_id=self.image_token_id,
|
147
|
+
video_token_id=self.video_token_id,
|
148
|
+
vision_start_token_id=self.vision_start_token_id,
|
149
|
+
model_type=self.hf_config.model_type,
|
150
|
+
tokens_per_second=getattr(
|
151
|
+
self.hf_config.vision_config, "tokens_per_second", None
|
152
|
+
),
|
153
|
+
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
154
|
+
image_grid_thw=ret.get("image_grid_thw", None),
|
155
|
+
video_grid_thw=ret.get("video_grid_thw", None),
|
156
|
+
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
157
|
+
)
|
158
|
+
mrope_positions = mrope_positions.squeeze(1)
|
159
|
+
|
160
|
+
return {
|
161
|
+
"input_ids": input_ids,
|
162
|
+
"mm_items": items,
|
140
163
|
"im_start_id": self.IM_START_TOKEN_ID,
|
141
164
|
"im_end_id": self.IM_END_TOKEN_ID,
|
142
165
|
"im_token_id": self.image_token_id,
|
143
166
|
"video_token_id": self.video_token_id,
|
167
|
+
"mrope_positions": mrope_positions,
|
168
|
+
"mrope_position_delta": mrope_position_delta,
|
144
169
|
}
|