sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,3 @@
1
- # Copyright 2023-2024 SGLang Team
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- #
6
- # http://www.apache.org/licenses/LICENSE-2.0
7
- #
8
- # Unless required by applicable law or agreed to in writing, software
9
- # distributed under the License is distributed on an "AS IS" BASIS,
10
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
- # See the License for the specific language governing permissions and
12
- # limitations under the License.
13
- # ==========================582====================================================
14
1
  from typing import Iterable, List, Optional, Set, Tuple, Union
15
2
 
16
3
  import torch
@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
23
10
  from transformers.activations import ACT2FN
24
11
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
25
12
 
13
+ from sglang.srt.distributed import parallel_state
26
14
  from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
15
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
27
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
17
  from sglang.srt.managers.mm_utils import (
29
18
  MultiModalityDataPaddingPatternTokenPairs,
@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
28
  from sglang.srt.models.deepseek_janus_pro import DropPath
40
29
  from sglang.srt.models.internlm2 import InternLM2ForCausalLM
41
30
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
31
+ from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
42
32
  from sglang.utils import logger
43
33
 
44
34
 
@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
53
43
  self.embed_dim = config.hidden_size
54
44
  self.num_heads = config.num_attention_heads
55
45
  self.head_dim = self.embed_dim // self.num_heads
56
-
57
46
  self.scale = self.head_dim**-0.5
58
47
 
59
48
  self.attn = VisionAttention(
@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
64
53
  use_qkv_parallel=True,
65
54
  quant_config=quant_config,
66
55
  dropout=getattr(config, "dropout", 0.0),
67
- proj_bias=getattr(config, "qkv_bias", True),
56
+ qkv_bias=getattr(config, "qkv_bias", False)
57
+ or getattr(config, "attention_bias", False),
58
+ num_dummy_heads=getattr(config, "num_dummy_heads", 0),
59
+ qk_normalization=getattr(config, "qk_normalization", False)
60
+ or getattr(config, "use_qk_norm", False),
68
61
  flatten_batch=False,
69
62
  )
70
63
 
71
64
  self.proj_drop = nn.Dropout(config.dropout)
72
65
 
73
- self.qk_normalization = config.qk_normalization
74
-
75
- if self.qk_normalization:
76
- self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
77
- self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
78
-
79
66
  def forward(
80
67
  self,
81
68
  hidden_states: torch.Tensor,
@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
91
78
  super().__init__()
92
79
  self.config = config
93
80
  self.embed_dim = config.hidden_size
94
- self.image_size = config.image_size
95
- self.patch_size = config.patch_size
81
+ self.image_size = (
82
+ config.image_size
83
+ if isinstance(config.image_size, int)
84
+ else config.image_size[0]
85
+ )
86
+ self.patch_size = (
87
+ config.patch_size
88
+ if isinstance(config.patch_size, int)
89
+ else config.patch_size[0]
90
+ )
96
91
 
97
92
  self.class_embedding = nn.Parameter(
98
93
  torch.randn(1, 1, self.embed_dim),
@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
199
194
  self.embed_dim = config.hidden_size
200
195
  self.intermediate_size = config.intermediate_size
201
196
  self.norm_type = config.norm_type
202
- self.attn = InternAttention(config)
197
+ self.attn = InternAttention(config=config, quant_config=quant_config)
203
198
  self.mlp = InternMLP(config)
204
199
  self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
205
200
  self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
417
412
  super().__init__()
418
413
  self.config = config
419
414
  self.quant_config = quant_config
420
-
415
+ self._update_vision_config()
421
416
  image_size = config.force_image_size or config.vision_config.image_size
422
417
  patch_size = config.vision_config.patch_size
423
418
  self.patch_size = patch_size
@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
446
441
  self.language_model = InternLM2ForCausalLM(
447
442
  config=config.llm_config, quant_config=quant_config
448
443
  )
444
+ elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
445
+ self.language_model = Qwen3MoeForCausalLM(
446
+ config=config.llm_config, quant_config=quant_config
447
+ )
449
448
  else:
450
449
  raise NotImplementedError(
451
450
  f"{config.llm_config.architectures[0]} is not implemented."
@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
463
462
  nn.Linear(llm_hidden_size, llm_hidden_size),
464
463
  )
465
464
 
465
+ def _update_vision_config(self):
466
+ """update vision config to support tp"""
467
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
468
+ num_heads = self.config.vision_config.num_attention_heads
469
+ head_dim = self.config.vision_config.hidden_size // num_heads
470
+ num_dummy_heads = 0
471
+
472
+ if num_heads % world_size != 0:
473
+ num_dummy_heads = (
474
+ (num_heads + world_size) // world_size
475
+ ) * world_size - num_heads
476
+
477
+ setattr(self.config.vision_config, "head_dim", head_dim)
478
+ setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
479
+
466
480
  def pixel_shuffle(self, x, scale_factor=0.5):
467
481
  n, w, h, c = x.size()
468
482
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
545
559
 
546
560
  return helper.pad_input_tokens(input_ids, mm_inputs)
547
561
 
562
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
563
+ """pad attn qkv weights for dummy heads"""
564
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
565
+ if num_dummy_heads == 0:
566
+ return loaded_weight
567
+ head_dim = self.config.vision_config.head_dim
568
+
569
+ if "attn.qkv_proj" in name:
570
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
571
+ if name.endswith(".weight"):
572
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
573
+ elif name.endswith(".bias"):
574
+ dummy_shape = [num_dummy_heads, head_dim]
575
+ else:
576
+ raise RuntimeError(f"Unsupported weight with name={name}")
577
+ pad_func = lambda x: torch.cat(
578
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
579
+ ).flatten(0, 1)
580
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
581
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
582
+ if "attn.proj.weight" in name:
583
+ padded_weight = loaded_weight.new_zeros(
584
+ loaded_weight.shape[0], head_dim * num_dummy_heads
585
+ )
586
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
587
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
588
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
589
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
590
+ return loaded_weight
591
+
548
592
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
593
+ expert_params_mapping = []
549
594
  if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
550
595
  stacked_params_mapping = [
551
596
  # (param_name, shard_name, shard_id)
@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
561
606
  ("gate_up_proj", "gate_proj", 0),
562
607
  ("gate_up_proj", "up_proj", 1),
563
608
  ]
609
+ elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
610
+ stacked_params_mapping = [
611
+ # (param_name, shard_name, shard_id)
612
+ ("qkv_proj", "q_proj", "q"),
613
+ ("qkv_proj", "k_proj", "k"),
614
+ ("qkv_proj", "v_proj", "v"),
615
+ ("gate_up_proj", "gate_proj", 0),
616
+ ("gate_up_proj", "up_proj", 1),
617
+ ]
618
+
619
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
620
+ ckpt_gate_proj_name="gate_proj",
621
+ ckpt_down_proj_name="down_proj",
622
+ ckpt_up_proj_name="up_proj",
623
+ num_experts=self.config.num_experts,
624
+ )
625
+
564
626
  params_dict = dict(self.named_parameters())
565
627
  loaded_params: Set[str] = set()
566
628
 
567
629
  for name, loaded_weight in weights:
568
630
  if "rotary_emb.inv_freq" in name:
569
631
  continue
632
+
570
633
  for param_name, weight_name, shard_id in stacked_params_mapping:
571
634
  if weight_name not in name:
572
635
  continue
636
+ # We have mlp.experts[0].gate_proj in the checkpoint.
637
+ # Since we handle the experts below in expert_params_mapping,
638
+ # we need to skip here BEFORE we update the name, otherwise
639
+ # name will be updated to mlp.experts[0].gate_up_proj, which
640
+ # will then be updated below in expert_params_mapping
641
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
642
+ if "mlp.experts" in name:
643
+ continue
573
644
  name = name.replace(weight_name, param_name)
574
645
  # Skip loading extra bias for GPTQ models.
575
646
  if name.endswith(".bias") and name not in params_dict:
@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
584
655
  name = name.replace(r"attn.", r"attn.attn.")
585
656
  name = name.replace(r"qkv.", r"qkv_proj.")
586
657
 
587
- # Skip loading extra bias for GPTQ models.
588
- if name.endswith(".bias") and name not in params_dict:
589
- continue
590
- param = params_dict[name]
591
- if "wqkv" in name:
592
- config = self.config
593
- kv_groups = config.num_attention_heads // config.num_key_value_heads
594
- head_dim = config.hidden_size // config.num_attention_heads
595
- loaded_weight = loaded_weight.view(
596
- -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
597
- )
598
- wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
599
- wq = wq.reshape(-1, wq.shape[-1])
600
- wk = wk.reshape(-1, wk.shape[-1])
601
- wv = wv.reshape(-1, wv.shape[-1])
658
+ for mapping in expert_params_mapping:
659
+ param_name, weight_name, expert_id, shard_id = mapping
660
+ if weight_name not in name:
661
+ continue
662
+ name = name.replace(weight_name, param_name)
663
+ param = params_dict[name]
602
664
  weight_loader = param.weight_loader
603
- weight_loader(param, wq, "q")
604
- weight_loader(param, wk, "k")
605
- weight_loader(param, wv, "v")
606
- else:
607
- weight_loader = getattr(
608
- param, "weight_loader", default_weight_loader
665
+ weight_loader(
666
+ param,
667
+ loaded_weight,
668
+ name,
669
+ shard_id=shard_id,
670
+ expert_id=expert_id,
609
671
  )
610
- weight_loader(param, loaded_weight)
672
+ break
673
+ else:
674
+ # Skip loading extra bias for GPTQ models.
675
+ if name.endswith(".bias") and name not in params_dict:
676
+ continue
677
+ param = params_dict[name]
678
+ if "wqkv" in name:
679
+ config = self.config
680
+ kv_groups = (
681
+ config.num_attention_heads // config.num_key_value_heads
682
+ )
683
+ head_dim = config.hidden_size // config.num_attention_heads
684
+ loaded_weight = loaded_weight.view(
685
+ -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
686
+ )
687
+ wq, wk, wv = torch.split(
688
+ loaded_weight, [kv_groups, 1, 1], dim=1
689
+ )
690
+ wq = wq.reshape(-1, wq.shape[-1])
691
+ wk = wk.reshape(-1, wk.shape[-1])
692
+ wv = wv.reshape(-1, wv.shape[-1])
693
+ weight_loader = param.weight_loader
694
+ weight_loader(param, wq, "q")
695
+ weight_loader(param, wk, "k")
696
+ weight_loader(param, wv, "v")
697
+ else:
698
+ weight_loader = getattr(
699
+ param, "weight_loader", default_weight_loader
700
+ )
701
+ if "vision_model" in name:
702
+ loaded_weight = self._pad_vit_attn_dummy_heads(
703
+ name, loaded_weight
704
+ )
705
+ weight_loader(param, loaded_weight)
706
+
611
707
  loaded_params.add(name)
612
708
  unloaded_params = params_dict.keys() - loaded_params
613
709
  if unloaded_params:
@@ -656,11 +656,15 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
656
656
  self, auto_model_type: Type[AutoModel]
657
657
  ) -> Dict[str, str]:
658
658
  mapping = {}
659
- for config_cls, archs in auto_model_type._model_mapping.items():
660
- if isinstance(archs, tuple):
661
- mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
662
- else:
663
- mapping[config_cls.__name__] = archs.__name__
659
+ for config_cls in auto_model_type._model_mapping.keys():
660
+ archs = auto_model_type._model_mapping.get(config_cls, None)
661
+ if archs is not None:
662
+ if isinstance(archs, tuple):
663
+ mapping[config_cls.__name__] = tuple(
664
+ arch.__name__ for arch in archs
665
+ )
666
+ else:
667
+ mapping[config_cls.__name__] = archs.__name__
664
668
  return mapping
665
669
 
666
670
  def __init__(
@@ -1134,7 +1134,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
1134
1134
  """
1135
1135
  residual = hidden_states
1136
1136
  hidden_states = self.self_attn_layer_norm(hidden_states)
1137
- hidden_states, attn_weights, past_key_values = self.self_attn(
1137
+ # TODO (lifuhuang): confirmed with Mick that the logic for past_key_values is copied from minicpmo official code,
1138
+ # currently we are not using past_key_values at all. We need to redesign the caching logic when we support streaming
1139
+ # in the future.
1140
+ hidden_states, attn_weights = self.self_attn(
1138
1141
  hidden_states=hidden_states,
1139
1142
  attention_mask=attention_mask,
1140
1143
  layer_head_mask=layer_head_mask,
@@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import (
23
23
  Modality,
24
24
  MultimodalDataItem,
25
25
  MultimodalInputs,
26
+ global_server_args_dict,
26
27
  )
27
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
29
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module):
55
56
  self.quant_config = quant_config
56
57
 
57
58
  # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
58
- self.has_vision = self._has_vision_weights(config)
59
- if not self.has_vision:
59
+ self.has_vision_weights = self._has_vision_weights(config)
60
+ if not self.has_vision_weights:
60
61
  logger.warning(
61
62
  "No vision weights found in checkpoint. Model will run in text-only mode. "
62
63
  "Multimodal capabilities (image processing) will be unavailable."
63
64
  )
64
65
 
66
+ self.has_vision = (
67
+ self.has_vision_weights and global_server_args_dict["enable_multimodal"]
68
+ )
69
+
65
70
  if self.has_vision:
66
71
  self.vision_model = Llama4VisionModel(config.vision_config)
67
72
  self.multi_modal_projector = Llama4MultiModalProjector(config)
@@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module):
269
274
 
270
275
  def _should_skip_weight(self, name: str) -> bool:
271
276
  """Check if we should skip loading this weight."""
272
- return "vision" in name and not self.has_vision
277
+ return not self.has_vision and (
278
+ "vision" in name or "multi_modal_projector" in name
279
+ )
273
280
 
274
281
  def _transform_weight_name(self, name: str) -> str:
275
282
  """Transform weight name by adding language_model prefix if needed."""
@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
43
43
  ScatterMode,
44
44
  )
45
45
  from sglang.srt.layers.dp_attention import (
46
- attn_tp_all_gather,
47
- attn_tp_reduce_scatter,
48
- dp_gather_partial,
49
- dp_scatter,
50
46
  get_attention_tp_rank,
51
47
  get_attention_tp_size,
52
48
  get_local_attention_dp_size,
@@ -151,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
151
147
  # Additional args for FusedMoE
152
148
  **(
153
149
  dict(
154
- enable_flashinfer_moe=True,
150
+ enable_flashinfer_cutlass_moe=True,
155
151
  enable_ep_moe=global_server_args_dict["enable_ep_moe"],
156
152
  )
157
- if global_server_args_dict["enable_flashinfer_moe"]
153
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]
158
154
  else {}
159
155
  ),
160
156
  )
@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
40
  from sglang.srt.layers.dp_attention import (
41
- attn_tp_all_gather,
42
- attn_tp_reduce_scatter,
43
- dp_gather_partial,
44
- dp_scatter,
45
41
  get_attention_tp_rank,
46
42
  get_attention_tp_size,
47
43
  get_local_attention_dp_size,
@@ -124,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
124
120
  # Additional args for FusedMoE
125
121
  **(
126
122
  dict(
127
- enable_flashinfer_moe=True,
123
+ enable_flashinfer_cutlass_moe=True,
128
124
  enable_ep_moe=global_server_args_dict["enable_ep_moe"],
129
125
  )
130
- if global_server_args_dict["enable_flashinfer_moe"]
126
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]
131
127
  else {}
132
128
  ),
133
129
  )
@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
193
189
  def forward_deepep(
194
190
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
195
191
  ) -> torch.Tensor:
196
- forward_mode = forward_batch.forward_mode
197
- if is_non_idle_and_non_empty(forward_mode, hidden_states):
192
+ if hidden_states.shape[0] > 0:
198
193
  # router_logits: (num_tokens, n_experts)
199
194
  router_logits, _ = self.gate(hidden_states)
200
195
  topk_weights, topk_idx, _ = self.topk(
@@ -712,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module):
712
707
  self.logits_processor = LogitsProcessor(config)
713
708
  self.capture_aux_hidden_states = False
714
709
 
710
+ def get_input_embeddings(self) -> nn.Embedding:
711
+ return self.model.embed_tokens
712
+
715
713
  @torch.no_grad()
716
714
  def forward(
717
715
  self,
@@ -12,6 +12,7 @@ import torch
12
12
  from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
+ from sglang.srt.managers.mm_utils import TransportProxyTensor
15
16
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
16
17
  from sglang.srt.utils import load_audio, load_image, load_video, logger
17
18
 
@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
142
143
  class BaseMultimodalProcessor(ABC):
143
144
  models = []
144
145
 
145
- def __init__(self, hf_config, server_args, _processor):
146
+ def __init__(
147
+ self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
148
+ ):
146
149
  self.hf_config = hf_config
147
150
  self._processor = _processor
148
151
  self.arch = hf_config.architectures[0]
149
152
  self.server_args = server_args
153
+ self.transport_mode = transport_mode
150
154
 
151
155
  # FIXME: not accurate, model and image specific
152
156
  self.NUM_TOKEN_PER_FRAME = 330
@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
217
221
  return_tensors="pt",
218
222
  **kwargs,
219
223
  )
220
- if "pixel_values" in result and isinstance(
221
- result["pixel_values"], torch.Tensor
222
- ):
223
- result["pixel_values"] = result["pixel_values"].to("cpu")
224
224
  return result
225
225
 
226
226
  @abstractmethod
@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
500
500
  ) -> List[MultimodalDataItem]:
501
501
  """Create mm_items directly from processor output."""
502
502
  items: dict[Modality, MultimodalDataItem] = {}
503
-
504
503
  for attr_name, value in data_dict.items():
505
504
  if attr_name == "input_ids":
506
505
  continue
@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
624
623
  mm_token_id=mm_token_id,
625
624
  )
626
625
 
626
+ # post-process
627
+ for item in all_collected_items:
628
+ # replace the feature tensor with a proxy
629
+ if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
630
+ item.feature = TransportProxyTensor(
631
+ transport_mode=self.transport_mode, data=item.feature
632
+ )
633
+ elif (
634
+ isinstance(item.precomputed_embeddings, torch.Tensor)
635
+ and item.precomputed_embeddings.is_cuda
636
+ ):
637
+ item.precomputed_embeddings = TransportProxyTensor(
638
+ transport_mode=self.transport_mode, data=item.precomputed_embeddings
639
+ )
640
+
627
641
  return all_collected_items, input_ids, ret
@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
10
10
  class ClipImageProcessor(BaseMultimodalProcessor):
11
11
  models = [CLIPModel]
12
12
 
13
- def __init__(self, hf_config, server_args, _processor):
14
- super().__init__(hf_config, server_args, _processor)
13
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
14
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
15
15
  self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
16
16
  _processor
17
17
  )
@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
31
31
  class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
32
32
  models = [DeepseekVL2ForCausalLM]
33
33
 
34
- def __init__(self, hf_config, server_args, _processor):
35
- super().__init__(hf_config, server_args, _processor)
34
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
35
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
36
36
  self.mm_tokens = MultimodalSpecialTokens(
37
37
  image_token="<image>", image_token_id=self._processor.image_token_id
38
38
  ).build(_processor)
@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
14
14
  class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
15
15
  models = [Gemma3ForConditionalGeneration]
16
16
 
17
- def __init__(self, hf_config, server_args, _processor):
18
- super().__init__(hf_config, server_args, _processor)
17
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
18
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
19
19
  self.IM_START_TOKEN_ID = hf_config.boi_token_index
20
20
  self.IM_END_TOKEN_ID = hf_config.eoi_token_index
21
21
  self.mm_tokens = MultimodalSpecialTokens(
@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
27
27
 
28
28
  models = [Gemma3nForConditionalGeneration]
29
29
 
30
- def __init__(self, hf_config, server_args, _processor):
31
- super().__init__(hf_config, server_args, _processor)
30
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
31
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
32
32
 
33
33
  self.IM_START_TOKEN_ID = hf_config.boi_token_id
34
34
  self.IM_END_TOKEN_ID = hf_config.eoi_token_id
@@ -6,6 +6,7 @@ from decord import VideoReader, cpu
6
6
  from PIL import Image
7
7
 
8
8
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9
+ from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
9
10
  from sglang.srt.models.internvl import InternVLChatModel
10
11
  from sglang.srt.multimodal.processors.base_processor import (
11
12
  BaseMultimodalProcessor,
@@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import (
14
15
 
15
16
 
16
17
  class InternVLImageProcessor(BaseMultimodalProcessor):
17
- models = [InternVLChatModel]
18
+ models = [InternVLChatModel, InternS1ForConditionalGeneration]
18
19
 
19
- def __init__(self, hf_config, server_args, _image_processor):
20
- super().__init__(hf_config, server_args, _image_processor)
21
- image_size = hf_config.force_image_size or hf_config.vision_config.image_size
20
+ def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
21
+ super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
22
+ image_size = (
23
+ getattr(hf_config, "force_image_size", None)
24
+ or hf_config.vision_config.image_size
25
+ )
22
26
  patch_size = hf_config.vision_config.patch_size
27
+ if isinstance(image_size, list):
28
+ image_size = image_size[0]
29
+ if isinstance(patch_size, list):
30
+ patch_size = patch_size[0]
23
31
 
24
32
  self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
25
33
  self.IMG_START_TOKEN = "<img>"
@@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
27
35
  self.num_image_token = int(
28
36
  (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
29
37
  )
38
+ if hasattr(self._processor, "tokenizer"):
39
+ tokenizer = self._processor.tokenizer
40
+ else:
41
+ tokenizer = self._processor
42
+ self.tokenizer = tokenizer
30
43
 
31
- tokenizer = self._processor
32
44
  self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
33
45
  self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
34
46
  self.mm_tokens = MultimodalSpecialTokens(
@@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
195
207
  try:
196
208
  # TODO: video input
197
209
  raw_image = process_image_internvl(image)
198
- pixel_value = [raw_image.to(torch.bfloat16).cuda()]
210
+ pixel_value = [raw_image.to(torch.bfloat16)]
199
211
  pixel_values += pixel_value
200
212
  num_patches = raw_image.shape[0]
201
213
  num_patches_list += [num_patches]
@@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
214
226
  )
215
227
  input_text = input_text.replace("<image>", image_tokens, 1)
216
228
 
217
- tokenizer = self._processor
218
- input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
229
+ input_ids = self.tokenizer(input_text, return_tensors="pt")[
230
+ "input_ids"
231
+ ].flatten()
219
232
  image_offsets = self.get_mm_items_offset(
220
233
  input_ids=input_ids,
221
234
  mm_token_id=self.mm_tokens.image_token_id,
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
11
11
  class JanusProImageProcessor(BaseMultimodalProcessor):
12
12
  models = [MultiModalityCausalLM]
13
13
 
14
- def __init__(self, hf_config, server_args, _processor):
15
- super().__init__(hf_config, server_args, _processor)
14
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
15
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
16
16
 
17
17
  self.mm_tokens = MultimodalSpecialTokens(
18
18
  image_token=_processor.image_token,
@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
12
12
  class KimiVLImageProcessor(SGLangBaseProcessor):
13
13
  models = [KimiVLForConditionalGeneration]
14
14
 
15
- def __init__(self, hf_config, server_args, _processor):
16
- super().__init__(hf_config, server_args, _processor)
15
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
16
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
17
17
  self.mm_tokens = MultimodalSpecialTokens(
18
18
  image_token="<|media_pad|>",
19
19
  # TODO: could we convert in MultimodalSpecialTokens?
@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
30
30
  LlavaMistralForCausalLM,
31
31
  ]
32
32
 
33
- def __init__(self, hf_config, server_args, _processor):
34
- super().__init__(hf_config, server_args, _processor)
33
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
34
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
35
35
 
36
36
  @staticmethod
37
37
  def _process_single_image_task(
@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
187
187
  f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
188
188
  )
189
189
 
190
- def __init__(self, hf_config, server_args, _processor):
190
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
191
191
  assert hasattr(hf_config, "vision_config")
192
192
  assert hasattr(hf_config, "text_config")
193
193
  self.vision_config = hf_config.vision_config
@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
196
196
 
197
197
  if vision_type := getattr(self.vision_config, "model_type"):
198
198
  self.inner = self._get_sgl_processor_cls(vision_type)(
199
- hf_config, server_args, _processor
199
+ hf_config, server_args, _processor, *args, **kwargs
200
200
  )
201
201
  else:
202
202
  raise ValueError(