sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
19
|
|
20
20
|
import logging
|
21
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from torch import nn
|
@@ -32,6 +32,9 @@ from sglang.srt.distributed import (
|
|
32
32
|
tensor_model_parallel_all_gather,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
36
39
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
37
40
|
from sglang.srt.layers.dp_attention import (
|
@@ -63,12 +66,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
63
66
|
ParallelLMHead,
|
64
67
|
VocabParallelEmbedding,
|
65
68
|
)
|
66
|
-
from sglang.srt.managers.expert_distribution import (
|
67
|
-
get_global_expert_distribution_recorder,
|
68
|
-
)
|
69
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
70
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
71
69
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
70
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
72
71
|
from sglang.srt.model_executor.forward_batch_info import (
|
73
72
|
ForwardBatch,
|
74
73
|
ForwardMode,
|
@@ -78,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
78
77
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
79
78
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
80
79
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
81
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
|
80
|
+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
|
82
81
|
|
83
82
|
Qwen3MoeConfig = None
|
84
83
|
|
85
84
|
logger = logging.getLogger(__name__)
|
85
|
+
_is_cuda = is_cuda()
|
86
86
|
|
87
87
|
|
88
88
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
@@ -117,6 +117,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
117
117
|
if global_server_args_dict["enable_deepep_moe"]
|
118
118
|
else {}
|
119
119
|
),
|
120
|
+
# Additional args for FusedMoE
|
121
|
+
**(
|
122
|
+
dict(
|
123
|
+
enable_flashinfer_moe=True,
|
124
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
125
|
+
)
|
126
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
127
|
+
else {}
|
128
|
+
),
|
120
129
|
)
|
121
130
|
|
122
131
|
self.gate = ReplicatedLinear(
|
@@ -220,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
220
229
|
hidden_states=hidden_states,
|
221
230
|
topk_idx=topk_idx,
|
222
231
|
topk_weights=topk_weights,
|
223
|
-
|
232
|
+
forward_batch=forward_batch,
|
224
233
|
)
|
225
234
|
final_hidden_states = self.experts(
|
226
235
|
hidden_states=hidden_states,
|
@@ -231,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
231
240
|
masked_m=masked_m,
|
232
241
|
expected_m=expected_m,
|
233
242
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
234
|
-
|
243
|
+
forward_batch=forward_batch,
|
235
244
|
)
|
236
245
|
if self.ep_size > 1:
|
237
246
|
final_hidden_states = self.deepep_dispatcher.combine(
|
238
247
|
hidden_states=final_hidden_states,
|
239
248
|
topk_idx=topk_idx,
|
240
249
|
topk_weights=topk_weights,
|
241
|
-
|
250
|
+
forward_batch=forward_batch,
|
242
251
|
)
|
243
252
|
return final_hidden_states
|
244
253
|
|
@@ -284,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
284
293
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
285
294
|
topk_idx=state.pop("topk_idx_local"),
|
286
295
|
topk_weights=state.pop("topk_weights_local"),
|
287
|
-
|
296
|
+
forward_batch=state.forward_batch,
|
288
297
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
289
298
|
)
|
290
299
|
|
@@ -316,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
316
325
|
masked_m=state.pop("masked_m"),
|
317
326
|
expected_m=state.pop("expected_m"),
|
318
327
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
319
|
-
|
328
|
+
forward_batch=state.forward_batch,
|
320
329
|
)
|
321
330
|
|
322
331
|
def op_combine_a(self, state):
|
@@ -325,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
325
334
|
hidden_states=state.pop("hidden_states_experts_output"),
|
326
335
|
topk_idx=state.pop("topk_idx_dispatched"),
|
327
336
|
topk_weights=state.pop("topk_weights_dispatched"),
|
328
|
-
|
337
|
+
forward_batch=state.forward_batch,
|
329
338
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
330
339
|
)
|
331
340
|
|
@@ -354,6 +363,7 @@ class Qwen3MoeAttention(nn.Module):
|
|
354
363
|
attention_bias: bool = False,
|
355
364
|
quant_config: Optional[QuantizationConfig] = None,
|
356
365
|
prefix: str = "",
|
366
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
357
367
|
) -> None:
|
358
368
|
super().__init__()
|
359
369
|
self.hidden_size = hidden_size
|
@@ -423,15 +433,27 @@ class Qwen3MoeAttention(nn.Module):
|
|
423
433
|
|
424
434
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
425
435
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
436
|
+
self.alt_stream = alt_stream
|
426
437
|
|
427
438
|
def _apply_qk_norm(
|
428
439
|
self, q: torch.Tensor, k: torch.Tensor
|
429
440
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
430
|
-
|
431
|
-
|
441
|
+
# overlap qk norm
|
442
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
443
|
+
current_stream = torch.cuda.current_stream()
|
444
|
+
self.alt_stream.wait_stream(current_stream)
|
445
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
446
|
+
q_by_head = self.q_norm(q_by_head)
|
447
|
+
with torch.cuda.stream(self.alt_stream):
|
448
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
449
|
+
k_by_head = self.k_norm(k_by_head)
|
450
|
+
current_stream.wait_stream(self.alt_stream)
|
451
|
+
else:
|
452
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
453
|
+
q_by_head = self.q_norm(q_by_head)
|
454
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
455
|
+
k_by_head = self.k_norm(k_by_head)
|
432
456
|
q = q_by_head.view(q.shape)
|
433
|
-
k_by_head = k.reshape(-1, self.head_dim)
|
434
|
-
k_by_head = self.k_norm(k_by_head)
|
435
457
|
k = k_by_head.view(k.shape)
|
436
458
|
return q, k
|
437
459
|
|
@@ -491,6 +513,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
491
513
|
layer_id: int,
|
492
514
|
quant_config: Optional[QuantizationConfig] = None,
|
493
515
|
prefix: str = "",
|
516
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
494
517
|
) -> None:
|
495
518
|
super().__init__()
|
496
519
|
self.config = config
|
@@ -516,6 +539,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
516
539
|
attention_bias=attention_bias,
|
517
540
|
quant_config=quant_config,
|
518
541
|
prefix=add_prefix("self_attn", prefix),
|
542
|
+
alt_stream=alt_stream,
|
519
543
|
)
|
520
544
|
|
521
545
|
self.layer_id = layer_id
|
@@ -623,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
623
647
|
|
624
648
|
def op_mlp(self, state):
|
625
649
|
hidden_states = state.pop("hidden_states_mlp_input")
|
626
|
-
state.hidden_states_mlp_output = self.mlp(
|
627
|
-
hidden_states, state.forward_batch.forward_mode
|
628
|
-
)
|
650
|
+
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
|
629
651
|
|
630
652
|
def op_comm_postprocess_layer(self, state):
|
631
653
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
@@ -659,11 +681,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|
659
681
|
quant_config: Optional[QuantizationConfig] = None,
|
660
682
|
prefix: str = "",
|
661
683
|
) -> None:
|
684
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
662
685
|
super().__init__(
|
663
686
|
config=config,
|
664
687
|
quant_config=quant_config,
|
665
688
|
prefix=prefix,
|
666
689
|
decoder_layer_type=Qwen3MoeDecoderLayer,
|
690
|
+
alt_stream=alt_stream,
|
667
691
|
)
|
668
692
|
|
669
693
|
|
@@ -691,6 +715,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
691
715
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
692
716
|
)
|
693
717
|
self.logits_processor = LogitsProcessor(config)
|
718
|
+
self.capture_aux_hidden_states = False
|
694
719
|
|
695
720
|
@torch.no_grad()
|
696
721
|
def forward(
|
@@ -709,9 +734,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
709
734
|
pp_proxy_tensors=pp_proxy_tensors,
|
710
735
|
)
|
711
736
|
|
737
|
+
aux_hidden_states = None
|
738
|
+
if self.capture_aux_hidden_states:
|
739
|
+
hidden_states, aux_hidden_states = hidden_states
|
740
|
+
|
712
741
|
if self.pp_group.is_last_rank:
|
713
742
|
return self.logits_processor(
|
714
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
743
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
715
744
|
)
|
716
745
|
else:
|
717
746
|
return hidden_states
|
@@ -724,6 +753,24 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
724
753
|
def end_layer(self):
|
725
754
|
return self.model.end_layer
|
726
755
|
|
756
|
+
def get_embed_and_head(self):
|
757
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
758
|
+
|
759
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
760
|
+
if not self.pp_group.is_last_rank:
|
761
|
+
return
|
762
|
+
|
763
|
+
self.capture_aux_hidden_states = True
|
764
|
+
if layer_ids is None:
|
765
|
+
num_layers = self.config.num_hidden_layers
|
766
|
+
self.model.layers_to_capture = [
|
767
|
+
2,
|
768
|
+
num_layers // 2,
|
769
|
+
num_layers - 3,
|
770
|
+
] # Specific layers for EAGLE3 support
|
771
|
+
else:
|
772
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
773
|
+
|
727
774
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
728
775
|
stacked_params_mapping = [
|
729
776
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/vila.py
CHANGED
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
|
|
270
270
|
weight_loader(param, loaded_weight)
|
271
271
|
|
272
272
|
def pad_input_ids(
|
273
|
-
self,
|
274
|
-
input_ids: List[int],
|
275
|
-
image_inputs: MultimodalInputs,
|
273
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
276
274
|
) -> List[int]:
|
277
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
278
|
-
|
279
|
-
)
|
280
|
-
|
281
|
-
return pattern.pad_input_tokens(input_ids, image_inputs)
|
275
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
276
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
282
277
|
|
283
278
|
##### BEGIN COPY modeling_vila.py #####
|
284
279
|
|
@@ -17,14 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
|
17
17
|
from sglang.srt.utils import encode_video, load_audio, load_image
|
18
18
|
|
19
19
|
|
20
|
-
class MultimodalInputFormat(Enum):
|
21
|
-
"""Enum for different multimodal input formats."""
|
22
|
-
|
23
|
-
RAW_IMAGES = "raw_images"
|
24
|
-
PRECOMPUTED_FEATURES = "precomputed_features"
|
25
|
-
PIXEL_VALUES = "pixel_values"
|
26
|
-
|
27
|
-
|
28
20
|
@dataclasses.dataclass
|
29
21
|
class BaseMultiModalProcessorOutput:
|
30
22
|
# input_text, with each frame of video/image represented with a image_token
|
@@ -97,6 +89,7 @@ class BaseMultimodalProcessor(ABC):
|
|
97
89
|
self._processor = _processor
|
98
90
|
self.arch = hf_config.architectures[0]
|
99
91
|
self.server_args = server_args
|
92
|
+
|
100
93
|
# FIXME: not accurate, model and image specific
|
101
94
|
self.NUM_TOKEN_PER_FRAME = 330
|
102
95
|
|
@@ -108,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
|
|
108
101
|
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
109
102
|
)
|
110
103
|
|
104
|
+
# Mapping from attribute names to modality types
|
105
|
+
self.ATTR_NAME_TO_MODALITY = {
|
106
|
+
# Image-related attributes
|
107
|
+
"pixel_values": Modality.IMAGE,
|
108
|
+
"image_sizes": Modality.IMAGE,
|
109
|
+
"image_grid_thw": Modality.IMAGE,
|
110
|
+
"image_emb_mask": Modality.IMAGE,
|
111
|
+
"image_spatial_crop": Modality.IMAGE,
|
112
|
+
"tgt_size": Modality.IMAGE,
|
113
|
+
"image_grid_hws": Modality.IMAGE,
|
114
|
+
"aspect_ratio_id": Modality.IMAGE,
|
115
|
+
"aspect_ratio_mask": Modality.IMAGE,
|
116
|
+
"second_per_grid_ts": Modality.IMAGE,
|
117
|
+
# Audio-related attributes
|
118
|
+
"audio_features": Modality.AUDIO,
|
119
|
+
"audio_feature_lens": Modality.AUDIO,
|
120
|
+
"input_features": Modality.AUDIO,
|
121
|
+
"input_features_mask": Modality.AUDIO,
|
122
|
+
# Video-related attributes
|
123
|
+
"video_grid_thws": Modality.VIDEO,
|
124
|
+
# Generic attributes that could apply to multiple modalities
|
125
|
+
# "precomputed_features" - handled specially as it can be any modality
|
126
|
+
}
|
127
|
+
|
111
128
|
def process_mm_data(
|
112
129
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
113
130
|
):
|
114
131
|
"""
|
115
132
|
process multimodal data with transformers AutoProcessor
|
116
133
|
"""
|
117
|
-
if images
|
134
|
+
if images:
|
118
135
|
kwargs["images"] = images
|
119
|
-
if videos
|
136
|
+
if videos:
|
120
137
|
kwargs["videos"] = videos
|
121
|
-
if audios
|
138
|
+
if audios:
|
122
139
|
kwargs["audios"] = audios
|
140
|
+
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
|
141
|
+
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
142
|
+
kwargs["audio"] = audios
|
123
143
|
|
124
144
|
processor = self._processor
|
125
145
|
if hasattr(processor, "image_processor") and isinstance(
|
@@ -142,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
|
|
142
162
|
async def process_mm_data_async(
|
143
163
|
self,
|
144
164
|
image_data,
|
165
|
+
audio_data,
|
145
166
|
input_text,
|
146
167
|
request_obj,
|
147
168
|
max_req_input_len,
|
@@ -416,141 +437,137 @@ class BaseMultimodalProcessor(ABC):
|
|
416
437
|
values[k] = v
|
417
438
|
return values
|
418
439
|
|
440
|
+
def collect_mm_items_from_processor_output(
|
441
|
+
self, data_dict: dict
|
442
|
+
) -> List[MultimodalDataItem]:
|
443
|
+
"""Create mm_items directly from processor output."""
|
444
|
+
items = {} # modality -> MultimodalDataItem
|
445
|
+
|
446
|
+
for attr_name, value in data_dict.items():
|
447
|
+
if attr_name == "input_ids":
|
448
|
+
continue
|
449
|
+
|
450
|
+
# Get modality for this attribute
|
451
|
+
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
452
|
+
|
453
|
+
if not modality and attr_name == "precomputed_features":
|
454
|
+
modality_str = data_dict.get("modality")
|
455
|
+
try:
|
456
|
+
modality = (
|
457
|
+
Modality.from_str(modality_str)
|
458
|
+
if modality_str
|
459
|
+
else Modality.IMAGE
|
460
|
+
)
|
461
|
+
except ValueError:
|
462
|
+
modality = Modality.IMAGE
|
463
|
+
|
464
|
+
if modality:
|
465
|
+
# Create item if needed
|
466
|
+
if modality not in items:
|
467
|
+
items[modality] = MultimodalDataItem(modality=modality)
|
468
|
+
|
469
|
+
# Set attribute
|
470
|
+
if hasattr(items[modality], attr_name):
|
471
|
+
setattr(items[modality], attr_name, value)
|
472
|
+
|
473
|
+
return list(items.values())
|
474
|
+
|
475
|
+
def _process_and_collect_mm_items(
|
476
|
+
self, input_text: str, images=None, audios=None, videos=None, **kwargs
|
477
|
+
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
478
|
+
"""
|
479
|
+
Helper method to process multimodal data and create mm_items in one step.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
Tuple of (created mm_items, input_ids)
|
483
|
+
"""
|
484
|
+
ret = self.process_mm_data(
|
485
|
+
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
|
486
|
+
)
|
487
|
+
|
488
|
+
input_ids = ret["input_ids"].flatten()
|
489
|
+
collected_items = self.collect_mm_items_from_processor_output(ret)
|
490
|
+
|
491
|
+
return collected_items, input_ids
|
492
|
+
|
419
493
|
def process_and_combine_mm_data(
|
420
494
|
self, base_output: BaseMultiModalProcessorOutput
|
421
|
-
) -> Tuple[
|
495
|
+
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
422
496
|
"""
|
423
|
-
Process multimodal data and return the combined multimodal
|
424
|
-
|
497
|
+
Process multimodal data and return the combined multimodal items and input_ids.
|
498
|
+
Supports mixed modalities (images and audio in the same request).
|
425
499
|
|
426
500
|
Returns:
|
427
|
-
Tuple of (
|
501
|
+
Tuple of (list of mm_items, input_ids)
|
428
502
|
"""
|
503
|
+
# Collect all items and categorize them
|
504
|
+
all_items = (base_output.images or []) + (base_output.audios or [])
|
429
505
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
input_text,
|
506
|
+
# Handle text-only case
|
507
|
+
if not all_items:
|
508
|
+
input_ids = self._processor.tokenizer(
|
509
|
+
base_output.input_text,
|
434
510
|
return_tensors="pt",
|
435
511
|
add_special_tokens=True,
|
436
512
|
).input_ids.flatten()
|
513
|
+
return [], input_ids
|
514
|
+
|
515
|
+
dict_items, raw_images, raw_audios = [], [], []
|
516
|
+
for item in all_items:
|
517
|
+
if isinstance(item, dict):
|
518
|
+
dict_items.append(item)
|
519
|
+
elif isinstance(item, Image.Image):
|
520
|
+
raw_images.append(item)
|
521
|
+
elif isinstance(item, np.ndarray):
|
522
|
+
raw_audios.append(item)
|
523
|
+
else:
|
524
|
+
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
437
525
|
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
has_image = False
|
442
|
-
has_pixel_values = False
|
443
|
-
has_precomputed_features = False
|
444
|
-
|
445
|
-
for mm_input in mm_inputs:
|
446
|
-
if isinstance(mm_input, Image.Image):
|
447
|
-
has_image = True
|
448
|
-
elif isinstance(mm_input, dict):
|
449
|
-
if mm_input.get("precomputed_features", None) is not None:
|
450
|
-
has_precomputed_features = True
|
451
|
-
elif mm_input.get("pixel_values", None) is not None:
|
452
|
-
has_pixel_values = True
|
453
|
-
else:
|
454
|
-
raise ValueError(
|
455
|
-
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
456
|
-
)
|
457
|
-
else:
|
458
|
-
raise ValueError(
|
459
|
-
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
460
|
-
)
|
526
|
+
# Process items and get input_ids
|
527
|
+
all_collected_items = []
|
528
|
+
input_ids = None
|
461
529
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
)
|
466
|
-
|
467
|
-
raise ValueError(
|
468
|
-
"Unsupported: mixture of multimodal input formats. "
|
469
|
-
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
470
|
-
f"precomputed_features={has_precomputed_features}"
|
471
|
-
)
|
530
|
+
# Handle dict items (already processed)
|
531
|
+
for dict_item in dict_items:
|
532
|
+
all_collected_items.extend(
|
533
|
+
self.collect_mm_items_from_processor_output(dict_item)
|
534
|
+
)
|
472
535
|
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
477
|
-
elif has_pixel_values:
|
478
|
-
return MultimodalInputFormat.PIXEL_VALUES
|
479
|
-
else:
|
480
|
-
raise ValueError("No valid multimodal input format found")
|
481
|
-
except Exception as e:
|
482
|
-
raise ValueError(f"Failed to categorize inputs: {e}")
|
483
|
-
|
484
|
-
def process_raw_images(
|
485
|
-
base_output: BaseMultiModalProcessorOutput,
|
486
|
-
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
487
|
-
"""Process raw Image.Image objects using transformers processor."""
|
488
|
-
ret = self.process_mm_data(
|
536
|
+
# Handle raw items (need processing)
|
537
|
+
if raw_images or raw_audios:
|
538
|
+
collected_items, input_ids = self._process_and_collect_mm_items(
|
489
539
|
input_text=base_output.input_text,
|
490
|
-
images=
|
540
|
+
images=raw_images,
|
541
|
+
audios=raw_audios,
|
491
542
|
)
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
input_ids = tokenize_text(base_output.input_text)
|
522
|
-
return combined_mm_item, input_ids
|
523
|
-
|
524
|
-
def finalize_mm_item(
|
525
|
-
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
526
|
-
) -> MultimodalDataItem:
|
527
|
-
"""Apply common post-processing to the multimodal item."""
|
528
|
-
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
529
|
-
input_ids=input_ids,
|
530
|
-
mm_token_id=self.IM_TOKEN_ID,
|
531
|
-
)
|
532
|
-
return combined_mm_item
|
533
|
-
|
534
|
-
# Main logic
|
535
|
-
mm_inputs = base_output.images
|
536
|
-
if not mm_inputs:
|
537
|
-
# Return text-only case
|
538
|
-
input_ids = tokenize_text(base_output.input_text)
|
539
|
-
return None, input_ids
|
540
|
-
|
541
|
-
# Categorize input formats
|
542
|
-
input_format = categorize_mm_inputs(mm_inputs)
|
543
|
-
|
544
|
-
# Process based on format
|
545
|
-
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
546
|
-
combined_mm_item, input_ids = process_raw_images(base_output)
|
547
|
-
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
548
|
-
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
549
|
-
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
550
|
-
combined_mm_item, input_ids = process_pixel_values(base_output)
|
551
|
-
else:
|
552
|
-
raise ValueError(f"Unknown input format: {input_format}")
|
543
|
+
all_collected_items.extend(collected_items)
|
544
|
+
|
545
|
+
# Fallback tokenization if no raw items were processed
|
546
|
+
if input_ids is None:
|
547
|
+
input_ids = self._processor.tokenizer(
|
548
|
+
base_output.input_text,
|
549
|
+
return_tensors="pt",
|
550
|
+
add_special_tokens=True,
|
551
|
+
).input_ids.flatten()
|
552
|
+
|
553
|
+
# Add offsets to all items
|
554
|
+
for mm_item in all_collected_items:
|
555
|
+
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
556
|
+
mm_item.image_offsets = self.get_mm_items_offset(
|
557
|
+
input_ids=input_ids,
|
558
|
+
mm_token_id=self.IM_TOKEN_ID,
|
559
|
+
)
|
560
|
+
elif mm_item.modality == Modality.AUDIO:
|
561
|
+
mm_item.audio_offsets = self.get_mm_items_offset(
|
562
|
+
input_ids=input_ids,
|
563
|
+
mm_token_id=self.AUDIO_TOKEN_ID,
|
564
|
+
)
|
565
|
+
elif mm_item.modality == Modality.VIDEO:
|
566
|
+
mm_item.video_offsets = self.get_mm_items_offset(
|
567
|
+
input_ids=input_ids,
|
568
|
+
mm_token_id=self.VIDEO_TOKEN_ID,
|
569
|
+
)
|
570
|
+
else:
|
571
|
+
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
553
572
|
|
554
|
-
|
555
|
-
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
556
|
-
return combined_mm_item, input_ids
|
573
|
+
return all_collected_items, input_ids
|
@@ -1,10 +1,8 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
|
-
from sglang.srt.managers.multimodal_processors.base_processor import (
|
4
|
-
BaseMultimodalProcessor,
|
5
|
-
)
|
6
3
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
7
4
|
from sglang.srt.models.clip import CLIPModel
|
5
|
+
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
8
6
|
from sglang.srt.utils import load_image
|
9
7
|
|
10
8
|
|
@@ -17,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
|
17
15
|
async def process_mm_data_async(
|
18
16
|
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
19
17
|
):
|
20
|
-
if not image_data:
|
21
|
-
return None
|
22
|
-
|
23
18
|
if isinstance(input_text, list):
|
24
19
|
assert len(input_text) and isinstance(input_text[0], int)
|
25
20
|
input_text = self._processor.tokenizer.decode(input_text)
|
26
21
|
|
27
|
-
|
28
|
-
image_data = [image_data]
|
29
|
-
|
30
|
-
if len(image_data) > 0:
|
31
|
-
images = [load_image(image)[0] for image in image_data]
|
32
|
-
else:
|
33
|
-
images = load_image(image_data[0])[0]
|
22
|
+
images = [load_image(image)[0] for image in image_data]
|
34
23
|
|
35
24
|
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
36
25
|
image_inputs["data_hashes"] = [hash(str(image_data))]
|
@@ -20,12 +20,12 @@ from typing import List, Union
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
23
|
-
from sglang.srt.managers.
|
23
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
24
|
+
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
25
|
+
from sglang.srt.multimodal.processors.base_processor import (
|
24
26
|
BaseMultimodalProcessor,
|
25
27
|
MultimodalSpecialTokens,
|
26
28
|
)
|
27
|
-
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
28
|
-
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
29
29
|
|
30
30
|
|
31
31
|
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|
44
44
|
*args,
|
45
45
|
**kwargs
|
46
46
|
):
|
47
|
-
if not image_data:
|
48
|
-
return None
|
49
|
-
|
50
|
-
if not isinstance(image_data, list):
|
51
|
-
image_data = [image_data]
|
52
|
-
|
53
|
-
image_token = self.IMAGE_TOKEN
|
54
47
|
base_output = self.load_mm_data(
|
55
48
|
input_text,
|
56
49
|
image_data=image_data,
|
57
|
-
multimodal_tokens=MultimodalSpecialTokens(image_token=
|
50
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
58
51
|
max_req_input_len=max_req_input_len,
|
59
52
|
)
|
60
53
|
res = self.process_mm_data(
|