sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@
18
18
  """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
19
 
20
20
  import logging
21
- from typing import Any, Dict, Iterable, Optional, Tuple
21
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
22
22
 
23
23
  import torch
24
24
  from torch import nn
@@ -32,6 +32,9 @@ from sglang.srt.distributed import (
32
32
  tensor_model_parallel_all_gather,
33
33
  tensor_model_parallel_all_reduce,
34
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
35
38
  from sglang.srt.layers.activation import SiluAndMul
36
39
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
37
40
  from sglang.srt.layers.dp_attention import (
@@ -63,12 +66,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
63
66
  ParallelLMHead,
64
67
  VocabParallelEmbedding,
65
68
  )
66
- from sglang.srt.managers.expert_distribution import (
67
- get_global_expert_distribution_recorder,
68
- )
69
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
70
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
71
69
  from sglang.srt.managers.schedule_batch import global_server_args_dict
70
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
72
71
  from sglang.srt.model_executor.forward_batch_info import (
73
72
  ForwardBatch,
74
73
  ForwardMode,
@@ -78,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
78
77
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
79
78
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
80
79
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
81
- from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
80
+ from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
82
81
 
83
82
  Qwen3MoeConfig = None
84
83
 
85
84
  logger = logging.getLogger(__name__)
85
+ _is_cuda = is_cuda()
86
86
 
87
87
 
88
88
  class Qwen3MoeSparseMoeBlock(nn.Module):
@@ -117,6 +117,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
117
117
  if global_server_args_dict["enable_deepep_moe"]
118
118
  else {}
119
119
  ),
120
+ # Additional args for FusedMoE
121
+ **(
122
+ dict(
123
+ enable_flashinfer_moe=True,
124
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
125
+ )
126
+ if global_server_args_dict["enable_flashinfer_moe"]
127
+ else {}
128
+ ),
120
129
  )
121
130
 
122
131
  self.gate = ReplicatedLinear(
@@ -220,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
220
229
  hidden_states=hidden_states,
221
230
  topk_idx=topk_idx,
222
231
  topk_weights=topk_weights,
223
- forward_mode=forward_mode,
232
+ forward_batch=forward_batch,
224
233
  )
225
234
  final_hidden_states = self.experts(
226
235
  hidden_states=hidden_states,
@@ -231,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
231
240
  masked_m=masked_m,
232
241
  expected_m=expected_m,
233
242
  num_recv_tokens_per_expert=num_recv_tokens_per_expert,
234
- forward_mode=forward_mode,
243
+ forward_batch=forward_batch,
235
244
  )
236
245
  if self.ep_size > 1:
237
246
  final_hidden_states = self.deepep_dispatcher.combine(
238
247
  hidden_states=final_hidden_states,
239
248
  topk_idx=topk_idx,
240
249
  topk_weights=topk_weights,
241
- forward_mode=forward_mode,
250
+ forward_batch=forward_batch,
242
251
  )
243
252
  return final_hidden_states
244
253
 
@@ -284,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
284
293
  hidden_states=state.pop("hidden_states_mlp_input"),
285
294
  topk_idx=state.pop("topk_idx_local"),
286
295
  topk_weights=state.pop("topk_weights_local"),
287
- forward_mode=state.forward_batch.forward_mode,
296
+ forward_batch=state.forward_batch,
288
297
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
289
298
  )
290
299
 
@@ -316,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
316
325
  masked_m=state.pop("masked_m"),
317
326
  expected_m=state.pop("expected_m"),
318
327
  num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
319
- forward_mode=state.forward_batch.forward_mode,
328
+ forward_batch=state.forward_batch,
320
329
  )
321
330
 
322
331
  def op_combine_a(self, state):
@@ -325,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
325
334
  hidden_states=state.pop("hidden_states_experts_output"),
326
335
  topk_idx=state.pop("topk_idx_dispatched"),
327
336
  topk_weights=state.pop("topk_weights_dispatched"),
328
- forward_mode=state.forward_batch.forward_mode,
337
+ forward_batch=state.forward_batch,
329
338
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
330
339
  )
331
340
 
@@ -354,6 +363,7 @@ class Qwen3MoeAttention(nn.Module):
354
363
  attention_bias: bool = False,
355
364
  quant_config: Optional[QuantizationConfig] = None,
356
365
  prefix: str = "",
366
+ alt_stream: Optional[torch.cuda.Stream] = None,
357
367
  ) -> None:
358
368
  super().__init__()
359
369
  self.hidden_size = hidden_size
@@ -423,15 +433,27 @@ class Qwen3MoeAttention(nn.Module):
423
433
 
424
434
  self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
425
435
  self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
436
+ self.alt_stream = alt_stream
426
437
 
427
438
  def _apply_qk_norm(
428
439
  self, q: torch.Tensor, k: torch.Tensor
429
440
  ) -> Tuple[torch.Tensor, torch.Tensor]:
430
- q_by_head = q.reshape(-1, self.head_dim)
431
- q_by_head = self.q_norm(q_by_head)
441
+ # overlap qk norm
442
+ if self.alt_stream is not None and get_is_capture_mode():
443
+ current_stream = torch.cuda.current_stream()
444
+ self.alt_stream.wait_stream(current_stream)
445
+ q_by_head = q.reshape(-1, self.head_dim)
446
+ q_by_head = self.q_norm(q_by_head)
447
+ with torch.cuda.stream(self.alt_stream):
448
+ k_by_head = k.reshape(-1, self.head_dim)
449
+ k_by_head = self.k_norm(k_by_head)
450
+ current_stream.wait_stream(self.alt_stream)
451
+ else:
452
+ q_by_head = q.reshape(-1, self.head_dim)
453
+ q_by_head = self.q_norm(q_by_head)
454
+ k_by_head = k.reshape(-1, self.head_dim)
455
+ k_by_head = self.k_norm(k_by_head)
432
456
  q = q_by_head.view(q.shape)
433
- k_by_head = k.reshape(-1, self.head_dim)
434
- k_by_head = self.k_norm(k_by_head)
435
457
  k = k_by_head.view(k.shape)
436
458
  return q, k
437
459
 
@@ -491,6 +513,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
491
513
  layer_id: int,
492
514
  quant_config: Optional[QuantizationConfig] = None,
493
515
  prefix: str = "",
516
+ alt_stream: Optional[torch.cuda.Stream] = None,
494
517
  ) -> None:
495
518
  super().__init__()
496
519
  self.config = config
@@ -516,6 +539,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
516
539
  attention_bias=attention_bias,
517
540
  quant_config=quant_config,
518
541
  prefix=add_prefix("self_attn", prefix),
542
+ alt_stream=alt_stream,
519
543
  )
520
544
 
521
545
  self.layer_id = layer_id
@@ -623,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
623
647
 
624
648
  def op_mlp(self, state):
625
649
  hidden_states = state.pop("hidden_states_mlp_input")
626
- state.hidden_states_mlp_output = self.mlp(
627
- hidden_states, state.forward_batch.forward_mode
628
- )
650
+ state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
629
651
 
630
652
  def op_comm_postprocess_layer(self, state):
631
653
  hidden_states, residual = self.layer_communicator.postprocess_layer(
@@ -659,11 +681,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
659
681
  quant_config: Optional[QuantizationConfig] = None,
660
682
  prefix: str = "",
661
683
  ) -> None:
684
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
662
685
  super().__init__(
663
686
  config=config,
664
687
  quant_config=quant_config,
665
688
  prefix=prefix,
666
689
  decoder_layer_type=Qwen3MoeDecoderLayer,
690
+ alt_stream=alt_stream,
667
691
  )
668
692
 
669
693
 
@@ -691,6 +715,7 @@ class Qwen3MoeForCausalLM(nn.Module):
691
715
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
692
716
  )
693
717
  self.logits_processor = LogitsProcessor(config)
718
+ self.capture_aux_hidden_states = False
694
719
 
695
720
  @torch.no_grad()
696
721
  def forward(
@@ -709,9 +734,13 @@ class Qwen3MoeForCausalLM(nn.Module):
709
734
  pp_proxy_tensors=pp_proxy_tensors,
710
735
  )
711
736
 
737
+ aux_hidden_states = None
738
+ if self.capture_aux_hidden_states:
739
+ hidden_states, aux_hidden_states = hidden_states
740
+
712
741
  if self.pp_group.is_last_rank:
713
742
  return self.logits_processor(
714
- input_ids, hidden_states, self.lm_head, forward_batch
743
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
715
744
  )
716
745
  else:
717
746
  return hidden_states
@@ -724,6 +753,24 @@ class Qwen3MoeForCausalLM(nn.Module):
724
753
  def end_layer(self):
725
754
  return self.model.end_layer
726
755
 
756
+ def get_embed_and_head(self):
757
+ return self.model.embed_tokens.weight, self.lm_head.weight
758
+
759
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
760
+ if not self.pp_group.is_last_rank:
761
+ return
762
+
763
+ self.capture_aux_hidden_states = True
764
+ if layer_ids is None:
765
+ num_layers = self.config.num_hidden_layers
766
+ self.model.layers_to_capture = [
767
+ 2,
768
+ num_layers // 2,
769
+ num_layers - 3,
770
+ ] # Specific layers for EAGLE3 support
771
+ else:
772
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
773
+
727
774
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
728
775
  stacked_params_mapping = [
729
776
  # (param_name, shard_name, shard_id)
sglang/srt/models/vila.py CHANGED
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
270
270
  weight_loader(param, loaded_weight)
271
271
 
272
272
  def pad_input_ids(
273
- self,
274
- input_ids: List[int],
275
- image_inputs: MultimodalInputs,
273
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
276
274
  ) -> List[int]:
277
- pattern = MultiModalityDataPaddingPatternMultimodalTokens(
278
- token_ids=[self.config.image_token_id],
279
- )
280
-
281
- return pattern.pad_input_tokens(input_ids, image_inputs)
275
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
276
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
282
277
 
283
278
  ##### BEGIN COPY modeling_vila.py #####
284
279
 
@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
17
17
  from sglang.srt.utils import encode_video, load_audio, load_image
18
18
 
19
19
 
20
- class MultimodalInputFormat(Enum):
21
- """Enum for different multimodal input formats."""
22
-
23
- RAW_IMAGES = "raw_images"
24
- PRECOMPUTED_FEATURES = "precomputed_features"
25
- PIXEL_VALUES = "pixel_values"
26
- AUDIO = "audio"
27
-
28
-
29
20
  @dataclasses.dataclass
30
21
  class BaseMultiModalProcessorOutput:
31
22
  # input_text, with each frame of video/image represented with a image_token
@@ -98,6 +89,7 @@ class BaseMultimodalProcessor(ABC):
98
89
  self._processor = _processor
99
90
  self.arch = hf_config.architectures[0]
100
91
  self.server_args = server_args
92
+
101
93
  # FIXME: not accurate, model and image specific
102
94
  self.NUM_TOKEN_PER_FRAME = 330
103
95
 
@@ -109,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
109
101
  max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
110
102
  )
111
103
 
104
+ # Mapping from attribute names to modality types
105
+ self.ATTR_NAME_TO_MODALITY = {
106
+ # Image-related attributes
107
+ "pixel_values": Modality.IMAGE,
108
+ "image_sizes": Modality.IMAGE,
109
+ "image_grid_thw": Modality.IMAGE,
110
+ "image_emb_mask": Modality.IMAGE,
111
+ "image_spatial_crop": Modality.IMAGE,
112
+ "tgt_size": Modality.IMAGE,
113
+ "image_grid_hws": Modality.IMAGE,
114
+ "aspect_ratio_id": Modality.IMAGE,
115
+ "aspect_ratio_mask": Modality.IMAGE,
116
+ "second_per_grid_ts": Modality.IMAGE,
117
+ # Audio-related attributes
118
+ "audio_features": Modality.AUDIO,
119
+ "audio_feature_lens": Modality.AUDIO,
120
+ "input_features": Modality.AUDIO,
121
+ "input_features_mask": Modality.AUDIO,
122
+ # Video-related attributes
123
+ "video_grid_thws": Modality.VIDEO,
124
+ # Generic attributes that could apply to multiple modalities
125
+ # "precomputed_features" - handled specially as it can be any modality
126
+ }
127
+
112
128
  def process_mm_data(
113
129
  self, input_text, images=None, videos=None, audios=None, **kwargs
114
130
  ):
115
131
  """
116
132
  process multimodal data with transformers AutoProcessor
117
133
  """
118
- if images is not None:
134
+ if images:
119
135
  kwargs["images"] = images
120
- if videos is not None:
136
+ if videos:
121
137
  kwargs["videos"] = videos
122
- if audios is not None:
138
+ if audios:
123
139
  kwargs["audios"] = audios
140
+ if self.__class__.__name__ == "Gemma3nSGLangProcessor":
141
+ # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
142
+ kwargs["audio"] = audios
124
143
 
125
144
  processor = self._processor
126
145
  if hasattr(processor, "image_processor") and isinstance(
@@ -143,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
143
162
  async def process_mm_data_async(
144
163
  self,
145
164
  image_data,
165
+ audio_data,
146
166
  input_text,
147
167
  request_obj,
148
168
  max_req_input_len,
@@ -417,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
417
437
  values[k] = v
418
438
  return values
419
439
 
440
+ def collect_mm_items_from_processor_output(
441
+ self, data_dict: dict
442
+ ) -> List[MultimodalDataItem]:
443
+ """Create mm_items directly from processor output."""
444
+ items = {} # modality -> MultimodalDataItem
445
+
446
+ for attr_name, value in data_dict.items():
447
+ if attr_name == "input_ids":
448
+ continue
449
+
450
+ # Get modality for this attribute
451
+ modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
452
+
453
+ if not modality and attr_name == "precomputed_features":
454
+ modality_str = data_dict.get("modality")
455
+ try:
456
+ modality = (
457
+ Modality.from_str(modality_str)
458
+ if modality_str
459
+ else Modality.IMAGE
460
+ )
461
+ except ValueError:
462
+ modality = Modality.IMAGE
463
+
464
+ if modality:
465
+ # Create item if needed
466
+ if modality not in items:
467
+ items[modality] = MultimodalDataItem(modality=modality)
468
+
469
+ # Set attribute
470
+ if hasattr(items[modality], attr_name):
471
+ setattr(items[modality], attr_name, value)
472
+
473
+ return list(items.values())
474
+
475
+ def _process_and_collect_mm_items(
476
+ self, input_text: str, images=None, audios=None, videos=None, **kwargs
477
+ ) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
478
+ """
479
+ Helper method to process multimodal data and create mm_items in one step.
480
+
481
+ Returns:
482
+ Tuple of (created mm_items, input_ids)
483
+ """
484
+ ret = self.process_mm_data(
485
+ input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
486
+ )
487
+
488
+ input_ids = ret["input_ids"].flatten()
489
+ collected_items = self.collect_mm_items_from_processor_output(ret)
490
+
491
+ return collected_items, input_ids
492
+
420
493
  def process_and_combine_mm_data(
421
494
  self, base_output: BaseMultiModalProcessorOutput
422
- ) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
495
+ ) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
423
496
  """
424
- Process multimodal data and return the combined multimodal item and input_ids.
425
- Handles all three input formats at the same abstraction level.
497
+ Process multimodal data and return the combined multimodal items and input_ids.
498
+ Supports mixed modalities (images and audio in the same request).
426
499
 
427
500
  Returns:
428
- Tuple of (combined_mm_item, input_ids)
501
+ Tuple of (list of mm_items, input_ids)
429
502
  """
503
+ # Collect all items and categorize them
504
+ all_items = (base_output.images or []) + (base_output.audios or [])
430
505
 
431
- def tokenize_text(input_text: str) -> torch.Tensor:
432
- """Tokenize input text."""
433
- return self._processor.tokenizer(
434
- input_text,
506
+ # Handle text-only case
507
+ if not all_items:
508
+ input_ids = self._processor.tokenizer(
509
+ base_output.input_text,
435
510
  return_tensors="pt",
436
511
  add_special_tokens=True,
437
512
  ).input_ids.flatten()
513
+ return [], input_ids
514
+
515
+ dict_items, raw_images, raw_audios = [], [], []
516
+ for item in all_items:
517
+ if isinstance(item, dict):
518
+ dict_items.append(item)
519
+ elif isinstance(item, Image.Image):
520
+ raw_images.append(item)
521
+ elif isinstance(item, np.ndarray):
522
+ raw_audios.append(item)
523
+ else:
524
+ raise ValueError(f"Unknown multimodal item type: {type(item)}")
438
525
 
439
- def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
440
- """Categorize multimodal inputs and validate consistency."""
441
- try:
442
- has_image = False
443
- has_pixel_values = False
444
- has_precomputed_features = False
445
- has_audio = False
446
-
447
- for mm_input in mm_inputs:
448
- if isinstance(mm_input, Image.Image):
449
- has_image = True
450
- elif isinstance(mm_input, np.ndarray):
451
- has_audio = True
452
- elif isinstance(mm_input, dict):
453
- if mm_input.get("precomputed_features", None) is not None:
454
- has_precomputed_features = True
455
- elif mm_input.get("pixel_values", None) is not None:
456
- has_pixel_values = True
457
- else:
458
- raise ValueError(
459
- f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
460
- )
461
- else:
462
- raise ValueError(
463
- f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
464
- )
526
+ # Process items and get input_ids
527
+ all_collected_items = []
528
+ input_ids = None
465
529
 
466
- # Validate format consistency
467
- format_count = sum(
468
- [has_image, has_pixel_values, has_precomputed_features, has_audio]
469
- )
470
- if format_count > 1:
471
- raise ValueError(
472
- "Unsupported: mixture of multimodal input formats. "
473
- f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
474
- f"precomputed_features={has_precomputed_features}, audio={has_audio}"
475
- )
476
-
477
- if has_image:
478
- return MultimodalInputFormat.RAW_IMAGES
479
- elif has_precomputed_features:
480
- return MultimodalInputFormat.PRECOMPUTED_FEATURES
481
- elif has_pixel_values:
482
- return MultimodalInputFormat.PIXEL_VALUES
483
- elif has_audio:
484
- return MultimodalInputFormat.AUDIO
485
- else:
486
- raise ValueError("No valid multimodal input format found")
487
- except Exception as e:
488
- raise ValueError(f"Failed to categorize inputs: {e}")
489
-
490
- def process_raw_images(
491
- base_output: BaseMultiModalProcessorOutput,
492
- ) -> Tuple[MultimodalDataItem, torch.Tensor]:
493
- """Process raw Image.Image objects using transformers processor."""
494
- ret = self.process_mm_data(
495
- input_text=base_output.input_text,
496
- images=base_output.images,
497
- )
498
- combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
499
-
500
- # Copy all fields from processor output except input_ids
501
- for key, value in ret.items():
502
- if key != "input_ids" and hasattr(combined_mm_item, key):
503
- setattr(combined_mm_item, key, value)
504
-
505
- input_ids = ret["input_ids"].flatten()
506
- return combined_mm_item, input_ids
507
-
508
- def process_precomputed_features(
509
- base_output: BaseMultiModalProcessorOutput,
510
- ) -> Tuple[MultimodalDataItem, torch.Tensor]:
511
- """Process inputs with precomputed features."""
512
- combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
513
- combined_mm_item.precomputed_features = self._extract_processor_features(
514
- base_output.images, "precomputed_features"
530
+ # Handle dict items (already processed)
531
+ for dict_item in dict_items:
532
+ all_collected_items.extend(
533
+ self.collect_mm_items_from_processor_output(dict_item)
515
534
  )
516
- input_ids = tokenize_text(base_output.input_text)
517
- return combined_mm_item, input_ids
518
-
519
- def process_pixel_values(
520
- base_output: BaseMultiModalProcessorOutput,
521
- ) -> Tuple[MultimodalDataItem, torch.Tensor]:
522
- """Process inputs with pixel values."""
523
- values = self._extract_processor_features_from_all_attributes(
524
- base_output.images
525
- )
526
- combined_mm_item = MultimodalDataItem.from_dict(values)
527
- input_ids = tokenize_text(base_output.input_text)
528
- return combined_mm_item, input_ids
529
-
530
- def process_audio(
531
- base_output: BaseMultiModalProcessorOutput,
532
- ) -> Tuple[MultimodalDataItem, torch.Tensor]:
533
- """Process inputs with audio."""
534
- ret = self.process_mm_data(
535
+
536
+ # Handle raw items (need processing)
537
+ if raw_images or raw_audios:
538
+ collected_items, input_ids = self._process_and_collect_mm_items(
535
539
  input_text=base_output.input_text,
536
- audio=base_output.audios, # Note: "audio" is for gemma3n only
540
+ images=raw_images,
541
+ audios=raw_audios,
537
542
  )
538
- combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
539
- for key, value in ret.items():
540
- if key != "input_ids" and hasattr(combined_mm_item, key):
541
- setattr(combined_mm_item, key, value)
542
- input_ids = ret["input_ids"].flatten()
543
- return combined_mm_item, input_ids
544
-
545
- def finalize_mm_item(
546
- combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
547
- ) -> MultimodalDataItem:
548
- """Apply common post-processing to the multimodal item."""
549
- if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
550
- combined_mm_item.image_offsets = self.get_mm_items_offset(
543
+ all_collected_items.extend(collected_items)
544
+
545
+ # Fallback tokenization if no raw items were processed
546
+ if input_ids is None:
547
+ input_ids = self._processor.tokenizer(
548
+ base_output.input_text,
549
+ return_tensors="pt",
550
+ add_special_tokens=True,
551
+ ).input_ids.flatten()
552
+
553
+ # Add offsets to all items
554
+ for mm_item in all_collected_items:
555
+ if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
556
+ mm_item.image_offsets = self.get_mm_items_offset(
551
557
  input_ids=input_ids,
552
558
  mm_token_id=self.IM_TOKEN_ID,
553
559
  )
554
- elif combined_mm_item.modality == Modality.AUDIO:
555
- combined_mm_item.audio_offsets = self.get_mm_items_offset(
560
+ elif mm_item.modality == Modality.AUDIO:
561
+ mm_item.audio_offsets = self.get_mm_items_offset(
556
562
  input_ids=input_ids,
557
563
  mm_token_id=self.AUDIO_TOKEN_ID,
558
564
  )
559
- elif combined_mm_item.modality == Modality.VIDEO:
560
- combined_mm_item.video_offsets = self.get_mm_items_offset(
565
+ elif mm_item.modality == Modality.VIDEO:
566
+ mm_item.video_offsets = self.get_mm_items_offset(
561
567
  input_ids=input_ids,
562
568
  mm_token_id=self.VIDEO_TOKEN_ID,
563
569
  )
564
570
  else:
565
- raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
566
- return combined_mm_item
567
-
568
- # Main logic - determine input type and handle text-only case
569
- mm_inputs = base_output.images or base_output.audios
570
- if not mm_inputs:
571
- input_ids = tokenize_text(base_output.input_text)
572
- return None, input_ids
573
-
574
- # Categorize input formats
575
- input_format = categorize_mm_inputs(mm_inputs)
576
-
577
- # Process based on format
578
- if input_format == MultimodalInputFormat.RAW_IMAGES:
579
- combined_mm_item, input_ids = process_raw_images(base_output)
580
- elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
581
- combined_mm_item, input_ids = process_precomputed_features(base_output)
582
- elif input_format == MultimodalInputFormat.PIXEL_VALUES:
583
- combined_mm_item, input_ids = process_pixel_values(base_output)
584
- elif input_format == MultimodalInputFormat.AUDIO:
585
- combined_mm_item, input_ids = process_audio(base_output)
586
- else:
587
- raise ValueError(f"Unknown input format: {input_format}")
571
+ raise ValueError(f"Unknown modality: {mm_item.modality}")
588
572
 
589
- # Finalize with common processing
590
- combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
591
- return combined_mm_item, input_ids
573
+ return all_collected_items, input_ids
@@ -1,10 +1,8 @@
1
1
  from typing import List, Union
2
2
 
3
- from sglang.srt.managers.multimodal_processors.base_processor import (
4
- BaseMultimodalProcessor,
5
- )
6
3
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
7
4
  from sglang.srt.models.clip import CLIPModel
5
+ from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
8
6
  from sglang.srt.utils import load_image
9
7
 
10
8
 
@@ -17,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
17
15
  async def process_mm_data_async(
18
16
  self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
19
17
  ):
20
- if not image_data:
21
- return None
22
-
23
18
  if isinstance(input_text, list):
24
19
  assert len(input_text) and isinstance(input_text[0], int)
25
20
  input_text = self._processor.tokenizer.decode(input_text)
26
21
 
27
- if not isinstance(image_data, list):
28
- image_data = [image_data]
29
-
30
- if len(image_data) > 0:
31
- images = [load_image(image)[0] for image in image_data]
32
- else:
33
- images = load_image(image_data[0])[0]
22
+ images = [load_image(image)[0] for image in image_data]
34
23
 
35
24
  image_inputs = self.process_mm_data(input_text=input_text, images=images)
36
25
  image_inputs["data_hashes"] = [hash(str(image_data))]