sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
423
423
  return self.config.num_local_experts > 0
424
424
  return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
425
425
 
426
+ def get_intermediate_size(self) -> int:
427
+ if isinstance(self.feed_forward, Llama4MoE):
428
+ return self.config.intermediate_size
429
+ else:
430
+ return self.config.intermediate_size_mlp
431
+
426
432
  def forward(
427
433
  self,
428
434
  positions: torch.Tensor,
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
540
546
  def get_input_embeddings(self):
541
547
  return self.model.embed_tokens
542
548
 
549
+ def get_layers(self):
550
+ return self.model.layers
551
+
543
552
  def _init_model(
544
553
  self,
545
554
  config: Llama4TextConfig,
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
109
109
  ) -> None:
110
110
  super().__init__()
111
111
  self.config = config
112
+
113
+ self.is_mrope_enabled = (
114
+ hasattr(config, "rope_scaling")
115
+ and config.rope_scaling is not None
116
+ and "mrope_section" in config.rope_scaling
117
+ )
118
+ # fix rope_scaling for qwen2.5-vl
119
+ if self.is_mrope_enabled:
120
+ config.rope_scaling["rope_type"] = "default"
121
+
112
122
  self.vocab_size = config.vocab_size
113
123
  self.embed_tokens = VocabParallelEmbedding(
114
124
  config.vocab_size,
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
144
154
  else:
145
155
  embeds = input_embeds
146
156
 
157
+ if self.is_mrope_enabled:
158
+ positions = forward_batch.mrope_positions
159
+
147
160
  hidden_states = forward_batch.spec_info.hidden_states
148
161
  if hidden_states.shape[-1] != embeds.shape[-1]:
149
162
  hidden_states = self.fc(hidden_states)
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
260
260
  )
261
261
  self.topk.forward = self.topk.forward_native
262
262
 
263
- self.experts = get_moe_impl_class()(
263
+ self.experts = get_moe_impl_class(quant_config)(
264
264
  num_experts=self.num_experts,
265
265
  top_k=self.top_k,
266
266
  layer_id=self.layer_id,
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
853
853
 
854
854
  # Params for weights, fp8 weight scales, fp8 activation scales
855
855
  # (param_name, weight_name, expert_id, shard_id)
856
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
856
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
857
857
  ckpt_gate_proj_name="gate_proj",
858
858
  ckpt_down_proj_name="down_proj",
859
859
  ckpt_up_proj_name="up_proj",
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
54
  from sglang.srt.model_loader.utils import set_default_torch_dtype
55
55
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
56
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
57
+ from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
57
58
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
58
59
  from sglang.srt.utils import add_prefix, flatten_nested_list
59
60
 
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
581
582
 
582
583
  def init_llm(
583
584
  self,
584
- config: Qwen2Config,
585
+ config: PretrainedConfig,
585
586
  quant_config: Optional[QuantizationConfig] = None,
586
587
  prefix: str = "",
587
588
  ) -> nn.Module:
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
774
775
  return pattern.pad_input_tokens(input_ids, image_inputs)
775
776
 
776
777
 
777
- _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
778
+ class MiniCPMV4_0(MiniCPMBaseModel):
779
+ packed_modules_mapping = {
780
+ "qkv_proj": [
781
+ "q_proj",
782
+ "k_proj",
783
+ "v_proj",
784
+ ],
785
+ "gate_up_proj": [
786
+ "gate_proj",
787
+ "up_proj",
788
+ ],
789
+ }
790
+ # LoRA specific attributes
791
+ supported_lora_modules = [
792
+ # vision encoder
793
+ "fc1",
794
+ "fc2",
795
+ "out_proj",
796
+ # language model
797
+ "qkv_proj", # same name with vision encoder
798
+ "o_proj",
799
+ "gate_up_proj",
800
+ "down_proj",
801
+ # resampler
802
+ "kv_proj",
803
+ ]
804
+
805
+ # BitandBytes specific attributes
806
+ bitsandbytes_stacked_params_mapping = {
807
+ # shard_name, weight_name, index
808
+ "q_proj": ("qkv_proj", 0),
809
+ "k_proj": ("qkv_proj", 1),
810
+ "v_proj": ("qkv_proj", 2),
811
+ "gate_proj": ("gate_up_proj", 0),
812
+ "up_proj": ("gate_up_proj", 1),
813
+ }
814
+
815
+ embedding_modules = {}
816
+ embedding_padding_modules = []
817
+
818
+ def __init__(
819
+ self,
820
+ config: PretrainedConfig,
821
+ quant_config: Optional[QuantizationConfig] = None,
822
+ prefix: str = "",
823
+ ):
824
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
825
+ assert self.version == (4, 0)
826
+
827
+ def init_llm(
828
+ self,
829
+ config: LlamaConfig,
830
+ quant_config: Optional[QuantizationConfig] = None,
831
+ prefix: str = "",
832
+ ) -> nn.Module:
833
+ return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
834
+
835
+ def init_vision_module(
836
+ self,
837
+ config: PretrainedConfig,
838
+ quant_config: Optional[QuantizationConfig],
839
+ prefix: str = "",
840
+ ) -> nn.Module:
841
+ model = Idefics2VisionTransformer(
842
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
843
+ )
844
+ if self.config.drop_vision_last_layer:
845
+ model.encoder.layers = model.encoder.layers[:-1]
846
+
847
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
848
+ setattr(model, "patch_size", model.embeddings.patch_size)
849
+ return model
850
+
851
+ def init_resampler(
852
+ self,
853
+ embed_dim: int,
854
+ vision_dim: int,
855
+ quant_config: Optional[QuantizationConfig] = None,
856
+ prefix: str = "",
857
+ ) -> nn.Module:
858
+ with set_default_torch_dtype(torch.float16):
859
+ # The resampler in 2.6 remains consistent with the one in 2.5.
860
+ resampler = Resampler2_5(
861
+ num_queries=self.config.query_num,
862
+ embed_dim=embed_dim,
863
+ num_heads=embed_dim // 128,
864
+ kv_dim=vision_dim,
865
+ quant_config=quant_config,
866
+ prefix=prefix,
867
+ )
868
+
869
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
870
+
871
+ def get_vision_embedding(
872
+ self,
873
+ pixel_values: List[torch.Tensor],
874
+ patch_attn_mask: Optional[torch.Tensor] = None,
875
+ tgt_sizes: Optional[torch.Tensor] = None,
876
+ ) -> torch.Tensor:
877
+ vision_embedding = self.vpm(
878
+ pixel_values,
879
+ patch_attention_mask=patch_attn_mask,
880
+ tgt_sizes=tgt_sizes,
881
+ )
882
+ return vision_embedding
883
+
884
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
885
+ # list of tensors
886
+ pixel_values = flatten_nested_list([item.feature for item in items])
887
+ tgt_sizes = torch.stack(
888
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
889
+ )
890
+ assert len(pixel_values) == tgt_sizes.shape[0]
891
+
892
+ device = self.vpm.embeddings.position_embedding.weight.device
893
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
894
+ all_pixel_values_lst = [
895
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
896
+ ]
897
+
898
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
899
+ assert isinstance(max_patches, int)
900
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
901
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
902
+ )
903
+
904
+ B, L, _ = all_pixel_values.shape
905
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
906
+ patch_attn_mask = torch.zeros(
907
+ (B, 1, max_patches), dtype=torch.bool, device=device
908
+ )
909
+
910
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
911
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
912
+ patch_attn_mask[:, 0, :] = torch.arange(
913
+ patch_attn_mask.size(2), device=patch_attn_mask.device
914
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
915
+
916
+ vision_embedding = self.vpm(
917
+ all_pixel_values.type(dtype),
918
+ patch_attention_mask=patch_attn_mask,
919
+ tgt_sizes=tgt_sizes,
920
+ )
921
+ return self.resampler(vision_embedding, tgt_sizes)
922
+
923
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
924
+ # Get all special token IDs
925
+ im_start_id: int = image_inputs.im_start_id
926
+ im_end_id: int = image_inputs.im_end_id
927
+ slice_start_id: int = image_inputs.slice_start_id
928
+ slice_end_id: int = image_inputs.slice_end_id
929
+
930
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
931
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
932
+
933
+ return pattern.pad_input_tokens(input_ids, image_inputs)
934
+
935
+
936
+ _SUPPORT_VERSION = {
937
+ (2, 6): MiniCPMV2_6,
938
+ (4, 0): MiniCPMV4_0,
939
+ }
778
940
 
779
941
 
780
942
  class MiniCPMV:
@@ -809,7 +971,7 @@ class MiniCPMV:
809
971
  # Dispatch class based on version
810
972
  instance_class = _SUPPORT_VERSION.get(version)
811
973
  if instance_class is None:
812
- raise ValueError("Currently, MiniCPMV only supports versions 2.6")
974
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
813
975
 
814
976
  try:
815
977
  minicpmv = instance_class(
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
961
961
  def set_embed(self, embed):
962
962
  return self.language_model.set_embed(embed)
963
963
 
964
+ def get_hidden_dim(self, module_name, layer_idx):
965
+ # return input_dim, output_dim
966
+ if module_name == "qkv_proj":
967
+ return (
968
+ self.config.hidden_size,
969
+ self.config.head_dim
970
+ * (
971
+ self.config.num_attention_heads
972
+ + self.config.num_key_value_heads * 2
973
+ ),
974
+ )
975
+ elif module_name == "o_proj":
976
+ return (
977
+ self.config.head_dim * self.config.num_attention_heads,
978
+ self.config.hidden_size,
979
+ )
980
+ elif module_name == "gate_up_proj":
981
+ return self.config.hidden_size, self.config.intermediate_size * 2
982
+ elif module_name == "down_proj":
983
+ decoder_layer = self.language_model.get_layers()[layer_idx]
984
+ intermediate_size = decoder_layer.get_intermediate_size()
985
+ return intermediate_size, self.config.hidden_size
986
+ else:
987
+ raise NotImplementedError()
988
+
964
989
 
965
990
  EntryClass = Llama4ForConditionalGeneration