sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,787 @@
1
+ # Copyright 2025 Qwen Team
2
+ # Copyright 2025 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
+ import logging
17
+ from functools import lru_cache, partial
18
+ from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import rearrange
25
+ from transformers.activations import ACT2FN
26
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
27
+ Qwen2_5_VisionRotaryEmbedding,
28
+ )
29
+
30
+ from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
31
+ from sglang.srt.layers.attention.vision import VisionAttention
32
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
34
+ from sglang.srt.layers.pooler import Pooler, PoolingType
35
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
37
+ from sglang.srt.managers.mm_utils import (
38
+ MultiModalityDataPaddingPatternMultimodalTokens,
39
+ general_mm_embed_routine,
40
+ )
41
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
44
+ from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
45
+ from sglang.srt.models.qwen3 import Qwen3Model
46
+ from sglang.srt.utils import add_prefix
47
+ from sglang.srt.utils.hf_transformers_utils import get_processor
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ # === Vision Encoder === #
52
+
53
+
54
+ class Qwen3_VisionMLP(nn.Module):
55
+
56
+ def __init__(
57
+ self,
58
+ in_features: int,
59
+ hidden_features: int,
60
+ bias: bool = True,
61
+ hidden_act="silu",
62
+ quant_config: Optional[QuantizationConfig] = None,
63
+ prefix: str = "",
64
+ ):
65
+ super().__init__()
66
+ self.linear_fc1 = ColumnParallelLinear(
67
+ in_features,
68
+ hidden_features,
69
+ bias=bias,
70
+ quant_config=quant_config,
71
+ prefix=add_prefix("linear_fc1", prefix),
72
+ )
73
+ self.linear_fc2 = RowParallelLinear(
74
+ hidden_features,
75
+ in_features,
76
+ bias=bias,
77
+ quant_config=quant_config,
78
+ prefix=add_prefix("linear_fc2", prefix),
79
+ )
80
+ self.act = ACT2FN[hidden_act]
81
+
82
+ def forward(self, x: torch.Tensor):
83
+ x_fc1, _ = self.linear_fc1(x)
84
+ mlp_output, _ = self.linear_fc2(self.act(x_fc1))
85
+ return mlp_output
86
+
87
+
88
+ class Qwen3VLVisionPatchEmbed(nn.Module):
89
+ def __init__(self, config) -> None:
90
+ super().__init__()
91
+ self.patch_size = config.patch_size
92
+ self.temporal_patch_size = config.temporal_patch_size
93
+ self.in_channels = config.in_channels
94
+ self.embed_dim = config.hidden_size
95
+
96
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
97
+ self.proj = nn.Conv3d(
98
+ self.in_channels,
99
+ self.embed_dim,
100
+ kernel_size=kernel_size,
101
+ stride=kernel_size,
102
+ bias=True,
103
+ )
104
+
105
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
106
+ target_dtype = self.proj.weight.dtype
107
+ hidden_states = hidden_states.view(
108
+ -1,
109
+ self.in_channels,
110
+ self.temporal_patch_size,
111
+ self.patch_size,
112
+ self.patch_size,
113
+ )
114
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
115
+ -1, self.embed_dim
116
+ )
117
+ return hidden_states
118
+
119
+
120
+ class Qwen3_VisionBlock(nn.Module):
121
+
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ num_heads: int,
126
+ intermediate_dim: int,
127
+ hidden_act="silu",
128
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
129
+ attn_implementation: Optional[str] = "sdpa",
130
+ quant_config: Optional[QuantizationConfig] = None,
131
+ prefix: str = "",
132
+ ) -> None:
133
+ super().__init__()
134
+ if norm_layer is None:
135
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
136
+ self.norm1 = norm_layer(dim)
137
+ self.norm2 = norm_layer(dim)
138
+
139
+ if attn_implementation == "sdpa":
140
+ softmax_in_single_precision = False
141
+ qkv_backend = "sdpa"
142
+ flatten_batch = True
143
+ elif attn_implementation == "flash_attention_2":
144
+ softmax_in_single_precision = False
145
+ qkv_backend = "triton_attn"
146
+ flatten_batch = True
147
+ elif attn_implementation == "eager":
148
+ softmax_in_single_precision = True
149
+ qkv_backend = "sdpa"
150
+ flatten_batch = True
151
+ elif attn_implementation == "flash_attention_3":
152
+ softmax_in_single_precision = False
153
+ qkv_backend = "fa3"
154
+ flatten_batch = True
155
+
156
+ self.attn = VisionAttention(
157
+ embed_dim=dim,
158
+ num_heads=num_heads,
159
+ projection_size=dim,
160
+ use_qkv_parallel=True,
161
+ rotary_embed="normal",
162
+ proj_bias=True,
163
+ qkv_backend=qkv_backend,
164
+ softmax_in_single_precision=softmax_in_single_precision,
165
+ flatten_batch=flatten_batch,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("attn", prefix),
168
+ )
169
+ self.mlp = Qwen3_VisionMLP(
170
+ dim,
171
+ intermediate_dim,
172
+ hidden_act=hidden_act,
173
+ bias=True,
174
+ quant_config=quant_config,
175
+ prefix=f"{prefix}.mlp",
176
+ )
177
+
178
+ def forward(
179
+ self,
180
+ x: torch.Tensor,
181
+ cu_seqlens: torch.Tensor,
182
+ position_embeddings: torch.Tensor,
183
+ ) -> torch.Tensor:
184
+ hidden_states = self.norm1(x)
185
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
186
+ attn = self.attn(
187
+ hidden_states,
188
+ cu_seqlens=cu_seqlens,
189
+ position_embeddings=position_embeddings,
190
+ )
191
+ attn = rearrange(attn, "b s ... -> s b ...")
192
+ x = x + attn
193
+ norm2 = self.norm2(x)
194
+ mlp = self.mlp(norm2)
195
+ x = x + mlp
196
+ return x
197
+
198
+
199
+ class Qwen3_VisionPatchMerger(nn.Module):
200
+
201
+ def __init__(
202
+ self,
203
+ dim: int,
204
+ context_dim: int,
205
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
206
+ spatial_merge_size: int = 2,
207
+ use_postshuffle_norm: bool = False,
208
+ quant_config: Optional[QuantizationConfig] = None,
209
+ prefix: str = "",
210
+ ) -> None:
211
+ super().__init__()
212
+ self.hidden_size = context_dim * (spatial_merge_size**2)
213
+
214
+ self.use_postshuffle_norm = use_postshuffle_norm
215
+
216
+ if norm_layer is None:
217
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
218
+ self.norm = norm_layer(
219
+ self.hidden_size if use_postshuffle_norm else context_dim
220
+ )
221
+ self.linear_fc1 = ColumnParallelLinear(
222
+ self.hidden_size,
223
+ self.hidden_size,
224
+ bias=True,
225
+ quant_config=quant_config,
226
+ prefix=add_prefix("linear_fc1", prefix),
227
+ )
228
+ self.act_fn = nn.GELU()
229
+ self.linear_fc2 = RowParallelLinear(
230
+ self.hidden_size,
231
+ dim,
232
+ bias=True,
233
+ quant_config=quant_config,
234
+ prefix=add_prefix("linear_fc2", prefix),
235
+ )
236
+
237
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
238
+ if self.use_postshuffle_norm:
239
+ x = self.norm(x.view(-1, self.hidden_size))
240
+ else:
241
+ x = self.norm(x).view(-1, self.hidden_size)
242
+
243
+ x_parallel, _ = self.linear_fc1(x)
244
+ x_parallel = self.act_fn(x_parallel)
245
+ out, _ = self.linear_fc2(x_parallel)
246
+ return out
247
+
248
+
249
+ class Qwen3_VisionTransformer(nn.Module):
250
+
251
+ def __init__(
252
+ self,
253
+ vision_config: Qwen3VLVisionConfig,
254
+ norm_eps: float = 1e-6,
255
+ quant_config: Optional[QuantizationConfig] = None,
256
+ prefix: str = "",
257
+ ) -> None:
258
+ super().__init__()
259
+ self.hidden_size = vision_config.hidden_size
260
+ self.num_heads = vision_config.num_heads
261
+ self.num_position_embeddings = vision_config.num_position_embeddings
262
+ self.patch_size = vision_config.patch_size
263
+ self.spatial_merge_size = vision_config.spatial_merge_size
264
+ self.spatial_merge_unit = self.spatial_merge_size**2
265
+ self.temporal_patch_size = vision_config.temporal_patch_size
266
+ self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
267
+ self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
268
+ self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
269
+
270
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
271
+ head_dim = self.hidden_size // self.num_heads
272
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
273
+
274
+ self.blocks = nn.ModuleList(
275
+ [
276
+ Qwen3_VisionBlock(
277
+ dim=self.hidden_size,
278
+ num_heads=self.num_heads,
279
+ intermediate_dim=vision_config.intermediate_size,
280
+ hidden_act=vision_config.hidden_act,
281
+ norm_layer=norm_layer,
282
+ attn_implementation="flash_attention_3",
283
+ quant_config=quant_config,
284
+ prefix=add_prefix(f"blocks.{layer_idx}", prefix),
285
+ )
286
+ for layer_idx in range(vision_config.depth)
287
+ ]
288
+ )
289
+ self.merger = Qwen3_VisionPatchMerger(
290
+ dim=vision_config.out_hidden_size,
291
+ context_dim=self.hidden_size,
292
+ norm_layer=norm_layer,
293
+ spatial_merge_size=self.spatial_merge_size,
294
+ quant_config=quant_config,
295
+ prefix=add_prefix("merger", prefix),
296
+ )
297
+
298
+ self.deepstack_merger_list = nn.ModuleList(
299
+ [
300
+ Qwen3_VisionPatchMerger(
301
+ dim=vision_config.out_hidden_size,
302
+ context_dim=self.hidden_size,
303
+ spatial_merge_size=self.spatial_merge_size,
304
+ use_postshuffle_norm=True,
305
+ norm_layer=norm_layer,
306
+ quant_config=quant_config,
307
+ prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
308
+ )
309
+ for layer_idx in range(len(self.deepstack_visual_indexes))
310
+ ]
311
+ )
312
+
313
+ @property
314
+ def dtype(self) -> torch.dtype:
315
+ return self.patch_embed.proj.weight.dtype
316
+
317
+ @property
318
+ def device(self) -> torch.device:
319
+ return self.patch_embed.proj.weight.device
320
+
321
+ def rot_pos_emb(self, grid_thw):
322
+ pos_ids = []
323
+ for t, h, w in grid_thw:
324
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
325
+ hpos_ids = hpos_ids.reshape(
326
+ h // self.spatial_merge_size,
327
+ self.spatial_merge_size,
328
+ w // self.spatial_merge_size,
329
+ self.spatial_merge_size,
330
+ )
331
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
332
+ hpos_ids = hpos_ids.flatten()
333
+
334
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
335
+ wpos_ids = wpos_ids.reshape(
336
+ h // self.spatial_merge_size,
337
+ self.spatial_merge_size,
338
+ w // self.spatial_merge_size,
339
+ self.spatial_merge_size,
340
+ )
341
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
342
+ wpos_ids = wpos_ids.flatten()
343
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
344
+ pos_ids = torch.cat(pos_ids, dim=0)
345
+ max_grid_size = grid_thw[:, 1:].max()
346
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
347
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
348
+ return rotary_pos_emb
349
+
350
+ def fast_pos_embed_interpolate(self, grid_thw):
351
+ num_grid_per_side = int(self.num_position_embeddings**0.5)
352
+
353
+ idx_list = [[] for _ in range(4)]
354
+ weight_list = [[] for _ in range(4)]
355
+
356
+ # TODO: use torch instand of np
357
+ for t, h, w in grid_thw:
358
+ h_idxs = np.linspace(0, num_grid_per_side - 1, h)
359
+ w_idxs = np.linspace(0, num_grid_per_side - 1, w)
360
+
361
+ h_idxs_floor = h_idxs.astype(int)
362
+ w_idxs_floor = w_idxs.astype(int)
363
+ h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
364
+ w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
365
+
366
+ dh = h_idxs - h_idxs_floor
367
+ dw = w_idxs - w_idxs_floor
368
+
369
+ idx_list[0].extend(
370
+ ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
371
+ .flatten()
372
+ .tolist()
373
+ * t
374
+ )
375
+ idx_list[1].extend(
376
+ ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
377
+ .flatten()
378
+ .tolist()
379
+ * t
380
+ )
381
+ idx_list[2].extend(
382
+ ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
383
+ .flatten()
384
+ .tolist()
385
+ * t
386
+ )
387
+ idx_list[3].extend(
388
+ ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
389
+ .flatten()
390
+ .tolist()
391
+ * t
392
+ )
393
+
394
+ weight_list[0].extend(
395
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
396
+ )
397
+ weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
398
+ weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
399
+ weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
400
+
401
+ device = self.pos_embed.weight.device
402
+ dtype = self.pos_embed.weight.dtype
403
+
404
+ p0 = (
405
+ self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
406
+ * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
407
+ )
408
+ p1 = (
409
+ self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
410
+ * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
411
+ )
412
+ p2 = (
413
+ self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
414
+ * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
415
+ )
416
+ p3 = (
417
+ self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
418
+ * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
419
+ )
420
+
421
+ patch_pos_embeds = p0 + p1 + p2 + p3
422
+ patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
423
+ patch_pos_embeds_permute = []
424
+ m_size = self.spatial_merge_size
425
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
426
+ pos_embed = (
427
+ pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
428
+ .permute(0, 1, 3, 2, 4, 5)
429
+ .flatten(0, 4)
430
+ )
431
+ patch_pos_embeds_permute.append(pos_embed)
432
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
433
+ return patch_pos_embeds
434
+
435
+ def forward(
436
+ self,
437
+ x: torch.Tensor,
438
+ grid_thw: torch.Tensor,
439
+ ) -> torch.Tensor:
440
+ x = x.to(device=self.device, dtype=self.dtype)
441
+ x = self.patch_embed(x)
442
+
443
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
444
+ x = x + pos_embeds
445
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
446
+
447
+ seq_len, _ = x.size()
448
+ rotary_pos_emb = rotary_pos_emb.to(x.device)
449
+
450
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
451
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
452
+ position_embeddings = (emb.cos(), emb.sin())
453
+
454
+ # compute cu_seqlens
455
+ cu_seqlens = torch.cat(
456
+ [
457
+ torch.tensor([0], device=grid_thw.device),
458
+ (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
459
+ ]
460
+ )
461
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
462
+
463
+ # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
464
+ x = x.unsqueeze(1)
465
+
466
+ deepstack_feature_lists = []
467
+ num_deepstack_captured = 0
468
+ for layer_num, blk in enumerate(self.blocks):
469
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
470
+ if layer_num in self.deepstack_visual_indexes:
471
+ deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
472
+ x
473
+ )
474
+ deepstack_feature_lists.append(deepstack_feature)
475
+ num_deepstack_captured += 1
476
+ x = self.merger(x)
477
+ hidden_states = torch.cat(
478
+ [x] + deepstack_feature_lists, dim=1
479
+ ) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
480
+ return hidden_states
481
+
482
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
483
+ stacked_params_mapping = [
484
+ # (param_name, shard_name, shard_id)
485
+ ("attn.qkv.", "attn.q.", "q"),
486
+ ("attn.qkv.", "attn.k.", "k"),
487
+ ("attn.qkv.", "attn.v.", "v"),
488
+ ]
489
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
490
+ loaded_params: set[str] = set()
491
+
492
+ for name, loaded_weight in weights:
493
+ for param_name, weight_name, shard_id in stacked_params_mapping:
494
+ if weight_name not in name:
495
+ continue
496
+ name = name.replace(weight_name, param_name)
497
+
498
+ param = params_dict[name]
499
+ weight_loader = param.weight_loader
500
+ weight_loader(param, loaded_weight, shard_id)
501
+ break
502
+ else:
503
+ param = params_dict[name]
504
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
505
+ weight_loader(param, loaded_weight)
506
+ loaded_params.add(name)
507
+ return loaded_params
508
+
509
+
510
+ cached_get_processor = lru_cache(get_processor)
511
+
512
+
513
+ class Qwen3LLMModel(Qwen3Model):
514
+
515
+ def __init__(
516
+ self,
517
+ *,
518
+ config: Qwen3VLConfig,
519
+ quant_config: Optional[QuantizationConfig] = None,
520
+ prefix: str = "",
521
+ ):
522
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
523
+ if not self.pp_group.is_first_rank:
524
+ assert self.start_layer >= len(
525
+ config.vision_config.deepstack_visual_indexes
526
+ ), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
527
+
528
+ self.hidden_size = config.hidden_size
529
+ self.deepstack_embed_to_decoder_layer = range(
530
+ len(config.vision_config.deepstack_visual_indexes)
531
+ )
532
+
533
+ def forward(
534
+ self,
535
+ input_ids: torch.Tensor,
536
+ positions: torch.Tensor,
537
+ forward_batch: ForwardBatch,
538
+ input_embeds: torch.Tensor = None,
539
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
540
+ input_deepstack_embeds: Optional[torch.Tensor] = None,
541
+ ) -> Union[torch.Tensor, PPProxyTensors]:
542
+
543
+ if self.pp_group.is_first_rank:
544
+ if input_embeds is None:
545
+ hidden_states = self.embed_tokens(input_ids)
546
+ else:
547
+ hidden_states = input_embeds
548
+ residual = None
549
+ else:
550
+ assert pp_proxy_tensors is not None
551
+ hidden_states = pp_proxy_tensors["hidden_states"]
552
+ residual = pp_proxy_tensors["residual"]
553
+
554
+ aux_hidden_states = []
555
+ for layer_idx, layer in enumerate(
556
+ self.layers[self.start_layer : self.end_layer]
557
+ ):
558
+ layer_idx = layer_idx + self.start_layer
559
+ if layer_idx in self.layers_to_capture:
560
+ aux_hidden_states.append(
561
+ hidden_states + residual if residual is not None else hidden_states
562
+ )
563
+
564
+ hidden_states, residual = layer(
565
+ positions,
566
+ hidden_states,
567
+ forward_batch,
568
+ residual,
569
+ )
570
+
571
+ # process deepstack
572
+ if (
573
+ input_deepstack_embeds is not None
574
+ and layer_idx in self.deepstack_embed_to_decoder_layer
575
+ ):
576
+ sep = self.hidden_size * layer_idx
577
+ hidden_states = (
578
+ hidden_states
579
+ + input_deepstack_embeds[:, sep : sep + self.hidden_size]
580
+ )
581
+
582
+ if not self.pp_group.is_last_rank:
583
+ return PPProxyTensors(
584
+ {
585
+ "hidden_states": hidden_states,
586
+ "residual": residual,
587
+ }
588
+ )
589
+ else:
590
+ if hidden_states.shape[0] != 0:
591
+ if residual is None:
592
+ hidden_states = self.norm(hidden_states)
593
+ else:
594
+ hidden_states, _ = self.norm(hidden_states, residual)
595
+
596
+ if len(aux_hidden_states) == 0:
597
+ return hidden_states
598
+
599
+ return hidden_states, aux_hidden_states
600
+
601
+
602
+ class Qwen3VLForConditionalGeneration(nn.Module):
603
+ def __init__(
604
+ self,
605
+ config: Qwen3VLConfig,
606
+ quant_config: Optional[QuantizationConfig] = None,
607
+ prefix: str = "",
608
+ ) -> None:
609
+ super().__init__()
610
+
611
+ self.config = config
612
+ self.visual = Qwen3_VisionTransformer(
613
+ config.vision_config,
614
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
615
+ # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
616
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
617
+ quant_config=quant_config,
618
+ prefix=add_prefix("visual", prefix),
619
+ )
620
+
621
+ self.model = Qwen3LLMModel(
622
+ config=config,
623
+ quant_config=quant_config,
624
+ prefix=add_prefix("model", prefix),
625
+ )
626
+
627
+ if config.tie_word_embeddings:
628
+ self.lm_head = self.model.embed_tokens
629
+ else:
630
+ self.lm_head = ParallelLMHead(
631
+ config.vocab_size,
632
+ config.hidden_size,
633
+ quant_config=quant_config,
634
+ prefix=add_prefix("lm_head", prefix),
635
+ )
636
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
637
+
638
+ self.logits_processor = LogitsProcessor(config)
639
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
640
+ # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
641
+ # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
642
+
643
+ # deepstack
644
+ self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
645
+ self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
646
+
647
+ @property
648
+ def use_deepstack(self) -> bool:
649
+ return hasattr(self, "deepstack_visual_indexes")
650
+
651
+ def separate_deepstack_embeds(self, embedding):
652
+ assert (
653
+ embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
654
+ ), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
655
+
656
+ separate_index = self.config.hidden_size
657
+ input_embeds = embedding[:, :separate_index]
658
+ input_deepstack_embeds = embedding[:, separate_index:]
659
+ return input_embeds, input_deepstack_embeds
660
+
661
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
662
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
663
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
664
+
665
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
666
+ # in qwen-vl, last dim is the same
667
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
668
+ self.visual.dtype
669
+ )
670
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
671
+ assert pixel_values.dim() == 2, pixel_values.dim()
672
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
673
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
674
+ return image_embeds
675
+
676
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
677
+ # in qwen-vl, last dim is the same
678
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
679
+ self.visual.dtype
680
+ )
681
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
682
+ assert pixel_values.dim() == 2, pixel_values.dim()
683
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
684
+ video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
685
+ return video_embeds
686
+
687
+ def get_input_embeddings(self):
688
+ return self.model.embed_tokens
689
+
690
+ def forward(
691
+ self,
692
+ input_ids: torch.Tensor,
693
+ positions: torch.Tensor,
694
+ forward_batch: ForwardBatch,
695
+ get_embedding: bool = False,
696
+ ):
697
+ """Run forward pass for Qwen3-VL.
698
+
699
+ Args:
700
+ input_ids: Flattened (concatenated) input_ids corresponding to a
701
+ batch.
702
+ positions: Flattened (concatenated) position ids corresponding to a
703
+ batch.
704
+ **NOTE**: If mrope is enabled (default setting for Qwen2-VL
705
+ opensource models), the shape will be `(3, seq_len)`,
706
+ otherwise it will be `(seq_len,).
707
+ (Use input_metadata.mrope_positions to replace it)
708
+ """
709
+ if self.is_mrope_enabled:
710
+ positions = forward_batch.mrope_positions
711
+
712
+ if not (
713
+ forward_batch.forward_mode.is_decode()
714
+ or not forward_batch.contains_image_inputs()
715
+ ):
716
+ if self.is_mrope_enabled:
717
+ assert positions.ndim == 2 and positions.size(0) == 3, (
718
+ "multimodal section rotary embedding requires "
719
+ f"(3, seq_len) positions, but got {positions.size()}"
720
+ )
721
+
722
+ hidden_states = general_mm_embed_routine(
723
+ input_ids=input_ids,
724
+ forward_batch=forward_batch,
725
+ language_model=self.model,
726
+ multimodal_model=self,
727
+ positions=positions,
728
+ use_deepstack=self.use_deepstack,
729
+ )
730
+
731
+ if not get_embedding:
732
+ return self.logits_processor(
733
+ input_ids, hidden_states, self.lm_head, forward_batch
734
+ )
735
+ else:
736
+ return self.pooler(hidden_states, forward_batch)
737
+
738
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
739
+ stacked_params_mapping = [
740
+ # (param_name, shard_name, shard_id)
741
+ (".qkv_proj", ".q_proj", "q"),
742
+ (".qkv_proj", ".k_proj", "k"),
743
+ (".qkv_proj", ".v_proj", "v"),
744
+ ("gate_up_proj", "up_proj", 1),
745
+ ("gate_up_proj", "gate_proj", 0),
746
+ ]
747
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
748
+ for name, loaded_weight in weights:
749
+ if "rotary_emb.inv_freq" in name:
750
+ continue
751
+ if "language_model" in name:
752
+ name = name.replace(r"model.language_model.", r"model.")
753
+
754
+ for param_name, weight_name, shard_id in stacked_params_mapping:
755
+ if weight_name not in name:
756
+ continue
757
+ if "visual" in name:
758
+ continue
759
+ name = name.replace(weight_name, param_name)
760
+
761
+ # Skip loading extra bias for GPTQ models.
762
+ if name.endswith(".bias") and name not in params_dict:
763
+ continue
764
+ param = params_dict[name]
765
+ weight_loader = param.weight_loader
766
+ weight_loader(param, loaded_weight, shard_id)
767
+ break
768
+ else:
769
+ if "visual" in name:
770
+ # adapt to VisionAttention
771
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
772
+ name = name.replace(r"model.visual.", r"visual.")
773
+
774
+ try:
775
+ # Skip loading extra bias for GPTQ models.
776
+ if name.endswith(".bias") and name not in params_dict:
777
+ continue
778
+ param = params_dict[name]
779
+ except KeyError:
780
+ print(params_dict.keys())
781
+ raise
782
+
783
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
784
+ weight_loader(param, loaded_weight)
785
+
786
+
787
+ EntryClass = Qwen3VLForConditionalGeneration