sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
19
|
|
20
20
|
import logging
|
21
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from torch import nn
|
@@ -32,6 +32,9 @@ from sglang.srt.distributed import (
|
|
32
32
|
tensor_model_parallel_all_gather,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
36
39
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
37
40
|
from sglang.srt.layers.dp_attention import (
|
@@ -63,12 +66,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
63
66
|
ParallelLMHead,
|
64
67
|
VocabParallelEmbedding,
|
65
68
|
)
|
66
|
-
from sglang.srt.managers.expert_distribution import (
|
67
|
-
get_global_expert_distribution_recorder,
|
68
|
-
)
|
69
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
70
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
71
69
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
70
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
72
71
|
from sglang.srt.model_executor.forward_batch_info import (
|
73
72
|
ForwardBatch,
|
74
73
|
ForwardMode,
|
@@ -78,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
78
77
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
79
78
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
80
79
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
81
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
|
80
|
+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
|
82
81
|
|
83
82
|
Qwen3MoeConfig = None
|
84
83
|
|
85
84
|
logger = logging.getLogger(__name__)
|
85
|
+
_is_cuda = is_cuda()
|
86
86
|
|
87
87
|
|
88
88
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
@@ -117,6 +117,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
117
117
|
if global_server_args_dict["enable_deepep_moe"]
|
118
118
|
else {}
|
119
119
|
),
|
120
|
+
# Additional args for FusedMoE
|
121
|
+
**(
|
122
|
+
dict(
|
123
|
+
enable_flashinfer_moe=True,
|
124
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
125
|
+
)
|
126
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
127
|
+
else {}
|
128
|
+
),
|
120
129
|
)
|
121
130
|
|
122
131
|
self.gate = ReplicatedLinear(
|
@@ -220,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
220
229
|
hidden_states=hidden_states,
|
221
230
|
topk_idx=topk_idx,
|
222
231
|
topk_weights=topk_weights,
|
223
|
-
|
232
|
+
forward_batch=forward_batch,
|
224
233
|
)
|
225
234
|
final_hidden_states = self.experts(
|
226
235
|
hidden_states=hidden_states,
|
@@ -231,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
231
240
|
masked_m=masked_m,
|
232
241
|
expected_m=expected_m,
|
233
242
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
234
|
-
|
243
|
+
forward_batch=forward_batch,
|
235
244
|
)
|
236
245
|
if self.ep_size > 1:
|
237
246
|
final_hidden_states = self.deepep_dispatcher.combine(
|
238
247
|
hidden_states=final_hidden_states,
|
239
248
|
topk_idx=topk_idx,
|
240
249
|
topk_weights=topk_weights,
|
241
|
-
|
250
|
+
forward_batch=forward_batch,
|
242
251
|
)
|
243
252
|
return final_hidden_states
|
244
253
|
|
@@ -284,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
284
293
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
285
294
|
topk_idx=state.pop("topk_idx_local"),
|
286
295
|
topk_weights=state.pop("topk_weights_local"),
|
287
|
-
|
296
|
+
forward_batch=state.forward_batch,
|
288
297
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
289
298
|
)
|
290
299
|
|
@@ -316,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
316
325
|
masked_m=state.pop("masked_m"),
|
317
326
|
expected_m=state.pop("expected_m"),
|
318
327
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
319
|
-
|
328
|
+
forward_batch=state.forward_batch,
|
320
329
|
)
|
321
330
|
|
322
331
|
def op_combine_a(self, state):
|
@@ -325,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
325
334
|
hidden_states=state.pop("hidden_states_experts_output"),
|
326
335
|
topk_idx=state.pop("topk_idx_dispatched"),
|
327
336
|
topk_weights=state.pop("topk_weights_dispatched"),
|
328
|
-
|
337
|
+
forward_batch=state.forward_batch,
|
329
338
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
330
339
|
)
|
331
340
|
|
@@ -354,6 +363,7 @@ class Qwen3MoeAttention(nn.Module):
|
|
354
363
|
attention_bias: bool = False,
|
355
364
|
quant_config: Optional[QuantizationConfig] = None,
|
356
365
|
prefix: str = "",
|
366
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
357
367
|
) -> None:
|
358
368
|
super().__init__()
|
359
369
|
self.hidden_size = hidden_size
|
@@ -423,15 +433,27 @@ class Qwen3MoeAttention(nn.Module):
|
|
423
433
|
|
424
434
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
425
435
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
436
|
+
self.alt_stream = alt_stream
|
426
437
|
|
427
438
|
def _apply_qk_norm(
|
428
439
|
self, q: torch.Tensor, k: torch.Tensor
|
429
440
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
430
|
-
|
431
|
-
|
441
|
+
# overlap qk norm
|
442
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
443
|
+
current_stream = torch.cuda.current_stream()
|
444
|
+
self.alt_stream.wait_stream(current_stream)
|
445
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
446
|
+
q_by_head = self.q_norm(q_by_head)
|
447
|
+
with torch.cuda.stream(self.alt_stream):
|
448
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
449
|
+
k_by_head = self.k_norm(k_by_head)
|
450
|
+
current_stream.wait_stream(self.alt_stream)
|
451
|
+
else:
|
452
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
453
|
+
q_by_head = self.q_norm(q_by_head)
|
454
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
455
|
+
k_by_head = self.k_norm(k_by_head)
|
432
456
|
q = q_by_head.view(q.shape)
|
433
|
-
k_by_head = k.reshape(-1, self.head_dim)
|
434
|
-
k_by_head = self.k_norm(k_by_head)
|
435
457
|
k = k_by_head.view(k.shape)
|
436
458
|
return q, k
|
437
459
|
|
@@ -491,6 +513,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
491
513
|
layer_id: int,
|
492
514
|
quant_config: Optional[QuantizationConfig] = None,
|
493
515
|
prefix: str = "",
|
516
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
494
517
|
) -> None:
|
495
518
|
super().__init__()
|
496
519
|
self.config = config
|
@@ -516,6 +539,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
516
539
|
attention_bias=attention_bias,
|
517
540
|
quant_config=quant_config,
|
518
541
|
prefix=add_prefix("self_attn", prefix),
|
542
|
+
alt_stream=alt_stream,
|
519
543
|
)
|
520
544
|
|
521
545
|
self.layer_id = layer_id
|
@@ -623,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
623
647
|
|
624
648
|
def op_mlp(self, state):
|
625
649
|
hidden_states = state.pop("hidden_states_mlp_input")
|
626
|
-
state.hidden_states_mlp_output = self.mlp(
|
627
|
-
hidden_states, state.forward_batch.forward_mode
|
628
|
-
)
|
650
|
+
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
|
629
651
|
|
630
652
|
def op_comm_postprocess_layer(self, state):
|
631
653
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
@@ -659,11 +681,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|
659
681
|
quant_config: Optional[QuantizationConfig] = None,
|
660
682
|
prefix: str = "",
|
661
683
|
) -> None:
|
684
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
662
685
|
super().__init__(
|
663
686
|
config=config,
|
664
687
|
quant_config=quant_config,
|
665
688
|
prefix=prefix,
|
666
689
|
decoder_layer_type=Qwen3MoeDecoderLayer,
|
690
|
+
alt_stream=alt_stream,
|
667
691
|
)
|
668
692
|
|
669
693
|
|
@@ -691,6 +715,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
691
715
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
692
716
|
)
|
693
717
|
self.logits_processor = LogitsProcessor(config)
|
718
|
+
self.capture_aux_hidden_states = False
|
694
719
|
|
695
720
|
@torch.no_grad()
|
696
721
|
def forward(
|
@@ -709,9 +734,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
709
734
|
pp_proxy_tensors=pp_proxy_tensors,
|
710
735
|
)
|
711
736
|
|
737
|
+
aux_hidden_states = None
|
738
|
+
if self.capture_aux_hidden_states:
|
739
|
+
hidden_states, aux_hidden_states = hidden_states
|
740
|
+
|
712
741
|
if self.pp_group.is_last_rank:
|
713
742
|
return self.logits_processor(
|
714
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
743
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
715
744
|
)
|
716
745
|
else:
|
717
746
|
return hidden_states
|
@@ -724,6 +753,24 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
724
753
|
def end_layer(self):
|
725
754
|
return self.model.end_layer
|
726
755
|
|
756
|
+
def get_embed_and_head(self):
|
757
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
758
|
+
|
759
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
760
|
+
if not self.pp_group.is_last_rank:
|
761
|
+
return
|
762
|
+
|
763
|
+
self.capture_aux_hidden_states = True
|
764
|
+
if layer_ids is None:
|
765
|
+
num_layers = self.config.num_hidden_layers
|
766
|
+
self.model.layers_to_capture = [
|
767
|
+
2,
|
768
|
+
num_layers // 2,
|
769
|
+
num_layers - 3,
|
770
|
+
] # Specific layers for EAGLE3 support
|
771
|
+
else:
|
772
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
773
|
+
|
727
774
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
728
775
|
stacked_params_mapping = [
|
729
776
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/vila.py
CHANGED
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
|
|
270
270
|
weight_loader(param, loaded_weight)
|
271
271
|
|
272
272
|
def pad_input_ids(
|
273
|
-
self,
|
274
|
-
input_ids: List[int],
|
275
|
-
image_inputs: MultimodalInputs,
|
273
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
276
274
|
) -> List[int]:
|
277
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
278
|
-
|
279
|
-
)
|
280
|
-
|
281
|
-
return pattern.pad_input_tokens(input_ids, image_inputs)
|
275
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
276
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
282
277
|
|
283
278
|
##### BEGIN COPY modeling_vila.py #####
|
284
279
|
|
@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
|
17
17
|
from sglang.srt.utils import encode_video, load_audio, load_image
|
18
18
|
|
19
19
|
|
20
|
-
class MultimodalInputFormat(Enum):
|
21
|
-
"""Enum for different multimodal input formats."""
|
22
|
-
|
23
|
-
RAW_IMAGES = "raw_images"
|
24
|
-
PRECOMPUTED_FEATURES = "precomputed_features"
|
25
|
-
PIXEL_VALUES = "pixel_values"
|
26
|
-
AUDIO = "audio"
|
27
|
-
|
28
|
-
|
29
20
|
@dataclasses.dataclass
|
30
21
|
class BaseMultiModalProcessorOutput:
|
31
22
|
# input_text, with each frame of video/image represented with a image_token
|
@@ -98,6 +89,7 @@ class BaseMultimodalProcessor(ABC):
|
|
98
89
|
self._processor = _processor
|
99
90
|
self.arch = hf_config.architectures[0]
|
100
91
|
self.server_args = server_args
|
92
|
+
|
101
93
|
# FIXME: not accurate, model and image specific
|
102
94
|
self.NUM_TOKEN_PER_FRAME = 330
|
103
95
|
|
@@ -109,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
|
|
109
101
|
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
110
102
|
)
|
111
103
|
|
104
|
+
# Mapping from attribute names to modality types
|
105
|
+
self.ATTR_NAME_TO_MODALITY = {
|
106
|
+
# Image-related attributes
|
107
|
+
"pixel_values": Modality.IMAGE,
|
108
|
+
"image_sizes": Modality.IMAGE,
|
109
|
+
"image_grid_thw": Modality.IMAGE,
|
110
|
+
"image_emb_mask": Modality.IMAGE,
|
111
|
+
"image_spatial_crop": Modality.IMAGE,
|
112
|
+
"tgt_size": Modality.IMAGE,
|
113
|
+
"image_grid_hws": Modality.IMAGE,
|
114
|
+
"aspect_ratio_id": Modality.IMAGE,
|
115
|
+
"aspect_ratio_mask": Modality.IMAGE,
|
116
|
+
"second_per_grid_ts": Modality.IMAGE,
|
117
|
+
# Audio-related attributes
|
118
|
+
"audio_features": Modality.AUDIO,
|
119
|
+
"audio_feature_lens": Modality.AUDIO,
|
120
|
+
"input_features": Modality.AUDIO,
|
121
|
+
"input_features_mask": Modality.AUDIO,
|
122
|
+
# Video-related attributes
|
123
|
+
"video_grid_thws": Modality.VIDEO,
|
124
|
+
# Generic attributes that could apply to multiple modalities
|
125
|
+
# "precomputed_features" - handled specially as it can be any modality
|
126
|
+
}
|
127
|
+
|
112
128
|
def process_mm_data(
|
113
129
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
114
130
|
):
|
115
131
|
"""
|
116
132
|
process multimodal data with transformers AutoProcessor
|
117
133
|
"""
|
118
|
-
if images
|
134
|
+
if images:
|
119
135
|
kwargs["images"] = images
|
120
|
-
if videos
|
136
|
+
if videos:
|
121
137
|
kwargs["videos"] = videos
|
122
|
-
if audios
|
138
|
+
if audios:
|
123
139
|
kwargs["audios"] = audios
|
140
|
+
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
|
141
|
+
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
142
|
+
kwargs["audio"] = audios
|
124
143
|
|
125
144
|
processor = self._processor
|
126
145
|
if hasattr(processor, "image_processor") and isinstance(
|
@@ -143,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
|
|
143
162
|
async def process_mm_data_async(
|
144
163
|
self,
|
145
164
|
image_data,
|
165
|
+
audio_data,
|
146
166
|
input_text,
|
147
167
|
request_obj,
|
148
168
|
max_req_input_len,
|
@@ -417,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
|
|
417
437
|
values[k] = v
|
418
438
|
return values
|
419
439
|
|
440
|
+
def collect_mm_items_from_processor_output(
|
441
|
+
self, data_dict: dict
|
442
|
+
) -> List[MultimodalDataItem]:
|
443
|
+
"""Create mm_items directly from processor output."""
|
444
|
+
items = {} # modality -> MultimodalDataItem
|
445
|
+
|
446
|
+
for attr_name, value in data_dict.items():
|
447
|
+
if attr_name == "input_ids":
|
448
|
+
continue
|
449
|
+
|
450
|
+
# Get modality for this attribute
|
451
|
+
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
452
|
+
|
453
|
+
if not modality and attr_name == "precomputed_features":
|
454
|
+
modality_str = data_dict.get("modality")
|
455
|
+
try:
|
456
|
+
modality = (
|
457
|
+
Modality.from_str(modality_str)
|
458
|
+
if modality_str
|
459
|
+
else Modality.IMAGE
|
460
|
+
)
|
461
|
+
except ValueError:
|
462
|
+
modality = Modality.IMAGE
|
463
|
+
|
464
|
+
if modality:
|
465
|
+
# Create item if needed
|
466
|
+
if modality not in items:
|
467
|
+
items[modality] = MultimodalDataItem(modality=modality)
|
468
|
+
|
469
|
+
# Set attribute
|
470
|
+
if hasattr(items[modality], attr_name):
|
471
|
+
setattr(items[modality], attr_name, value)
|
472
|
+
|
473
|
+
return list(items.values())
|
474
|
+
|
475
|
+
def _process_and_collect_mm_items(
|
476
|
+
self, input_text: str, images=None, audios=None, videos=None, **kwargs
|
477
|
+
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
478
|
+
"""
|
479
|
+
Helper method to process multimodal data and create mm_items in one step.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
Tuple of (created mm_items, input_ids)
|
483
|
+
"""
|
484
|
+
ret = self.process_mm_data(
|
485
|
+
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
|
486
|
+
)
|
487
|
+
|
488
|
+
input_ids = ret["input_ids"].flatten()
|
489
|
+
collected_items = self.collect_mm_items_from_processor_output(ret)
|
490
|
+
|
491
|
+
return collected_items, input_ids
|
492
|
+
|
420
493
|
def process_and_combine_mm_data(
|
421
494
|
self, base_output: BaseMultiModalProcessorOutput
|
422
|
-
) -> Tuple[
|
495
|
+
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
423
496
|
"""
|
424
|
-
Process multimodal data and return the combined multimodal
|
425
|
-
|
497
|
+
Process multimodal data and return the combined multimodal items and input_ids.
|
498
|
+
Supports mixed modalities (images and audio in the same request).
|
426
499
|
|
427
500
|
Returns:
|
428
|
-
Tuple of (
|
501
|
+
Tuple of (list of mm_items, input_ids)
|
429
502
|
"""
|
503
|
+
# Collect all items and categorize them
|
504
|
+
all_items = (base_output.images or []) + (base_output.audios or [])
|
430
505
|
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
input_text,
|
506
|
+
# Handle text-only case
|
507
|
+
if not all_items:
|
508
|
+
input_ids = self._processor.tokenizer(
|
509
|
+
base_output.input_text,
|
435
510
|
return_tensors="pt",
|
436
511
|
add_special_tokens=True,
|
437
512
|
).input_ids.flatten()
|
513
|
+
return [], input_ids
|
514
|
+
|
515
|
+
dict_items, raw_images, raw_audios = [], [], []
|
516
|
+
for item in all_items:
|
517
|
+
if isinstance(item, dict):
|
518
|
+
dict_items.append(item)
|
519
|
+
elif isinstance(item, Image.Image):
|
520
|
+
raw_images.append(item)
|
521
|
+
elif isinstance(item, np.ndarray):
|
522
|
+
raw_audios.append(item)
|
523
|
+
else:
|
524
|
+
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
438
525
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
has_image = False
|
443
|
-
has_pixel_values = False
|
444
|
-
has_precomputed_features = False
|
445
|
-
has_audio = False
|
446
|
-
|
447
|
-
for mm_input in mm_inputs:
|
448
|
-
if isinstance(mm_input, Image.Image):
|
449
|
-
has_image = True
|
450
|
-
elif isinstance(mm_input, np.ndarray):
|
451
|
-
has_audio = True
|
452
|
-
elif isinstance(mm_input, dict):
|
453
|
-
if mm_input.get("precomputed_features", None) is not None:
|
454
|
-
has_precomputed_features = True
|
455
|
-
elif mm_input.get("pixel_values", None) is not None:
|
456
|
-
has_pixel_values = True
|
457
|
-
else:
|
458
|
-
raise ValueError(
|
459
|
-
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
460
|
-
)
|
461
|
-
else:
|
462
|
-
raise ValueError(
|
463
|
-
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
464
|
-
)
|
526
|
+
# Process items and get input_ids
|
527
|
+
all_collected_items = []
|
528
|
+
input_ids = None
|
465
529
|
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
)
|
470
|
-
if format_count > 1:
|
471
|
-
raise ValueError(
|
472
|
-
"Unsupported: mixture of multimodal input formats. "
|
473
|
-
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
474
|
-
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
475
|
-
)
|
476
|
-
|
477
|
-
if has_image:
|
478
|
-
return MultimodalInputFormat.RAW_IMAGES
|
479
|
-
elif has_precomputed_features:
|
480
|
-
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
481
|
-
elif has_pixel_values:
|
482
|
-
return MultimodalInputFormat.PIXEL_VALUES
|
483
|
-
elif has_audio:
|
484
|
-
return MultimodalInputFormat.AUDIO
|
485
|
-
else:
|
486
|
-
raise ValueError("No valid multimodal input format found")
|
487
|
-
except Exception as e:
|
488
|
-
raise ValueError(f"Failed to categorize inputs: {e}")
|
489
|
-
|
490
|
-
def process_raw_images(
|
491
|
-
base_output: BaseMultiModalProcessorOutput,
|
492
|
-
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
493
|
-
"""Process raw Image.Image objects using transformers processor."""
|
494
|
-
ret = self.process_mm_data(
|
495
|
-
input_text=base_output.input_text,
|
496
|
-
images=base_output.images,
|
497
|
-
)
|
498
|
-
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
499
|
-
|
500
|
-
# Copy all fields from processor output except input_ids
|
501
|
-
for key, value in ret.items():
|
502
|
-
if key != "input_ids" and hasattr(combined_mm_item, key):
|
503
|
-
setattr(combined_mm_item, key, value)
|
504
|
-
|
505
|
-
input_ids = ret["input_ids"].flatten()
|
506
|
-
return combined_mm_item, input_ids
|
507
|
-
|
508
|
-
def process_precomputed_features(
|
509
|
-
base_output: BaseMultiModalProcessorOutput,
|
510
|
-
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
511
|
-
"""Process inputs with precomputed features."""
|
512
|
-
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
513
|
-
combined_mm_item.precomputed_features = self._extract_processor_features(
|
514
|
-
base_output.images, "precomputed_features"
|
530
|
+
# Handle dict items (already processed)
|
531
|
+
for dict_item in dict_items:
|
532
|
+
all_collected_items.extend(
|
533
|
+
self.collect_mm_items_from_processor_output(dict_item)
|
515
534
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
base_output: BaseMultiModalProcessorOutput,
|
521
|
-
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
522
|
-
"""Process inputs with pixel values."""
|
523
|
-
values = self._extract_processor_features_from_all_attributes(
|
524
|
-
base_output.images
|
525
|
-
)
|
526
|
-
combined_mm_item = MultimodalDataItem.from_dict(values)
|
527
|
-
input_ids = tokenize_text(base_output.input_text)
|
528
|
-
return combined_mm_item, input_ids
|
529
|
-
|
530
|
-
def process_audio(
|
531
|
-
base_output: BaseMultiModalProcessorOutput,
|
532
|
-
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
533
|
-
"""Process inputs with audio."""
|
534
|
-
ret = self.process_mm_data(
|
535
|
+
|
536
|
+
# Handle raw items (need processing)
|
537
|
+
if raw_images or raw_audios:
|
538
|
+
collected_items, input_ids = self._process_and_collect_mm_items(
|
535
539
|
input_text=base_output.input_text,
|
536
|
-
|
540
|
+
images=raw_images,
|
541
|
+
audios=raw_audios,
|
537
542
|
)
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
input_ids =
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
543
|
+
all_collected_items.extend(collected_items)
|
544
|
+
|
545
|
+
# Fallback tokenization if no raw items were processed
|
546
|
+
if input_ids is None:
|
547
|
+
input_ids = self._processor.tokenizer(
|
548
|
+
base_output.input_text,
|
549
|
+
return_tensors="pt",
|
550
|
+
add_special_tokens=True,
|
551
|
+
).input_ids.flatten()
|
552
|
+
|
553
|
+
# Add offsets to all items
|
554
|
+
for mm_item in all_collected_items:
|
555
|
+
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
556
|
+
mm_item.image_offsets = self.get_mm_items_offset(
|
551
557
|
input_ids=input_ids,
|
552
558
|
mm_token_id=self.IM_TOKEN_ID,
|
553
559
|
)
|
554
|
-
elif
|
555
|
-
|
560
|
+
elif mm_item.modality == Modality.AUDIO:
|
561
|
+
mm_item.audio_offsets = self.get_mm_items_offset(
|
556
562
|
input_ids=input_ids,
|
557
563
|
mm_token_id=self.AUDIO_TOKEN_ID,
|
558
564
|
)
|
559
|
-
elif
|
560
|
-
|
565
|
+
elif mm_item.modality == Modality.VIDEO:
|
566
|
+
mm_item.video_offsets = self.get_mm_items_offset(
|
561
567
|
input_ids=input_ids,
|
562
568
|
mm_token_id=self.VIDEO_TOKEN_ID,
|
563
569
|
)
|
564
570
|
else:
|
565
|
-
raise ValueError(f"Unknown modality: {
|
566
|
-
return combined_mm_item
|
567
|
-
|
568
|
-
# Main logic - determine input type and handle text-only case
|
569
|
-
mm_inputs = base_output.images or base_output.audios
|
570
|
-
if not mm_inputs:
|
571
|
-
input_ids = tokenize_text(base_output.input_text)
|
572
|
-
return None, input_ids
|
573
|
-
|
574
|
-
# Categorize input formats
|
575
|
-
input_format = categorize_mm_inputs(mm_inputs)
|
576
|
-
|
577
|
-
# Process based on format
|
578
|
-
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
579
|
-
combined_mm_item, input_ids = process_raw_images(base_output)
|
580
|
-
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
581
|
-
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
582
|
-
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
583
|
-
combined_mm_item, input_ids = process_pixel_values(base_output)
|
584
|
-
elif input_format == MultimodalInputFormat.AUDIO:
|
585
|
-
combined_mm_item, input_ids = process_audio(base_output)
|
586
|
-
else:
|
587
|
-
raise ValueError(f"Unknown input format: {input_format}")
|
571
|
+
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
588
572
|
|
589
|
-
|
590
|
-
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
591
|
-
return combined_mm_item, input_ids
|
573
|
+
return all_collected_items, input_ids
|
@@ -1,10 +1,8 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
|
-
from sglang.srt.managers.multimodal_processors.base_processor import (
|
4
|
-
BaseMultimodalProcessor,
|
5
|
-
)
|
6
3
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
7
4
|
from sglang.srt.models.clip import CLIPModel
|
5
|
+
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
8
6
|
from sglang.srt.utils import load_image
|
9
7
|
|
10
8
|
|
@@ -17,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
|
17
15
|
async def process_mm_data_async(
|
18
16
|
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
19
17
|
):
|
20
|
-
if not image_data:
|
21
|
-
return None
|
22
|
-
|
23
18
|
if isinstance(input_text, list):
|
24
19
|
assert len(input_text) and isinstance(input_text[0], int)
|
25
20
|
input_text = self._processor.tokenizer.decode(input_text)
|
26
21
|
|
27
|
-
|
28
|
-
image_data = [image_data]
|
29
|
-
|
30
|
-
if len(image_data) > 0:
|
31
|
-
images = [load_image(image)[0] for image in image_data]
|
32
|
-
else:
|
33
|
-
images = load_image(image_data[0])[0]
|
22
|
+
images = [load_image(image)[0] for image in image_data]
|
34
23
|
|
35
24
|
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
36
25
|
image_inputs["data_hashes"] = [hash(str(image_data))]
|