sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
54
  from sglang.srt.model_loader.utils import set_default_torch_dtype
55
55
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
56
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
57
+ from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
57
58
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
58
59
  from sglang.srt.utils import add_prefix, flatten_nested_list
59
60
 
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
581
582
 
582
583
  def init_llm(
583
584
  self,
584
- config: Qwen2Config,
585
+ config: PretrainedConfig,
585
586
  quant_config: Optional[QuantizationConfig] = None,
586
587
  prefix: str = "",
587
588
  ) -> nn.Module:
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
774
775
  return pattern.pad_input_tokens(input_ids, image_inputs)
775
776
 
776
777
 
777
- _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
778
+ class MiniCPMV4_0(MiniCPMBaseModel):
779
+ packed_modules_mapping = {
780
+ "qkv_proj": [
781
+ "q_proj",
782
+ "k_proj",
783
+ "v_proj",
784
+ ],
785
+ "gate_up_proj": [
786
+ "gate_proj",
787
+ "up_proj",
788
+ ],
789
+ }
790
+ # LoRA specific attributes
791
+ supported_lora_modules = [
792
+ # vision encoder
793
+ "fc1",
794
+ "fc2",
795
+ "out_proj",
796
+ # language model
797
+ "qkv_proj", # same name with vision encoder
798
+ "o_proj",
799
+ "gate_up_proj",
800
+ "down_proj",
801
+ # resampler
802
+ "kv_proj",
803
+ ]
804
+
805
+ # BitandBytes specific attributes
806
+ bitsandbytes_stacked_params_mapping = {
807
+ # shard_name, weight_name, index
808
+ "q_proj": ("qkv_proj", 0),
809
+ "k_proj": ("qkv_proj", 1),
810
+ "v_proj": ("qkv_proj", 2),
811
+ "gate_proj": ("gate_up_proj", 0),
812
+ "up_proj": ("gate_up_proj", 1),
813
+ }
814
+
815
+ embedding_modules = {}
816
+ embedding_padding_modules = []
817
+
818
+ def __init__(
819
+ self,
820
+ config: PretrainedConfig,
821
+ quant_config: Optional[QuantizationConfig] = None,
822
+ prefix: str = "",
823
+ ):
824
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
825
+ assert self.version == (4, 0)
826
+
827
+ def init_llm(
828
+ self,
829
+ config: LlamaConfig,
830
+ quant_config: Optional[QuantizationConfig] = None,
831
+ prefix: str = "",
832
+ ) -> nn.Module:
833
+ return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
834
+
835
+ def init_vision_module(
836
+ self,
837
+ config: PretrainedConfig,
838
+ quant_config: Optional[QuantizationConfig],
839
+ prefix: str = "",
840
+ ) -> nn.Module:
841
+ model = Idefics2VisionTransformer(
842
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
843
+ )
844
+ if self.config.drop_vision_last_layer:
845
+ model.encoder.layers = model.encoder.layers[:-1]
846
+
847
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
848
+ setattr(model, "patch_size", model.embeddings.patch_size)
849
+ return model
850
+
851
+ def init_resampler(
852
+ self,
853
+ embed_dim: int,
854
+ vision_dim: int,
855
+ quant_config: Optional[QuantizationConfig] = None,
856
+ prefix: str = "",
857
+ ) -> nn.Module:
858
+ with set_default_torch_dtype(torch.float16):
859
+ # The resampler in 2.6 remains consistent with the one in 2.5.
860
+ resampler = Resampler2_5(
861
+ num_queries=self.config.query_num,
862
+ embed_dim=embed_dim,
863
+ num_heads=embed_dim // 128,
864
+ kv_dim=vision_dim,
865
+ quant_config=quant_config,
866
+ prefix=prefix,
867
+ )
868
+
869
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
870
+
871
+ def get_vision_embedding(
872
+ self,
873
+ pixel_values: List[torch.Tensor],
874
+ patch_attn_mask: Optional[torch.Tensor] = None,
875
+ tgt_sizes: Optional[torch.Tensor] = None,
876
+ ) -> torch.Tensor:
877
+ vision_embedding = self.vpm(
878
+ pixel_values,
879
+ patch_attention_mask=patch_attn_mask,
880
+ tgt_sizes=tgt_sizes,
881
+ )
882
+ return vision_embedding
883
+
884
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
885
+ # list of tensors
886
+ pixel_values = flatten_nested_list([item.feature for item in items])
887
+ tgt_sizes = torch.stack(
888
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
889
+ )
890
+ assert len(pixel_values) == tgt_sizes.shape[0]
891
+
892
+ device = self.vpm.embeddings.position_embedding.weight.device
893
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
894
+ all_pixel_values_lst = [
895
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
896
+ ]
897
+
898
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
899
+ assert isinstance(max_patches, int)
900
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
901
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
902
+ )
903
+
904
+ B, L, _ = all_pixel_values.shape
905
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
906
+ patch_attn_mask = torch.zeros(
907
+ (B, 1, max_patches), dtype=torch.bool, device=device
908
+ )
909
+
910
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
911
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
912
+ patch_attn_mask[:, 0, :] = torch.arange(
913
+ patch_attn_mask.size(2), device=patch_attn_mask.device
914
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
915
+
916
+ vision_embedding = self.vpm(
917
+ all_pixel_values.type(dtype),
918
+ patch_attention_mask=patch_attn_mask,
919
+ tgt_sizes=tgt_sizes,
920
+ )
921
+ return self.resampler(vision_embedding, tgt_sizes)
922
+
923
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
924
+ # Get all special token IDs
925
+ im_start_id: int = image_inputs.im_start_id
926
+ im_end_id: int = image_inputs.im_end_id
927
+ slice_start_id: int = image_inputs.slice_start_id
928
+ slice_end_id: int = image_inputs.slice_end_id
929
+
930
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
931
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
932
+
933
+ return pattern.pad_input_tokens(input_ids, image_inputs)
934
+
935
+
936
+ _SUPPORT_VERSION = {
937
+ (2, 6): MiniCPMV2_6,
938
+ (4, 0): MiniCPMV4_0,
939
+ }
778
940
 
779
941
 
780
942
  class MiniCPMV:
@@ -809,7 +971,7 @@ class MiniCPMV:
809
971
  # Dispatch class based on version
810
972
  instance_class = _SUPPORT_VERSION.get(version)
811
973
  if instance_class is None:
812
- raise ValueError("Currently, MiniCPMV only supports versions 2.6")
974
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
813
975
 
814
976
  try:
815
977
  minicpmv = instance_class(
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
961
961
  def set_embed(self, embed):
962
962
  return self.language_model.set_embed(embed)
963
963
 
964
+ def get_hidden_dim(self, module_name, layer_idx):
965
+ # return input_dim, output_dim
966
+ if module_name == "qkv_proj":
967
+ return (
968
+ self.config.hidden_size,
969
+ self.config.head_dim
970
+ * (
971
+ self.config.num_attention_heads
972
+ + self.config.num_key_value_heads * 2
973
+ ),
974
+ )
975
+ elif module_name == "o_proj":
976
+ return (
977
+ self.config.head_dim * self.config.num_attention_heads,
978
+ self.config.hidden_size,
979
+ )
980
+ elif module_name == "gate_up_proj":
981
+ return self.config.hidden_size, self.config.intermediate_size * 2
982
+ elif module_name == "down_proj":
983
+ decoder_layer = self.language_model.get_layers()[layer_idx]
984
+ intermediate_size = decoder_layer.get_intermediate_size()
985
+ return intermediate_size, self.config.hidden_size
986
+ else:
987
+ raise NotImplementedError()
988
+
964
989
 
965
990
  EntryClass = Llama4ForConditionalGeneration