sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 NextN Speculative Decoding."""
15
+ """Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
16
16
  import logging
17
17
  from typing import Iterable, Optional, Tuple
18
18
 
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
48
48
  super().__init__()
49
49
  if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
50
50
  logger.warning(
51
- "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
51
+ "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
52
52
  )
53
53
  quant_config = None
54
54
 
@@ -7,7 +7,6 @@ import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
9
9
 
10
- from sglang.srt.hf_transformers_utils import get_processor
11
10
  from sglang.srt.layers.activation import SiluAndMul
12
11
  from sglang.srt.layers.attention import vision_utils
13
12
  from sglang.srt.layers.layernorm import RMSNorm
@@ -28,6 +27,7 @@ from sglang.srt.models.qwen2_5_vl import (
28
27
  Qwen2_5_VLForConditionalGeneration,
29
28
  )
30
29
  from sglang.srt.utils import add_prefix
30
+ from sglang.srt.utils.hf_transformers_utils import get_processor
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
93
93
  quant_config=quant_config,
94
94
  prefix=prefix,
95
95
  num_dummy_heads=config.num_dummy_heads,
96
+ rms_norm_eps=config.rms_norm_eps,
96
97
  )
97
- self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
98
- self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
99
98
 
100
99
  self.mlp = Glm4vVisionMLP(
101
100
  config.hidden_size,
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
498
497
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
498
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
500
499
 
500
+ # For EAGLE3 support
501
+ self.capture_aux_hidden_states = False
502
+
501
503
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
504
  pixel_values = torch.cat(
503
505
  [item.feature.squeeze(0) for item in items], dim=0
@@ -10,7 +10,6 @@ from sglang.srt.distributed import (
10
10
  get_moe_expert_parallel_world_size,
11
11
  get_tensor_model_parallel_world_size,
12
12
  )
13
- from sglang.srt.hf_transformers_utils import get_processor
14
13
  from sglang.srt.layers.attention import vision_utils
15
14
  from sglang.srt.layers.logits_processor import LogitsProcessor
16
15
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
@@ -22,6 +21,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
22
21
  from sglang.srt.models.glm4_moe import Glm4MoeModel
23
22
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
24
23
  from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
24
+ from sglang.srt.utils.hf_transformers_utils import get_processor
25
25
 
26
26
  _is_cuda = is_cuda()
27
27
 
@@ -74,6 +74,9 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
74
74
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
75
75
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
76
76
 
77
+ # For EAGLE3 support
78
+ self.capture_aux_hidden_states = False
79
+
77
80
  def determine_num_fused_shared_experts(
78
81
  self, architecture: str = "Glm4MoeForCausalLM"
79
82
  ):
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
66
66
  from sglang.srt.managers.schedule_batch import global_server_args_dict
67
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
68
68
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
+ from sglang.srt.models.utils import (
70
+ create_fused_set_kv_buffer_arg,
71
+ enable_fused_set_kv_buffer,
72
+ )
69
73
  from sglang.srt.utils import (
70
74
  LazyValue,
71
75
  add_prefix,
@@ -121,7 +125,7 @@ class GptOssSparseMoeBlock(nn.Module):
121
125
  )
122
126
 
123
127
  self.top_k = config.num_experts_per_tok
124
- experts_type = get_moe_impl_class()
128
+ experts_type = get_moe_impl_class(quant_config)
125
129
  extra_kwargs = {}
126
130
  if experts_type.__name__ == "FusedMoE":
127
131
  quant_config_name = (
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
193
197
  return ans
194
198
 
195
199
 
196
- def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
197
- """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
198
- return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
199
-
200
-
201
- # TODO maybe move to a model-common utils
202
- def _create_fused_set_kv_buffer_arg(
203
- value: torch.Tensor,
204
- layer: RadixAttention,
205
- forward_batch: ForwardBatch,
206
- ):
207
- layer_id = layer.layer_id
208
- token_to_kv_pool = forward_batch.token_to_kv_pool
209
-
210
- k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
211
- v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
212
-
213
- return FusedSetKVBufferArg(
214
- value=value,
215
- k_buffer=k_buffer.view(k_buffer.shape[0], -1),
216
- v_buffer=v_buffer.view(v_buffer.shape[0], -1),
217
- k_scale=layer.k_scale,
218
- v_scale=layer.v_scale,
219
- cache_loc=forward_batch.out_cache_loc,
220
- )
221
-
222
-
223
200
  class GptOssAttention(nn.Module):
224
201
  def __init__(
225
202
  self,
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
337
314
  q,
338
315
  k,
339
316
  fused_set_kv_buffer_arg=(
340
- _create_fused_set_kv_buffer_arg(
317
+ create_fused_set_kv_buffer_arg(
341
318
  value=v,
342
319
  layer=self.attn,
343
320
  forward_batch=forward_batch,
344
321
  )
345
- if _enable_fused_set_kv_buffer(forward_batch)
322
+ if enable_fused_set_kv_buffer(forward_batch)
346
323
  else None
347
324
  ),
348
325
  )
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
356
333
  attn_output = self.attn(
357
334
  *inner_state,
358
335
  sinks=self.sinks,
359
- save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
336
+ save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
360
337
  )
361
338
  output, _ = self.o_proj(attn_output)
362
339
  return output
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.deepseek_janus_pro import DropPath
29
+ from sglang.srt.models.gpt_oss import GptOssForCausalLM
29
30
  from sglang.srt.models.internlm2 import InternLM2ForCausalLM
30
31
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
32
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM
31
33
  from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
32
34
  from sglang.utils import logger
33
35
 
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
445
447
  self.language_model = Qwen3MoeForCausalLM(
446
448
  config=config.llm_config, quant_config=quant_config
447
449
  )
450
+ elif config.llm_config.architectures[0] == "GptOssForCausalLM":
451
+ self.language_model = GptOssForCausalLM(
452
+ config=config.llm_config, quant_config=quant_config
453
+ )
454
+ elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
455
+ self.language_model = Qwen3ForCausalLM(
456
+ config=config.llm_config, quant_config=quant_config
457
+ )
448
458
  else:
449
459
  raise NotImplementedError(
450
460
  f"{config.llm_config.architectures[0]} is not implemented."
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
577
587
  ckpt_up_proj_name="up_proj",
578
588
  num_experts=self.config.num_experts,
579
589
  )
590
+ elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
591
+ stacked_params_mapping = [
592
+ # (param_name, shard_name, shard_id)
593
+ ("qkv_proj", "q_proj", "q"),
594
+ ("qkv_proj", "k_proj", "k"),
595
+ ("qkv_proj", "v_proj", "v"),
596
+ ("gate_up_proj", "gate_proj", 0),
597
+ ("gate_up_proj", "up_proj", 1),
598
+ ]
580
599
 
581
600
  params_dict = dict(self.named_parameters())
582
601
  loaded_params: Set[str] = set()
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
661
680
 
662
681
  loaded_params.add(name)
663
682
  unloaded_params = params_dict.keys() - loaded_params
683
+ # Skip params that are created by quantization wrappers and are not expected in the ckpt
684
+ _quant_only_fragments = (
685
+ "weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
686
+ )
687
+ unloaded_params = {
688
+ n
689
+ for n in unloaded_params
690
+ if not any(frag in n for frag in _quant_only_fragments)
691
+ }
664
692
  if unloaded_params:
665
693
  raise RuntimeError(
666
694
  f"Some weights are not initialized from checkpoints: {unloaded_params}"
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
49
49
  import torch
50
50
  import torch.nn as nn
51
51
  import torch.nn.functional as F
52
- from transformers.activations import ACT2FN, PytorchGELUTanh
52
+ from transformers.activations import ACT2FN, GELUTanh
53
53
  from transformers.modeling_utils import PreTrainedModel
54
54
 
55
55
  try:
@@ -614,7 +614,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
614
614
  "num_heads": config.num_attention_heads,
615
615
  "hidden_dim": config.hidden_size,
616
616
  "mlp_dim": config.intermediate_size,
617
- "activation": PytorchGELUTanh(),
617
+ "activation": GELUTanh(),
618
618
  "attn_bias": True,
619
619
  "attn_implementation": config._attn_implementation,
620
620
  },
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
385
385
  "Self attention has no KV cache scaling " "factor attribute!"
386
386
  )
387
387
 
388
+ def get_input_embeddings(self) -> nn.Embedding:
389
+ """Get input embeddings from the model."""
390
+ return self.embed_tokens
391
+
388
392
 
389
393
  class LlamaForCausalLM(nn.Module):
390
394
  # BitandBytes specific attributes
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
423
423
  return self.config.num_local_experts > 0
424
424
  return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
425
425
 
426
+ def get_intermediate_size(self) -> int:
427
+ if isinstance(self.feed_forward, Llama4MoE):
428
+ return self.config.intermediate_size
429
+ else:
430
+ return self.config.intermediate_size_mlp
431
+
426
432
  def forward(
427
433
  self,
428
434
  positions: torch.Tensor,
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
540
546
  def get_input_embeddings(self):
541
547
  return self.model.embed_tokens
542
548
 
549
+ def get_layers(self):
550
+ return self.model.layers
551
+
543
552
  def _init_model(
544
553
  self,
545
554
  config: Llama4TextConfig,
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
109
109
  ) -> None:
110
110
  super().__init__()
111
111
  self.config = config
112
+
113
+ self.is_mrope_enabled = (
114
+ hasattr(config, "rope_scaling")
115
+ and config.rope_scaling is not None
116
+ and "mrope_section" in config.rope_scaling
117
+ )
118
+ # fix rope_scaling for qwen2.5-vl
119
+ if self.is_mrope_enabled:
120
+ config.rope_scaling["rope_type"] = "default"
121
+
112
122
  self.vocab_size = config.vocab_size
113
123
  self.embed_tokens = VocabParallelEmbedding(
114
124
  config.vocab_size,
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
144
154
  else:
145
155
  embeds = input_embeds
146
156
 
157
+ if self.is_mrope_enabled:
158
+ positions = forward_batch.mrope_positions
159
+
147
160
  hidden_states = forward_batch.spec_info.hidden_states
148
161
  if hidden_states.shape[-1] != embeds.shape[-1]:
149
162
  hidden_states = self.fc(hidden_states)
@@ -131,7 +131,7 @@ elif _is_hip:
131
131
  awq_dequantize_triton as awq_dequantize,
132
132
  )
133
133
  else:
134
- from vllm._custom_ops import awq_dequantize
134
+ pass
135
135
 
136
136
  logger = logging.getLogger(__name__)
137
137
 
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
260
260
  )
261
261
  self.topk.forward = self.topk.forward_native
262
262
 
263
- self.experts = get_moe_impl_class()(
263
+ self.experts = get_moe_impl_class(quant_config)(
264
264
  num_experts=self.num_experts,
265
265
  top_k=self.top_k,
266
266
  layer_id=self.layer_id,
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
853
853
 
854
854
  # Params for weights, fp8 weight scales, fp8 activation scales
855
855
  # (param_name, weight_name, expert_id, shard_id)
856
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
856
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
857
857
  ckpt_gate_proj_name="gate_proj",
858
858
  ckpt_down_proj_name="down_proj",
859
859
  ckpt_up_proj_name="up_proj",
@@ -111,7 +111,7 @@ elif _is_hip:
111
111
  awq_dequantize_triton as awq_dequantize,
112
112
  )
113
113
  else:
114
- from vllm._custom_ops import awq_dequantize
114
+ pass
115
115
 
116
116
 
117
117
  logger = logging.getLogger(__name__)
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
54
  from sglang.srt.model_loader.utils import set_default_torch_dtype
55
55
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
56
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
57
+ from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
57
58
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
58
59
  from sglang.srt.utils import add_prefix, flatten_nested_list
59
60
 
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
581
582
 
582
583
  def init_llm(
583
584
  self,
584
- config: Qwen2Config,
585
+ config: PretrainedConfig,
585
586
  quant_config: Optional[QuantizationConfig] = None,
586
587
  prefix: str = "",
587
588
  ) -> nn.Module:
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
774
775
  return pattern.pad_input_tokens(input_ids, image_inputs)
775
776
 
776
777
 
777
- _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
778
+ class MiniCPMV4_0(MiniCPMBaseModel):
779
+ packed_modules_mapping = {
780
+ "qkv_proj": [
781
+ "q_proj",
782
+ "k_proj",
783
+ "v_proj",
784
+ ],
785
+ "gate_up_proj": [
786
+ "gate_proj",
787
+ "up_proj",
788
+ ],
789
+ }
790
+ # LoRA specific attributes
791
+ supported_lora_modules = [
792
+ # vision encoder
793
+ "fc1",
794
+ "fc2",
795
+ "out_proj",
796
+ # language model
797
+ "qkv_proj", # same name with vision encoder
798
+ "o_proj",
799
+ "gate_up_proj",
800
+ "down_proj",
801
+ # resampler
802
+ "kv_proj",
803
+ ]
804
+
805
+ # BitandBytes specific attributes
806
+ bitsandbytes_stacked_params_mapping = {
807
+ # shard_name, weight_name, index
808
+ "q_proj": ("qkv_proj", 0),
809
+ "k_proj": ("qkv_proj", 1),
810
+ "v_proj": ("qkv_proj", 2),
811
+ "gate_proj": ("gate_up_proj", 0),
812
+ "up_proj": ("gate_up_proj", 1),
813
+ }
814
+
815
+ embedding_modules = {}
816
+ embedding_padding_modules = []
817
+
818
+ def __init__(
819
+ self,
820
+ config: PretrainedConfig,
821
+ quant_config: Optional[QuantizationConfig] = None,
822
+ prefix: str = "",
823
+ ):
824
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
825
+ assert self.version == (4, 0)
826
+
827
+ def init_llm(
828
+ self,
829
+ config: LlamaConfig,
830
+ quant_config: Optional[QuantizationConfig] = None,
831
+ prefix: str = "",
832
+ ) -> nn.Module:
833
+ return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
834
+
835
+ def init_vision_module(
836
+ self,
837
+ config: PretrainedConfig,
838
+ quant_config: Optional[QuantizationConfig],
839
+ prefix: str = "",
840
+ ) -> nn.Module:
841
+ model = Idefics2VisionTransformer(
842
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
843
+ )
844
+ if self.config.drop_vision_last_layer:
845
+ model.encoder.layers = model.encoder.layers[:-1]
846
+
847
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
848
+ setattr(model, "patch_size", model.embeddings.patch_size)
849
+ return model
850
+
851
+ def init_resampler(
852
+ self,
853
+ embed_dim: int,
854
+ vision_dim: int,
855
+ quant_config: Optional[QuantizationConfig] = None,
856
+ prefix: str = "",
857
+ ) -> nn.Module:
858
+ with set_default_torch_dtype(torch.float16):
859
+ # The resampler in 2.6 remains consistent with the one in 2.5.
860
+ resampler = Resampler2_5(
861
+ num_queries=self.config.query_num,
862
+ embed_dim=embed_dim,
863
+ num_heads=embed_dim // 128,
864
+ kv_dim=vision_dim,
865
+ quant_config=quant_config,
866
+ prefix=prefix,
867
+ )
868
+
869
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
870
+
871
+ def get_vision_embedding(
872
+ self,
873
+ pixel_values: List[torch.Tensor],
874
+ patch_attn_mask: Optional[torch.Tensor] = None,
875
+ tgt_sizes: Optional[torch.Tensor] = None,
876
+ ) -> torch.Tensor:
877
+ vision_embedding = self.vpm(
878
+ pixel_values,
879
+ patch_attention_mask=patch_attn_mask,
880
+ tgt_sizes=tgt_sizes,
881
+ )
882
+ return vision_embedding
883
+
884
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
885
+ # list of tensors
886
+ pixel_values = flatten_nested_list([item.feature for item in items])
887
+ tgt_sizes = torch.stack(
888
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
889
+ )
890
+ assert len(pixel_values) == tgt_sizes.shape[0]
891
+
892
+ device = self.vpm.embeddings.position_embedding.weight.device
893
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
894
+ all_pixel_values_lst = [
895
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
896
+ ]
897
+
898
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
899
+ assert isinstance(max_patches, int)
900
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
901
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
902
+ )
903
+
904
+ B, L, _ = all_pixel_values.shape
905
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
906
+ patch_attn_mask = torch.zeros(
907
+ (B, 1, max_patches), dtype=torch.bool, device=device
908
+ )
909
+
910
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
911
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
912
+ patch_attn_mask[:, 0, :] = torch.arange(
913
+ patch_attn_mask.size(2), device=patch_attn_mask.device
914
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
915
+
916
+ vision_embedding = self.vpm(
917
+ all_pixel_values.type(dtype),
918
+ patch_attention_mask=patch_attn_mask,
919
+ tgt_sizes=tgt_sizes,
920
+ )
921
+ return self.resampler(vision_embedding, tgt_sizes)
922
+
923
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
924
+ # Get all special token IDs
925
+ im_start_id: int = image_inputs.im_start_id
926
+ im_end_id: int = image_inputs.im_end_id
927
+ slice_start_id: int = image_inputs.slice_start_id
928
+ slice_end_id: int = image_inputs.slice_end_id
929
+
930
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
931
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
932
+
933
+ return pattern.pad_input_tokens(input_ids, image_inputs)
934
+
935
+
936
+ _SUPPORT_VERSION = {
937
+ (2, 6): MiniCPMV2_6,
938
+ (4, 0): MiniCPMV4_0,
939
+ }
778
940
 
779
941
 
780
942
  class MiniCPMV:
@@ -809,7 +971,7 @@ class MiniCPMV:
809
971
  # Dispatch class based on version
810
972
  instance_class = _SUPPORT_VERSION.get(version)
811
973
  if instance_class is None:
812
- raise ValueError("Currently, MiniCPMV only supports versions 2.6")
974
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
813
975
 
814
976
  try:
815
977
  minicpmv = instance_class(
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
291
291
 
292
292
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
293
  hidden_states = self.unfold(hidden_states)
294
- hidden_states = hidden_states.permute(0, 2, 1)
294
+ hidden_states = hidden_states.permute(0, 2, 1).contiguous()
295
295
  hidden_states, _ = self.linear(hidden_states)
296
296
  return hidden_states
297
297
 
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
446
446
  )
447
447
 
448
448
  if self.has_vision:
449
+ # TODO: make this more general
450
+ ignore_quant_layers = getattr(config, "quantization_config", {}).get(
451
+ "ignore", {}
452
+ )
453
+ if (
454
+ "model.layers.vision_model*" in ignore_quant_layers
455
+ and "model.layers.multi_modal_projector*" in ignore_quant_layers
456
+ ):
457
+ vision_quant_config = None
458
+ else:
459
+ vision_quant_config = quant_config
449
460
  self.vision_model = Llama4VisionModel(
450
461
  config.vision_config,
451
- quant_config=quant_config,
462
+ quant_config=vision_quant_config,
452
463
  prefix=add_prefix("vision_model", prefix),
453
464
  )
454
465
 
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
560
571
  forward_batch=forward_batch,
561
572
  language_model=self.language_model,
562
573
  data_embedding_funcs={
563
- Modality.IMAGE: self.get_image_feature,
574
+ Modality.IMAGE: image_embedding_func,
564
575
  },
565
576
  positions=positions,
566
577
  )
@@ -689,7 +700,7 @@ class Llama4ForConditionalGeneration(nn.Module):
689
700
  """Handle scale parameter remapping. Returns True if handled."""
690
701
  if "scale" in name and "expert" not in name:
691
702
  remapped_name = maybe_remap_kv_scale_name(name, params_dict)
692
- return remapped_name is None
703
+ return remapped_name is not None and remapped_name != name
693
704
  return False
694
705
 
695
706
  def _handle_stacked_params(
@@ -961,5 +972,30 @@ class Llama4ForConditionalGeneration(nn.Module):
961
972
  def set_embed(self, embed):
962
973
  return self.language_model.set_embed(embed)
963
974
 
975
+ def get_hidden_dim(self, module_name, layer_idx):
976
+ # return input_dim, output_dim
977
+ if module_name == "qkv_proj":
978
+ return (
979
+ self.config.hidden_size,
980
+ self.config.head_dim
981
+ * (
982
+ self.config.num_attention_heads
983
+ + self.config.num_key_value_heads * 2
984
+ ),
985
+ )
986
+ elif module_name == "o_proj":
987
+ return (
988
+ self.config.head_dim * self.config.num_attention_heads,
989
+ self.config.hidden_size,
990
+ )
991
+ elif module_name == "gate_up_proj":
992
+ return self.config.hidden_size, self.config.intermediate_size * 2
993
+ elif module_name == "down_proj":
994
+ decoder_layer = self.language_model.get_layers()[layer_idx]
995
+ intermediate_size = decoder_layer.get_intermediate_size()
996
+ return intermediate_size, self.config.hidden_size
997
+ else:
998
+ raise NotImplementedError()
999
+
964
1000
 
965
1001
  EntryClass = Llama4ForConditionalGeneration