sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/minicpmv.py
CHANGED
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
54
54
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
55
55
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
56
|
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
|
57
|
+
from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
|
57
58
|
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
58
59
|
from sglang.srt.utils import add_prefix, flatten_nested_list
|
59
60
|
|
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
|
|
581
582
|
|
582
583
|
def init_llm(
|
583
584
|
self,
|
584
|
-
config:
|
585
|
+
config: PretrainedConfig,
|
585
586
|
quant_config: Optional[QuantizationConfig] = None,
|
586
587
|
prefix: str = "",
|
587
588
|
) -> nn.Module:
|
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
|
|
774
775
|
return pattern.pad_input_tokens(input_ids, image_inputs)
|
775
776
|
|
776
777
|
|
777
|
-
|
778
|
+
class MiniCPMV4_0(MiniCPMBaseModel):
|
779
|
+
packed_modules_mapping = {
|
780
|
+
"qkv_proj": [
|
781
|
+
"q_proj",
|
782
|
+
"k_proj",
|
783
|
+
"v_proj",
|
784
|
+
],
|
785
|
+
"gate_up_proj": [
|
786
|
+
"gate_proj",
|
787
|
+
"up_proj",
|
788
|
+
],
|
789
|
+
}
|
790
|
+
# LoRA specific attributes
|
791
|
+
supported_lora_modules = [
|
792
|
+
# vision encoder
|
793
|
+
"fc1",
|
794
|
+
"fc2",
|
795
|
+
"out_proj",
|
796
|
+
# language model
|
797
|
+
"qkv_proj", # same name with vision encoder
|
798
|
+
"o_proj",
|
799
|
+
"gate_up_proj",
|
800
|
+
"down_proj",
|
801
|
+
# resampler
|
802
|
+
"kv_proj",
|
803
|
+
]
|
804
|
+
|
805
|
+
# BitandBytes specific attributes
|
806
|
+
bitsandbytes_stacked_params_mapping = {
|
807
|
+
# shard_name, weight_name, index
|
808
|
+
"q_proj": ("qkv_proj", 0),
|
809
|
+
"k_proj": ("qkv_proj", 1),
|
810
|
+
"v_proj": ("qkv_proj", 2),
|
811
|
+
"gate_proj": ("gate_up_proj", 0),
|
812
|
+
"up_proj": ("gate_up_proj", 1),
|
813
|
+
}
|
814
|
+
|
815
|
+
embedding_modules = {}
|
816
|
+
embedding_padding_modules = []
|
817
|
+
|
818
|
+
def __init__(
|
819
|
+
self,
|
820
|
+
config: PretrainedConfig,
|
821
|
+
quant_config: Optional[QuantizationConfig] = None,
|
822
|
+
prefix: str = "",
|
823
|
+
):
|
824
|
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
825
|
+
assert self.version == (4, 0)
|
826
|
+
|
827
|
+
def init_llm(
|
828
|
+
self,
|
829
|
+
config: LlamaConfig,
|
830
|
+
quant_config: Optional[QuantizationConfig] = None,
|
831
|
+
prefix: str = "",
|
832
|
+
) -> nn.Module:
|
833
|
+
return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
834
|
+
|
835
|
+
def init_vision_module(
|
836
|
+
self,
|
837
|
+
config: PretrainedConfig,
|
838
|
+
quant_config: Optional[QuantizationConfig],
|
839
|
+
prefix: str = "",
|
840
|
+
) -> nn.Module:
|
841
|
+
model = Idefics2VisionTransformer(
|
842
|
+
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
843
|
+
)
|
844
|
+
if self.config.drop_vision_last_layer:
|
845
|
+
model.encoder.layers = model.encoder.layers[:-1]
|
846
|
+
|
847
|
+
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
848
|
+
setattr(model, "patch_size", model.embeddings.patch_size)
|
849
|
+
return model
|
850
|
+
|
851
|
+
def init_resampler(
|
852
|
+
self,
|
853
|
+
embed_dim: int,
|
854
|
+
vision_dim: int,
|
855
|
+
quant_config: Optional[QuantizationConfig] = None,
|
856
|
+
prefix: str = "",
|
857
|
+
) -> nn.Module:
|
858
|
+
with set_default_torch_dtype(torch.float16):
|
859
|
+
# The resampler in 2.6 remains consistent with the one in 2.5.
|
860
|
+
resampler = Resampler2_5(
|
861
|
+
num_queries=self.config.query_num,
|
862
|
+
embed_dim=embed_dim,
|
863
|
+
num_heads=embed_dim // 128,
|
864
|
+
kv_dim=vision_dim,
|
865
|
+
quant_config=quant_config,
|
866
|
+
prefix=prefix,
|
867
|
+
)
|
868
|
+
|
869
|
+
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
870
|
+
|
871
|
+
def get_vision_embedding(
|
872
|
+
self,
|
873
|
+
pixel_values: List[torch.Tensor],
|
874
|
+
patch_attn_mask: Optional[torch.Tensor] = None,
|
875
|
+
tgt_sizes: Optional[torch.Tensor] = None,
|
876
|
+
) -> torch.Tensor:
|
877
|
+
vision_embedding = self.vpm(
|
878
|
+
pixel_values,
|
879
|
+
patch_attention_mask=patch_attn_mask,
|
880
|
+
tgt_sizes=tgt_sizes,
|
881
|
+
)
|
882
|
+
return vision_embedding
|
883
|
+
|
884
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
885
|
+
# list of tensors
|
886
|
+
pixel_values = flatten_nested_list([item.feature for item in items])
|
887
|
+
tgt_sizes = torch.stack(
|
888
|
+
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
889
|
+
)
|
890
|
+
assert len(pixel_values) == tgt_sizes.shape[0]
|
891
|
+
|
892
|
+
device = self.vpm.embeddings.position_embedding.weight.device
|
893
|
+
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
894
|
+
all_pixel_values_lst = [
|
895
|
+
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
896
|
+
]
|
897
|
+
|
898
|
+
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
899
|
+
assert isinstance(max_patches, int)
|
900
|
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
901
|
+
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
902
|
+
)
|
903
|
+
|
904
|
+
B, L, _ = all_pixel_values.shape
|
905
|
+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
906
|
+
patch_attn_mask = torch.zeros(
|
907
|
+
(B, 1, max_patches), dtype=torch.bool, device=device
|
908
|
+
)
|
909
|
+
|
910
|
+
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
911
|
+
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
912
|
+
patch_attn_mask[:, 0, :] = torch.arange(
|
913
|
+
patch_attn_mask.size(2), device=patch_attn_mask.device
|
914
|
+
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
915
|
+
|
916
|
+
vision_embedding = self.vpm(
|
917
|
+
all_pixel_values.type(dtype),
|
918
|
+
patch_attention_mask=patch_attn_mask,
|
919
|
+
tgt_sizes=tgt_sizes,
|
920
|
+
)
|
921
|
+
return self.resampler(vision_embedding, tgt_sizes)
|
922
|
+
|
923
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
924
|
+
# Get all special token IDs
|
925
|
+
im_start_id: int = image_inputs.im_start_id
|
926
|
+
im_end_id: int = image_inputs.im_end_id
|
927
|
+
slice_start_id: int = image_inputs.slice_start_id
|
928
|
+
slice_end_id: int = image_inputs.slice_end_id
|
929
|
+
|
930
|
+
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
931
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
932
|
+
|
933
|
+
return pattern.pad_input_tokens(input_ids, image_inputs)
|
934
|
+
|
935
|
+
|
936
|
+
_SUPPORT_VERSION = {
|
937
|
+
(2, 6): MiniCPMV2_6,
|
938
|
+
(4, 0): MiniCPMV4_0,
|
939
|
+
}
|
778
940
|
|
779
941
|
|
780
942
|
class MiniCPMV:
|
@@ -809,7 +971,7 @@ class MiniCPMV:
|
|
809
971
|
# Dispatch class based on version
|
810
972
|
instance_class = _SUPPORT_VERSION.get(version)
|
811
973
|
if instance_class is None:
|
812
|
-
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
|
974
|
+
raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
|
813
975
|
|
814
976
|
try:
|
815
977
|
minicpmv = instance_class(
|
sglang/srt/models/mllama4.py
CHANGED
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
961
961
|
def set_embed(self, embed):
|
962
962
|
return self.language_model.set_embed(embed)
|
963
963
|
|
964
|
+
def get_hidden_dim(self, module_name, layer_idx):
|
965
|
+
# return input_dim, output_dim
|
966
|
+
if module_name == "qkv_proj":
|
967
|
+
return (
|
968
|
+
self.config.hidden_size,
|
969
|
+
self.config.head_dim
|
970
|
+
* (
|
971
|
+
self.config.num_attention_heads
|
972
|
+
+ self.config.num_key_value_heads * 2
|
973
|
+
),
|
974
|
+
)
|
975
|
+
elif module_name == "o_proj":
|
976
|
+
return (
|
977
|
+
self.config.head_dim * self.config.num_attention_heads,
|
978
|
+
self.config.hidden_size,
|
979
|
+
)
|
980
|
+
elif module_name == "gate_up_proj":
|
981
|
+
return self.config.hidden_size, self.config.intermediate_size * 2
|
982
|
+
elif module_name == "down_proj":
|
983
|
+
decoder_layer = self.language_model.get_layers()[layer_idx]
|
984
|
+
intermediate_size = decoder_layer.get_intermediate_size()
|
985
|
+
return intermediate_size, self.config.hidden_size
|
986
|
+
else:
|
987
|
+
raise NotImplementedError()
|
988
|
+
|
964
989
|
|
965
990
|
EntryClass = Llama4ForConditionalGeneration
|