sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -50,10 +50,11 @@ from sglang.srt.layers.linear import (
50
50
  )
51
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
- from sglang.srt.managers.multi_modality_padding import (
53
+ from sglang.srt.managers.mm_utils import (
54
54
  MultiModalityDataPaddingPatternTokenPairs,
55
+ general_mm_embed_routine,
55
56
  )
56
- from sglang.srt.managers.schedule_batch import ImageInputs
57
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
57
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
59
  from sglang.srt.model_loader.utils import set_default_torch_dtype
59
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
399
400
  )
400
401
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
401
402
 
402
- def get_input_embeddings(self):
403
+ def get_input_embeddings(self) -> nn.Embedding:
403
404
  return self.embeddings
404
405
 
405
406
  def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
762
763
  valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
763
764
  return valid_pairs_tensor
764
765
 
765
- def get_embedding(
766
- self,
767
- input_ids: torch.Tensor,
768
- image_inputs: Optional[MiniCPMVImageInputs],
769
- ) -> Tuple[torch.Tensor, torch.Tensor]:
770
- vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
771
-
772
- if image_inputs is None: # No image
773
- vision_hidden_states = torch.tensor([], device=input_ids.device)
774
- else:
775
- if image_inputs["type"] == "image_embeds":
776
- vision_hidden_states = (
777
- image_inputs["data"]
778
- .type(vlm_embedding.dtype)
779
- .to(vlm_embedding.device)
780
- )
781
- else:
782
- vision_hidden_states = self.get_vision_hidden_states(image_inputs)
783
- # See NOTE in _parse_and_validate_inputs
784
- image_bounds = image_inputs["image_bounds"]
785
- if len(image_bounds) > 0:
786
- image_indices = torch.stack(
787
- [
788
- torch.arange(start, end, dtype=torch.long)
789
- for start, end in image_bounds.tolist()
790
- ]
791
- ).to(vlm_embedding.device)
792
-
793
- vlm_embedding.scatter_(
794
- 0,
795
- image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
796
- vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
797
- )
798
-
799
- return vlm_embedding, vision_hidden_states
800
-
801
766
  def _parse_and_validate_inputs(
802
767
  self,
803
768
  input_ids: torch.Tensor,
@@ -828,7 +793,7 @@ class MiniCPMVBaseModel(nn.Module):
828
793
  )
829
794
 
830
795
  if isinstance(image_embeds, list):
831
- image_embeds = torch.concat(image_embeds)
796
+ image_embeds = torch.cat(image_embeds)
832
797
 
833
798
  return MiniCPMVImageEmbeddingInputs(
834
799
  image_bounds=image_bounds,
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
836
801
  type="image_embeds",
837
802
  )
838
803
 
839
- if not isinstance(pixel_values, (torch.Tensor, list)):
840
- raise ValueError(
841
- "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
842
- )
843
-
844
- if not isinstance(tgt_sizes, (torch.Tensor, list)):
845
- raise ValueError(
846
- "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
847
- )
848
-
849
- if len(pixel_values) != len(tgt_sizes):
850
- raise ValueError(
851
- "Inconsistent batch lengths, found: "
852
- f"{len(pixel_values)} vs. {len(tgt_sizes)}"
853
- )
854
-
855
- pixel_values_flat: List[torch.Tensor] = []
856
- tgt_sizes_flat: List[torch.Tensor] = []
857
- for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
858
- if len(pixel_b) != len(tgt_b):
859
- raise ValueError(
860
- "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
861
- )
862
-
863
- for pixel_n, tgt_n in zip(pixel_b, tgt_b):
864
- pixel_values_flat += pixel_n
865
- tgt_sizes_flat += tgt_n
866
-
867
- # NOTE: Input IDs does not contain image tokens during memory profiling,
868
- # so we allow it to be empty
869
- if len(pixel_values_flat) != len(tgt_sizes_flat):
870
- raise ValueError(
871
- "Inconsistent flattened lengths, found: "
872
- f"{len(pixel_values_flat)} vs. "
873
- f"{len(tgt_sizes_flat)}"
874
- )
875
-
876
- if len(pixel_values_flat) == 0:
877
- return None
878
-
879
804
  image_bounds = self._get_image_bounds(
880
805
  input_ids=input_ids,
881
806
  pad_values=pad_values,
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
886
811
  )
887
812
  return MiniCPMVImagePixelInputs(
888
813
  image_bounds=image_bounds.to(device=input_ids.device),
889
- data=pixel_values_flat,
890
- tgt_sizes=torch.stack(tgt_sizes_flat),
814
+ data=pixel_values,
815
+ tgt_sizes=tgt_sizes,
891
816
  type="pixel_values",
892
817
  )
893
818
 
819
+ def get_embedding(
820
+ self,
821
+ input_ids: torch.Tensor,
822
+ image_inputs: Optional[MiniCPMVImageInputs],
823
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
824
+ vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
825
+
826
+ if image_inputs is None: # No image
827
+ vision_hidden_states = torch.tensor([], device=input_ids.device)
828
+ else:
829
+ if image_inputs["type"] == "image_embeds":
830
+ vision_hidden_states = (
831
+ image_inputs["data"]
832
+ .type(vlm_embedding.dtype)
833
+ .to(vlm_embedding.device)
834
+ )
835
+ else:
836
+ vision_hidden_states = self.get_vision_hidden_states(image_inputs)
837
+ # See NOTE in _parse_and_validate_inputs
838
+ image_bounds = image_inputs["image_bounds"]
839
+ if len(image_bounds) > 0:
840
+ image_indices = torch.stack(
841
+ [
842
+ torch.arange(start, end, dtype=torch.long)
843
+ for start, end in image_bounds.tolist()
844
+ ]
845
+ ).to(vlm_embedding.device)
846
+
847
+ vlm_embedding.scatter_(
848
+ 0,
849
+ image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
850
+ vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
851
+ )
852
+
853
+ return vlm_embedding, vision_hidden_states
854
+
855
+ def get_input_embeddings(self) -> nn.Embedding:
856
+ return self.llm.get_input_embedding()
857
+
894
858
  def forward(
895
859
  self,
896
860
  input_ids: torch.Tensor,
@@ -898,59 +862,18 @@ class MiniCPMVBaseModel(nn.Module):
898
862
  forward_batch: ForwardBatch,
899
863
  **kwargs: Any,
900
864
  ) -> torch.Tensor:
901
- if (
902
- forward_batch.image_inputs is not None
903
- and len(forward_batch.image_inputs) > 0
904
- and forward_batch.image_inputs[0] is not None
905
- ):
906
- # TODO: bath
907
- kwargs.update(
908
- {
909
- "pixel_values": (
910
- None
911
- if forward_batch.image_inputs is None
912
- else [
913
- i.pixel_values
914
- for i in forward_batch.image_inputs
915
- if i is not None
916
- ]
917
- ),
918
- "tgt_sizes": (
919
- None
920
- if forward_batch.image_inputs is None
921
- else [
922
- i.tgt_sizes
923
- for i in forward_batch.image_inputs
924
- if i is not None
925
- ]
926
- ),
927
- "im_start_id": forward_batch.image_inputs[0].im_start_id,
928
- "im_end_id": forward_batch.image_inputs[0].im_end_id,
929
- "slice_start_id": forward_batch.image_inputs[0].slice_start_id,
930
- "slice_end_id": forward_batch.image_inputs[0].slice_end_id,
931
- "pad_values": forward_batch.image_inputs[0].pad_values,
932
- }
933
- )
934
-
935
- image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
936
-
937
- # Clamp input ids. This is because the input_ids for the image tokens are
938
- # filled with the hash values of the image for the prefix matching in the radix attention.
939
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
940
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
941
-
942
- vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
943
-
944
- # always pass the input via `inputs_embeds`
945
- # to make sure the computation graph is consistent
946
- # for `torch.compile` integration
947
- input_ids = None
865
+ inputs_embeds = general_mm_embed_routine(
866
+ input_ids=input_ids,
867
+ forward_batch=forward_batch,
868
+ embed_tokens=self.get_input_embeddings(),
869
+ mm_data_embedding_func=self.get_image_features,
870
+ )
948
871
 
949
872
  hidden_states = self.llm.model(
950
- input_ids=input_ids,
873
+ input_ids=None,
951
874
  positions=positions,
952
875
  forward_batch=forward_batch,
953
- input_embeds=vlm_embeddings,
876
+ input_embeds=inputs_embeds,
954
877
  )
955
878
 
956
879
  return self.logits_processor(
@@ -990,7 +913,7 @@ class MiniCPMVBaseModel(nn.Module):
990
913
  ) -> torch.Tensor:
991
914
  raise NotImplementedError
992
915
 
993
- def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
916
+ def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
994
917
  raise NotImplementedError
995
918
 
996
919
 
@@ -1100,12 +1023,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1100
1023
  )
1101
1024
  return vision_embedding
1102
1025
 
1103
- def get_vision_hidden_states(
1026
+ def get_image_features(
1104
1027
  self,
1105
- data: MiniCPMVImageInputs,
1028
+ image_inputs: MultimodalInputs,
1106
1029
  ) -> torch.Tensor:
1107
- pixel_values = data["data"]
1108
- tgt_sizes = data["tgt_sizes"]
1030
+ # list of tensors
1031
+ pixel_values = image_inputs.pixel_values
1032
+
1033
+ tgt_sizes = image_inputs.tgt_sizes
1109
1034
 
1110
1035
  device = self.vpm.embeddings.position_embedding.weight.device
1111
1036
  dtype = self.vpm.embeddings.position_embedding.weight.dtype
@@ -1138,7 +1063,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1138
1063
  )
1139
1064
  return self.resampler(vision_embedding, tgt_sizes)
1140
1065
 
1141
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
1066
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
1142
1067
  # Get all special token IDs
1143
1068
  im_start_id: int = image_inputs.im_start_id
1144
1069
  im_end_id: int = image_inputs.im_end_id
@@ -32,7 +32,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
32
32
  ParallelLMHead,
33
33
  VocabParallelEmbedding,
34
34
  )
35
- from sglang.srt.managers.schedule_batch import ImageInputs
35
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
37
  from sglang.srt.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
@@ -796,7 +796,7 @@ class MllamaForConditionalGeneration(nn.Module):
796
796
  self.logits_processor = LogitsProcessor(config.text_config)
797
797
  self.capture_mode = False
798
798
 
799
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
799
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
800
800
  pixel_values = image_inputs.pixel_values
801
801
  pad_values = image_inputs.pad_values
802
802
 
@@ -815,7 +815,7 @@ class MllamaForConditionalGeneration(nn.Module):
815
815
 
816
816
  # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
817
817
  max_num_images = max_num_tiles = bs = 0
818
- for i, im in enumerate(forward_batch.image_inputs):
818
+ for i, im in enumerate(forward_batch.mm_inputs):
819
819
  if not forward_batch.encoder_cached[i] and im is not None:
820
820
  max_num_images = max(max_num_images, im.pixel_values.shape[1])
821
821
  max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
@@ -842,7 +842,7 @@ class MllamaForConditionalGeneration(nn.Module):
842
842
  )
843
843
  i = 0
844
844
  encoder_lens_need = []
845
- for k, im in enumerate(forward_batch.image_inputs):
845
+ for k, im in enumerate(forward_batch.mm_inputs):
846
846
  if forward_batch.encoder_cached[k] or im is None:
847
847
  continue
848
848
 
@@ -301,7 +301,7 @@ class Phi3SmallModel(nn.Module):
301
301
  self.mup_embedding_multiplier = config.mup_embedding_multiplier
302
302
  self.start_layer, self.end_layer, self.layers = make_layers(
303
303
  config.num_hidden_layers,
304
- lambda prefix: Phi3SmallDecoderLayer(
304
+ lambda idx, prefix: Phi3SmallDecoderLayer(
305
305
  config,
306
306
  int(prefix.split(".")[-1]),
307
307
  quant_config,
@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
361
361
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
362
362
  return self.model.get_input_embeddings(input_ids)
363
363
 
364
+ def get_input_embedding(self) -> nn.Embedding:
365
+ return self.model.embed_tokens
366
+
364
367
  @torch.no_grad()
365
368
  def forward(
366
369
  self,