sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/minicpmv.py
CHANGED
@@ -50,10 +50,11 @@ from sglang.srt.layers.linear import (
|
|
50
50
|
)
|
51
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
52
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
53
|
-
from sglang.srt.managers.
|
53
|
+
from sglang.srt.managers.mm_utils import (
|
54
54
|
MultiModalityDataPaddingPatternTokenPairs,
|
55
|
+
general_mm_embed_routine,
|
55
56
|
)
|
56
|
-
from sglang.srt.managers.schedule_batch import
|
57
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
57
58
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
59
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
59
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
|
|
399
400
|
)
|
400
401
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
401
402
|
|
402
|
-
def get_input_embeddings(self):
|
403
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
403
404
|
return self.embeddings
|
404
405
|
|
405
406
|
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
|
|
762
763
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
763
764
|
return valid_pairs_tensor
|
764
765
|
|
765
|
-
def get_embedding(
|
766
|
-
self,
|
767
|
-
input_ids: torch.Tensor,
|
768
|
-
image_inputs: Optional[MiniCPMVImageInputs],
|
769
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
770
|
-
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
771
|
-
|
772
|
-
if image_inputs is None: # No image
|
773
|
-
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
774
|
-
else:
|
775
|
-
if image_inputs["type"] == "image_embeds":
|
776
|
-
vision_hidden_states = (
|
777
|
-
image_inputs["data"]
|
778
|
-
.type(vlm_embedding.dtype)
|
779
|
-
.to(vlm_embedding.device)
|
780
|
-
)
|
781
|
-
else:
|
782
|
-
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
783
|
-
# See NOTE in _parse_and_validate_inputs
|
784
|
-
image_bounds = image_inputs["image_bounds"]
|
785
|
-
if len(image_bounds) > 0:
|
786
|
-
image_indices = torch.stack(
|
787
|
-
[
|
788
|
-
torch.arange(start, end, dtype=torch.long)
|
789
|
-
for start, end in image_bounds.tolist()
|
790
|
-
]
|
791
|
-
).to(vlm_embedding.device)
|
792
|
-
|
793
|
-
vlm_embedding.scatter_(
|
794
|
-
0,
|
795
|
-
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
796
|
-
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
797
|
-
)
|
798
|
-
|
799
|
-
return vlm_embedding, vision_hidden_states
|
800
|
-
|
801
766
|
def _parse_and_validate_inputs(
|
802
767
|
self,
|
803
768
|
input_ids: torch.Tensor,
|
@@ -828,7 +793,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|
828
793
|
)
|
829
794
|
|
830
795
|
if isinstance(image_embeds, list):
|
831
|
-
image_embeds = torch.
|
796
|
+
image_embeds = torch.cat(image_embeds)
|
832
797
|
|
833
798
|
return MiniCPMVImageEmbeddingInputs(
|
834
799
|
image_bounds=image_bounds,
|
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
|
|
836
801
|
type="image_embeds",
|
837
802
|
)
|
838
803
|
|
839
|
-
if not isinstance(pixel_values, (torch.Tensor, list)):
|
840
|
-
raise ValueError(
|
841
|
-
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
842
|
-
)
|
843
|
-
|
844
|
-
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
845
|
-
raise ValueError(
|
846
|
-
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
847
|
-
)
|
848
|
-
|
849
|
-
if len(pixel_values) != len(tgt_sizes):
|
850
|
-
raise ValueError(
|
851
|
-
"Inconsistent batch lengths, found: "
|
852
|
-
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
853
|
-
)
|
854
|
-
|
855
|
-
pixel_values_flat: List[torch.Tensor] = []
|
856
|
-
tgt_sizes_flat: List[torch.Tensor] = []
|
857
|
-
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
858
|
-
if len(pixel_b) != len(tgt_b):
|
859
|
-
raise ValueError(
|
860
|
-
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
861
|
-
)
|
862
|
-
|
863
|
-
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
864
|
-
pixel_values_flat += pixel_n
|
865
|
-
tgt_sizes_flat += tgt_n
|
866
|
-
|
867
|
-
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
868
|
-
# so we allow it to be empty
|
869
|
-
if len(pixel_values_flat) != len(tgt_sizes_flat):
|
870
|
-
raise ValueError(
|
871
|
-
"Inconsistent flattened lengths, found: "
|
872
|
-
f"{len(pixel_values_flat)} vs. "
|
873
|
-
f"{len(tgt_sizes_flat)}"
|
874
|
-
)
|
875
|
-
|
876
|
-
if len(pixel_values_flat) == 0:
|
877
|
-
return None
|
878
|
-
|
879
804
|
image_bounds = self._get_image_bounds(
|
880
805
|
input_ids=input_ids,
|
881
806
|
pad_values=pad_values,
|
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
|
|
886
811
|
)
|
887
812
|
return MiniCPMVImagePixelInputs(
|
888
813
|
image_bounds=image_bounds.to(device=input_ids.device),
|
889
|
-
data=
|
890
|
-
tgt_sizes=
|
814
|
+
data=pixel_values,
|
815
|
+
tgt_sizes=tgt_sizes,
|
891
816
|
type="pixel_values",
|
892
817
|
)
|
893
818
|
|
819
|
+
def get_embedding(
|
820
|
+
self,
|
821
|
+
input_ids: torch.Tensor,
|
822
|
+
image_inputs: Optional[MiniCPMVImageInputs],
|
823
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
824
|
+
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
825
|
+
|
826
|
+
if image_inputs is None: # No image
|
827
|
+
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
828
|
+
else:
|
829
|
+
if image_inputs["type"] == "image_embeds":
|
830
|
+
vision_hidden_states = (
|
831
|
+
image_inputs["data"]
|
832
|
+
.type(vlm_embedding.dtype)
|
833
|
+
.to(vlm_embedding.device)
|
834
|
+
)
|
835
|
+
else:
|
836
|
+
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
837
|
+
# See NOTE in _parse_and_validate_inputs
|
838
|
+
image_bounds = image_inputs["image_bounds"]
|
839
|
+
if len(image_bounds) > 0:
|
840
|
+
image_indices = torch.stack(
|
841
|
+
[
|
842
|
+
torch.arange(start, end, dtype=torch.long)
|
843
|
+
for start, end in image_bounds.tolist()
|
844
|
+
]
|
845
|
+
).to(vlm_embedding.device)
|
846
|
+
|
847
|
+
vlm_embedding.scatter_(
|
848
|
+
0,
|
849
|
+
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
850
|
+
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
851
|
+
)
|
852
|
+
|
853
|
+
return vlm_embedding, vision_hidden_states
|
854
|
+
|
855
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
856
|
+
return self.llm.get_input_embedding()
|
857
|
+
|
894
858
|
def forward(
|
895
859
|
self,
|
896
860
|
input_ids: torch.Tensor,
|
@@ -898,59 +862,18 @@ class MiniCPMVBaseModel(nn.Module):
|
|
898
862
|
forward_batch: ForwardBatch,
|
899
863
|
**kwargs: Any,
|
900
864
|
) -> torch.Tensor:
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
kwargs.update(
|
908
|
-
{
|
909
|
-
"pixel_values": (
|
910
|
-
None
|
911
|
-
if forward_batch.image_inputs is None
|
912
|
-
else [
|
913
|
-
i.pixel_values
|
914
|
-
for i in forward_batch.image_inputs
|
915
|
-
if i is not None
|
916
|
-
]
|
917
|
-
),
|
918
|
-
"tgt_sizes": (
|
919
|
-
None
|
920
|
-
if forward_batch.image_inputs is None
|
921
|
-
else [
|
922
|
-
i.tgt_sizes
|
923
|
-
for i in forward_batch.image_inputs
|
924
|
-
if i is not None
|
925
|
-
]
|
926
|
-
),
|
927
|
-
"im_start_id": forward_batch.image_inputs[0].im_start_id,
|
928
|
-
"im_end_id": forward_batch.image_inputs[0].im_end_id,
|
929
|
-
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
|
930
|
-
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
|
931
|
-
"pad_values": forward_batch.image_inputs[0].pad_values,
|
932
|
-
}
|
933
|
-
)
|
934
|
-
|
935
|
-
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
|
936
|
-
|
937
|
-
# Clamp input ids. This is because the input_ids for the image tokens are
|
938
|
-
# filled with the hash values of the image for the prefix matching in the radix attention.
|
939
|
-
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
940
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
941
|
-
|
942
|
-
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
943
|
-
|
944
|
-
# always pass the input via `inputs_embeds`
|
945
|
-
# to make sure the computation graph is consistent
|
946
|
-
# for `torch.compile` integration
|
947
|
-
input_ids = None
|
865
|
+
inputs_embeds = general_mm_embed_routine(
|
866
|
+
input_ids=input_ids,
|
867
|
+
forward_batch=forward_batch,
|
868
|
+
embed_tokens=self.get_input_embeddings(),
|
869
|
+
mm_data_embedding_func=self.get_image_features,
|
870
|
+
)
|
948
871
|
|
949
872
|
hidden_states = self.llm.model(
|
950
|
-
input_ids=
|
873
|
+
input_ids=None,
|
951
874
|
positions=positions,
|
952
875
|
forward_batch=forward_batch,
|
953
|
-
input_embeds=
|
876
|
+
input_embeds=inputs_embeds,
|
954
877
|
)
|
955
878
|
|
956
879
|
return self.logits_processor(
|
@@ -990,7 +913,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|
990
913
|
) -> torch.Tensor:
|
991
914
|
raise NotImplementedError
|
992
915
|
|
993
|
-
def
|
916
|
+
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
|
994
917
|
raise NotImplementedError
|
995
918
|
|
996
919
|
|
@@ -1100,12 +1023,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|
1100
1023
|
)
|
1101
1024
|
return vision_embedding
|
1102
1025
|
|
1103
|
-
def
|
1026
|
+
def get_image_features(
|
1104
1027
|
self,
|
1105
|
-
|
1028
|
+
image_inputs: MultimodalInputs,
|
1106
1029
|
) -> torch.Tensor:
|
1107
|
-
|
1108
|
-
|
1030
|
+
# list of tensors
|
1031
|
+
pixel_values = image_inputs.pixel_values
|
1032
|
+
|
1033
|
+
tgt_sizes = image_inputs.tgt_sizes
|
1109
1034
|
|
1110
1035
|
device = self.vpm.embeddings.position_embedding.weight.device
|
1111
1036
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
@@ -1138,7 +1063,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|
1138
1063
|
)
|
1139
1064
|
return self.resampler(vision_embedding, tgt_sizes)
|
1140
1065
|
|
1141
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs:
|
1066
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
1142
1067
|
# Get all special token IDs
|
1143
1068
|
im_start_id: int = image_inputs.im_start_id
|
1144
1069
|
im_end_id: int = image_inputs.im_end_id
|
sglang/srt/models/mllama.py
CHANGED
@@ -32,7 +32,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
32
32
|
ParallelLMHead,
|
33
33
|
VocabParallelEmbedding,
|
34
34
|
)
|
35
|
-
from sglang.srt.managers.schedule_batch import
|
35
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
37
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
38
38
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
@@ -796,7 +796,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
796
796
|
self.logits_processor = LogitsProcessor(config.text_config)
|
797
797
|
self.capture_mode = False
|
798
798
|
|
799
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs:
|
799
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
800
800
|
pixel_values = image_inputs.pixel_values
|
801
801
|
pad_values = image_inputs.pad_values
|
802
802
|
|
@@ -815,7 +815,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
815
815
|
|
816
816
|
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
817
817
|
max_num_images = max_num_tiles = bs = 0
|
818
|
-
for i, im in enumerate(forward_batch.
|
818
|
+
for i, im in enumerate(forward_batch.mm_inputs):
|
819
819
|
if not forward_batch.encoder_cached[i] and im is not None:
|
820
820
|
max_num_images = max(max_num_images, im.pixel_values.shape[1])
|
821
821
|
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
|
@@ -842,7 +842,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
842
842
|
)
|
843
843
|
i = 0
|
844
844
|
encoder_lens_need = []
|
845
|
-
for k, im in enumerate(forward_batch.
|
845
|
+
for k, im in enumerate(forward_batch.mm_inputs):
|
846
846
|
if forward_batch.encoder_cached[k] or im is None:
|
847
847
|
continue
|
848
848
|
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -301,7 +301,7 @@ class Phi3SmallModel(nn.Module):
|
|
301
301
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
302
302
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
303
303
|
config.num_hidden_layers,
|
304
|
-
lambda prefix: Phi3SmallDecoderLayer(
|
304
|
+
lambda idx, prefix: Phi3SmallDecoderLayer(
|
305
305
|
config,
|
306
306
|
int(prefix.split(".")[-1]),
|
307
307
|
quant_config,
|
sglang/srt/models/qwen2.py
CHANGED
@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|
361
361
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
362
362
|
return self.model.get_input_embeddings(input_ids)
|
363
363
|
|
364
|
+
def get_input_embedding(self) -> nn.Embedding:
|
365
|
+
return self.model.embed_tokens
|
366
|
+
|
364
367
|
@torch.no_grad()
|
365
368
|
def forward(
|
366
369
|
self,
|