sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,791 @@
1
+ # Copyright 2025 Qwen Team
2
+ # Copyright 2025 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
+ import logging
17
+ from functools import lru_cache, partial
18
+ from typing import Callable, Iterable, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ from einops import rearrange
24
+ from transformers.activations import ACT2FN
25
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
26
+ Qwen2_5_VisionRotaryEmbedding,
27
+ )
28
+
29
+ from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
30
+ from sglang.srt.layers.attention.vision import VisionAttention
31
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.pooler import Pooler, PoolingType
34
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
36
+ from sglang.srt.managers.mm_utils import (
37
+ MultiModalityDataPaddingPatternMultimodalTokens,
38
+ general_mm_embed_routine,
39
+ )
40
+ from sglang.srt.managers.schedule_batch import (
41
+ Modality,
42
+ MultimodalDataItem,
43
+ MultimodalInputs,
44
+ )
45
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
46
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
47
+ from sglang.srt.models.qwen3 import Qwen3Model
48
+ from sglang.srt.utils import add_prefix
49
+ from sglang.srt.utils.hf_transformers_utils import get_processor
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ # === Vision Encoder === #
55
+
56
+
57
+ class Qwen3_VisionMLP(nn.Module):
58
+
59
+ def __init__(
60
+ self,
61
+ in_features: int,
62
+ hidden_features: int,
63
+ bias: bool = True,
64
+ hidden_act="silu",
65
+ quant_config: Optional[QuantizationConfig] = None,
66
+ prefix: str = "",
67
+ ):
68
+ super().__init__()
69
+ self.linear_fc1 = ColumnParallelLinear(
70
+ in_features,
71
+ hidden_features,
72
+ bias=bias,
73
+ quant_config=quant_config,
74
+ prefix=add_prefix("linear_fc1", prefix),
75
+ )
76
+ self.linear_fc2 = RowParallelLinear(
77
+ hidden_features,
78
+ in_features,
79
+ bias=bias,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("linear_fc2", prefix),
82
+ )
83
+ self.act = ACT2FN[hidden_act]
84
+
85
+ def forward(self, x: torch.Tensor):
86
+ x_fc1, _ = self.linear_fc1(x)
87
+ mlp_output, _ = self.linear_fc2(self.act(x_fc1))
88
+ return mlp_output
89
+
90
+
91
+ class Qwen3VLVisionPatchEmbed(nn.Module):
92
+ def __init__(self, config) -> None:
93
+ super().__init__()
94
+ self.patch_size = config.patch_size
95
+ self.temporal_patch_size = config.temporal_patch_size
96
+ self.in_channels = config.in_channels
97
+ self.embed_dim = config.hidden_size
98
+
99
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
100
+ self.proj = nn.Conv3d(
101
+ self.in_channels,
102
+ self.embed_dim,
103
+ kernel_size=kernel_size,
104
+ stride=kernel_size,
105
+ bias=True,
106
+ )
107
+
108
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109
+ target_dtype = self.proj.weight.dtype
110
+ hidden_states = hidden_states.view(
111
+ -1,
112
+ self.in_channels,
113
+ self.temporal_patch_size,
114
+ self.patch_size,
115
+ self.patch_size,
116
+ )
117
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
118
+ -1, self.embed_dim
119
+ )
120
+ return hidden_states
121
+
122
+
123
+ class Qwen3_VisionBlock(nn.Module):
124
+
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ intermediate_dim: int,
130
+ hidden_act="silu",
131
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
132
+ attn_implementation: Optional[str] = "sdpa",
133
+ quant_config: Optional[QuantizationConfig] = None,
134
+ prefix: str = "",
135
+ ) -> None:
136
+ super().__init__()
137
+ if norm_layer is None:
138
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
139
+ self.norm1 = norm_layer(dim)
140
+ self.norm2 = norm_layer(dim)
141
+
142
+ if attn_implementation == "sdpa":
143
+ softmax_in_single_precision = False
144
+ qkv_backend = "sdpa"
145
+ flatten_batch = True
146
+ elif attn_implementation == "flash_attention_2":
147
+ softmax_in_single_precision = False
148
+ qkv_backend = "triton_attn"
149
+ flatten_batch = True
150
+ elif attn_implementation == "eager":
151
+ softmax_in_single_precision = True
152
+ qkv_backend = "sdpa"
153
+ flatten_batch = True
154
+ elif attn_implementation == "flash_attention_3":
155
+ softmax_in_single_precision = False
156
+ qkv_backend = "fa3"
157
+ flatten_batch = True
158
+
159
+ self.attn = VisionAttention(
160
+ embed_dim=dim,
161
+ num_heads=num_heads,
162
+ projection_size=dim,
163
+ use_qkv_parallel=True,
164
+ rotary_embed="normal",
165
+ proj_bias=True,
166
+ qkv_backend=qkv_backend,
167
+ softmax_in_single_precision=softmax_in_single_precision,
168
+ flatten_batch=flatten_batch,
169
+ quant_config=quant_config,
170
+ prefix=add_prefix("attn", prefix),
171
+ )
172
+ self.mlp = Qwen3_VisionMLP(
173
+ dim,
174
+ intermediate_dim,
175
+ hidden_act=hidden_act,
176
+ bias=True,
177
+ quant_config=quant_config,
178
+ prefix=f"{prefix}.mlp",
179
+ )
180
+
181
+ def forward(
182
+ self,
183
+ x: torch.Tensor,
184
+ cu_seqlens: torch.Tensor,
185
+ position_embeddings: torch.Tensor,
186
+ ) -> torch.Tensor:
187
+ hidden_states = self.norm1(x)
188
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
189
+ attn = self.attn(
190
+ hidden_states,
191
+ cu_seqlens=cu_seqlens,
192
+ position_embeddings=position_embeddings,
193
+ )
194
+ attn = rearrange(attn, "b s ... -> s b ...")
195
+ x += attn
196
+ norm2 = self.norm2(x)
197
+ mlp = self.mlp(norm2)
198
+ x += mlp
199
+ return x
200
+
201
+
202
+ class Qwen3VLMoeVisionPatchMerger(nn.Module):
203
+
204
+ def __init__(
205
+ self,
206
+ dim: int,
207
+ context_dim: int,
208
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
209
+ spatial_merge_size: int = 2,
210
+ use_postshuffle_norm: bool = False,
211
+ quant_config: Optional[QuantizationConfig] = None,
212
+ prefix: str = "",
213
+ ) -> None:
214
+ super().__init__()
215
+ self.hidden_size = context_dim * (spatial_merge_size**2)
216
+
217
+ self.use_postshuffle_norm = use_postshuffle_norm
218
+
219
+ if norm_layer is None:
220
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
221
+ self.norm = norm_layer(
222
+ self.hidden_size if use_postshuffle_norm else context_dim
223
+ )
224
+ self.linear_fc1 = ColumnParallelLinear(
225
+ self.hidden_size,
226
+ self.hidden_size,
227
+ bias=True,
228
+ quant_config=quant_config,
229
+ prefix=add_prefix("linear_fc1", prefix),
230
+ )
231
+ self.act_fn = nn.GELU()
232
+ self.linear_fc2 = RowParallelLinear(
233
+ self.hidden_size,
234
+ dim,
235
+ bias=True,
236
+ quant_config=quant_config,
237
+ prefix=add_prefix("linear_fc2", prefix),
238
+ )
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ if self.use_postshuffle_norm:
242
+ x = self.norm(x.view(-1, self.hidden_size))
243
+ else:
244
+ x = self.norm(x).view(-1, self.hidden_size)
245
+
246
+ x_parallel, _ = self.linear_fc1(x)
247
+ x_parallel = self.act_fn(x_parallel)
248
+ out, _ = self.linear_fc2(x_parallel)
249
+ return out
250
+
251
+
252
+ class Qwen3VLMoeVisionModel(nn.Module):
253
+
254
+ def __init__(
255
+ self,
256
+ vision_config: Qwen3VLVisionConfig,
257
+ norm_eps: float = 1e-6,
258
+ quant_config: Optional[QuantizationConfig] = None,
259
+ prefix: str = "",
260
+ ) -> None:
261
+ super().__init__()
262
+ self.hidden_size = vision_config.hidden_size
263
+ self.num_heads = vision_config.num_heads
264
+ self.num_position_embeddings = vision_config.num_position_embeddings
265
+ self.patch_size = vision_config.patch_size
266
+ self.spatial_merge_size = vision_config.spatial_merge_size
267
+ self.spatial_merge_unit = self.spatial_merge_size**2
268
+ self.temporal_patch_size = vision_config.temporal_patch_size
269
+ # layer indexes of which layer's output should be deep-stacked
270
+ self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
271
+ self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
272
+ self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
273
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
274
+ head_dim = self.hidden_size // self.num_heads
275
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
276
+
277
+ self.blocks = nn.ModuleList(
278
+ [
279
+ Qwen3_VisionBlock(
280
+ dim=self.hidden_size,
281
+ num_heads=self.num_heads,
282
+ intermediate_dim=vision_config.intermediate_size,
283
+ hidden_act=vision_config.hidden_act,
284
+ norm_layer=norm_layer,
285
+ attn_implementation="flash_attention_3",
286
+ quant_config=quant_config,
287
+ prefix=add_prefix(f"blocks.{layer_idx}", prefix),
288
+ )
289
+ for layer_idx in range(vision_config.depth)
290
+ ]
291
+ )
292
+ self.merger = Qwen3VLMoeVisionPatchMerger(
293
+ dim=vision_config.out_hidden_size,
294
+ context_dim=self.hidden_size,
295
+ norm_layer=norm_layer,
296
+ spatial_merge_size=self.spatial_merge_size,
297
+ quant_config=quant_config,
298
+ prefix=add_prefix("merger", prefix),
299
+ )
300
+
301
+ self.deepstack_merger_list = nn.ModuleList(
302
+ [
303
+ Qwen3VLMoeVisionPatchMerger(
304
+ dim=vision_config.out_hidden_size,
305
+ context_dim=self.hidden_size,
306
+ spatial_merge_size=self.spatial_merge_size,
307
+ use_postshuffle_norm=True,
308
+ norm_layer=norm_layer,
309
+ quant_config=quant_config,
310
+ prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
311
+ )
312
+ for layer_idx in range(len(self.deepstack_visual_indexes))
313
+ ]
314
+ )
315
+
316
+ @property
317
+ def dtype(self) -> torch.dtype:
318
+ return self.patch_embed.proj.weight.dtype
319
+
320
+ @property
321
+ def device(self) -> torch.device:
322
+ return self.patch_embed.proj.weight.device
323
+
324
+ def rot_pos_emb(self, grid_thw):
325
+ pos_ids = []
326
+ for t, h, w in grid_thw:
327
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
328
+ hpos_ids = hpos_ids.reshape(
329
+ h // self.spatial_merge_size,
330
+ self.spatial_merge_size,
331
+ w // self.spatial_merge_size,
332
+ self.spatial_merge_size,
333
+ )
334
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
335
+ hpos_ids = hpos_ids.flatten()
336
+
337
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
338
+ wpos_ids = wpos_ids.reshape(
339
+ h // self.spatial_merge_size,
340
+ self.spatial_merge_size,
341
+ w // self.spatial_merge_size,
342
+ self.spatial_merge_size,
343
+ )
344
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
345
+ wpos_ids = wpos_ids.flatten()
346
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
347
+ pos_ids = torch.cat(pos_ids, dim=0)
348
+ max_grid_size = grid_thw[:, 1:].max()
349
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
350
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
351
+ return rotary_pos_emb
352
+
353
+ def fast_pos_embed_interpolate(self, grid_thw):
354
+ num_grid_per_side = int(self.num_position_embeddings**0.5)
355
+
356
+ idx_list = [[] for _ in range(4)]
357
+ weight_list = [[] for _ in range(4)]
358
+
359
+ # TODO: use torch instand of np
360
+ for t, h, w in grid_thw:
361
+ h_idxs = np.linspace(0, num_grid_per_side - 1, h)
362
+ w_idxs = np.linspace(0, num_grid_per_side - 1, w)
363
+
364
+ h_idxs_floor = h_idxs.astype(int)
365
+ w_idxs_floor = w_idxs.astype(int)
366
+ h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
367
+ w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
368
+
369
+ dh = h_idxs - h_idxs_floor
370
+ dw = w_idxs - w_idxs_floor
371
+
372
+ idx_list[0].extend(
373
+ ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
374
+ .flatten()
375
+ .tolist()
376
+ * t
377
+ )
378
+ idx_list[1].extend(
379
+ ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
380
+ .flatten()
381
+ .tolist()
382
+ * t
383
+ )
384
+ idx_list[2].extend(
385
+ ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
386
+ .flatten()
387
+ .tolist()
388
+ * t
389
+ )
390
+ idx_list[3].extend(
391
+ ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
392
+ .flatten()
393
+ .tolist()
394
+ * t
395
+ )
396
+
397
+ weight_list[0].extend(
398
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
399
+ )
400
+ weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
401
+ weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
402
+ weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
403
+
404
+ device = self.pos_embed.weight.device
405
+ dtype = self.pos_embed.weight.dtype
406
+
407
+ p0 = (
408
+ self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
409
+ * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
410
+ )
411
+ p1 = (
412
+ self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
413
+ * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
414
+ )
415
+ p2 = (
416
+ self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
417
+ * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
418
+ )
419
+ p3 = (
420
+ self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
421
+ * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
422
+ )
423
+
424
+ patch_pos_embeds = p0 + p1 + p2 + p3
425
+ patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
426
+ patch_pos_embeds_permute = []
427
+ m_size = self.spatial_merge_size
428
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
429
+ pos_embed = (
430
+ pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
431
+ .permute(0, 1, 3, 2, 4, 5)
432
+ .flatten(0, 4)
433
+ )
434
+ patch_pos_embeds_permute.append(pos_embed)
435
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
436
+ return patch_pos_embeds
437
+
438
+ def forward(
439
+ self,
440
+ x: torch.Tensor,
441
+ grid_thw: torch.Tensor,
442
+ ) -> torch.Tensor:
443
+ x = x.to(device=self.device, dtype=self.dtype)
444
+ x = self.patch_embed(x)
445
+
446
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
447
+ x += pos_embeds
448
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
449
+
450
+ seq_len, _ = x.size()
451
+ rotary_pos_emb = rotary_pos_emb.to(x.device)
452
+
453
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
454
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
455
+ position_embeddings = (emb.cos(), emb.sin())
456
+
457
+ # compute cu_seqlens
458
+ cu_seqlens = torch.repeat_interleave(
459
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
460
+ ).cumsum(dim=0)
461
+ cu_seqlens = torch.cat(
462
+ [
463
+ torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
464
+ cu_seqlens.to(torch.int32),
465
+ ]
466
+ )
467
+
468
+ x = x.unsqueeze(1)
469
+
470
+ deepstack_feature_lists = []
471
+ num_deepstack_captured = 0
472
+ for layer_num, blk in enumerate(self.blocks):
473
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
474
+ if layer_num in self.deepstack_visual_indexes:
475
+ deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
476
+ x
477
+ )
478
+ deepstack_feature_lists.append(deepstack_feature)
479
+ num_deepstack_captured += 1
480
+ x = self.merger(x)
481
+ hidden_states = torch.cat(
482
+ [x] + deepstack_feature_lists, dim=1
483
+ ) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
484
+ return hidden_states
485
+
486
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
487
+ stacked_params_mapping = [
488
+ # (param_name, shard_name, shard_id)
489
+ ("attn.qkv.", "attn.q.", "q"),
490
+ ("attn.qkv.", "attn.k.", "k"),
491
+ ("attn.qkv.", "attn.v.", "v"),
492
+ ]
493
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
494
+ loaded_params: set[str] = set()
495
+
496
+ for name, loaded_weight in weights:
497
+ for param_name, weight_name, shard_id in stacked_params_mapping:
498
+ if weight_name not in name:
499
+ continue
500
+ name = name.replace(weight_name, param_name)
501
+
502
+ param = params_dict[name]
503
+ weight_loader = param.weight_loader
504
+ weight_loader(param, loaded_weight, shard_id)
505
+ break
506
+ else:
507
+ param = params_dict[name]
508
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
509
+ weight_loader(param, loaded_weight)
510
+ loaded_params.add(name)
511
+ return loaded_params
512
+
513
+
514
+ cached_get_processor = lru_cache(get_processor)
515
+
516
+
517
+ class Qwen3LLMModel(Qwen3Model):
518
+
519
+ def __init__(
520
+ self,
521
+ *,
522
+ config: Qwen3VLConfig,
523
+ quant_config: Optional[QuantizationConfig] = None,
524
+ prefix: str = "",
525
+ ):
526
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
527
+ if not self.pp_group.is_first_rank:
528
+ assert self.start_layer >= len(
529
+ config.vision_config.deepstack_visual_indexes
530
+ ), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
531
+
532
+ self.hidden_size = config.hidden_size
533
+ self.deepstack_embed_to_decoder_layer = range(
534
+ len(config.vision_config.deepstack_visual_indexes)
535
+ )
536
+
537
+ def forward(
538
+ self,
539
+ input_ids: torch.Tensor,
540
+ positions: torch.Tensor,
541
+ forward_batch: ForwardBatch,
542
+ input_embeds: torch.Tensor = None,
543
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
544
+ input_deepstack_embeds: Optional[torch.Tensor] = None,
545
+ ) -> Union[torch.Tensor, PPProxyTensors]:
546
+
547
+ if self.pp_group.is_first_rank:
548
+ if input_embeds is None:
549
+ hidden_states = self.embed_tokens(input_ids)
550
+ else:
551
+ hidden_states = input_embeds
552
+ residual = None
553
+ else:
554
+ assert pp_proxy_tensors is not None
555
+ hidden_states = pp_proxy_tensors["hidden_states"]
556
+ residual = pp_proxy_tensors["residual"]
557
+
558
+ aux_hidden_states = []
559
+ for layer_idx, layer in enumerate(
560
+ self.layers[self.start_layer : self.end_layer]
561
+ ):
562
+ layer_idx = layer_idx + self.start_layer
563
+ if layer_idx in self.layers_to_capture:
564
+ aux_hidden_states.append(
565
+ hidden_states + residual if residual is not None else hidden_states
566
+ )
567
+
568
+ hidden_states, residual = layer(
569
+ positions,
570
+ hidden_states,
571
+ forward_batch,
572
+ residual,
573
+ )
574
+
575
+ # process deepstack
576
+ if (
577
+ input_deepstack_embeds is not None
578
+ and layer_idx in self.deepstack_embed_to_decoder_layer
579
+ ):
580
+ sep = self.hidden_size * layer_idx
581
+ hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size]
582
+
583
+ if not self.pp_group.is_last_rank:
584
+ return PPProxyTensors(
585
+ {
586
+ "hidden_states": hidden_states,
587
+ "residual": residual,
588
+ }
589
+ )
590
+ else:
591
+ if hidden_states.shape[0] != 0:
592
+ if residual is None:
593
+ hidden_states = self.norm(hidden_states)
594
+ else:
595
+ hidden_states, _ = self.norm(hidden_states, residual)
596
+
597
+ if len(aux_hidden_states) == 0:
598
+ return hidden_states
599
+
600
+ return hidden_states, aux_hidden_states
601
+
602
+
603
+ class Qwen3VLForConditionalGeneration(nn.Module):
604
+ def __init__(
605
+ self,
606
+ config: Qwen3VLConfig,
607
+ quant_config: Optional[QuantizationConfig] = None,
608
+ prefix: str = "",
609
+ language_model_cls=Qwen3LLMModel,
610
+ ) -> None:
611
+ super().__init__()
612
+
613
+ self.visual = Qwen3VLMoeVisionModel(
614
+ config.vision_config,
615
+ # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
616
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
617
+ quant_config=quant_config,
618
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
619
+ prefix=add_prefix("visual", prefix),
620
+ )
621
+
622
+ # TODO: make it more elegant
623
+ if language_model_cls is Qwen3LLMModel:
624
+ self.config: Qwen3VLConfig = config # for qwen3-vl
625
+ else:
626
+ self.config = config.text_config # for qwen3-omni
627
+
628
+ self.model = language_model_cls(
629
+ config=self.config,
630
+ quant_config=quant_config,
631
+ prefix=add_prefix("model", prefix),
632
+ )
633
+
634
+ if self.config.tie_word_embeddings:
635
+ self.lm_head = self.model.embed_tokens
636
+ else:
637
+ self.lm_head = ParallelLMHead(
638
+ self.config.vocab_size,
639
+ self.config.hidden_size,
640
+ quant_config=quant_config,
641
+ prefix=add_prefix("lm_head", prefix),
642
+ )
643
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
644
+
645
+ self.logits_processor = LogitsProcessor(self.config)
646
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
647
+ # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
648
+ # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
649
+
650
+ # deepstack
651
+ self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
652
+ self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
653
+ self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
654
+
655
+ def separate_deepstack_embeds(self, embedding):
656
+ assert (
657
+ embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
658
+ ), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
659
+
660
+ separate_index = self.config.hidden_size
661
+ input_embeds = embedding[:, :separate_index]
662
+ input_deepstack_embeds = embedding[:, separate_index:]
663
+ return input_embeds, input_deepstack_embeds
664
+
665
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
666
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
667
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
668
+
669
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
670
+ # in qwen-vl, last dim is the same
671
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
672
+ self.visual.dtype
673
+ )
674
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
675
+ assert pixel_values.dim() == 2, pixel_values.dim()
676
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
677
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
678
+ return image_embeds
679
+
680
+ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
681
+ # in qwen-vl, last dim is the same
682
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
683
+ self.visual.dtype
684
+ )
685
+ video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
686
+ assert pixel_values.dim() == 2, pixel_values.dim()
687
+ assert video_grid_thw.dim() == 2, video_grid_thw.dim()
688
+ video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
689
+ return video_embeds
690
+
691
+ def get_input_embeddings(self):
692
+ return self.model.embed_tokens
693
+
694
+ def forward(
695
+ self,
696
+ input_ids: torch.Tensor,
697
+ positions: torch.Tensor,
698
+ forward_batch: ForwardBatch,
699
+ get_embedding: bool = False,
700
+ ):
701
+ """Run forward pass for Qwen3-VL.
702
+
703
+ Args:
704
+ input_ids: Flattened (concatenated) input_ids corresponding to a
705
+ batch.
706
+ positions: Flattened (concatenated) position ids corresponding to a
707
+ batch.
708
+ **NOTE**: If mrope is enabled (default setting for Qwen2-VL
709
+ opensource models), the shape will be `(3, seq_len)`,
710
+ otherwise it will be `(seq_len,).
711
+ (Use input_metadata.mrope_positions to replace it)
712
+ """
713
+ if self.is_mrope_enabled:
714
+ positions = forward_batch.mrope_positions
715
+
716
+ if not (
717
+ forward_batch.forward_mode.is_decode()
718
+ or not forward_batch.contains_image_inputs()
719
+ ):
720
+ if self.is_mrope_enabled:
721
+ assert positions.ndim == 2 and positions.size(0) == 3, (
722
+ "multimodal section rotary embedding requires "
723
+ f"(3, seq_len) positions, but got {positions.size()}"
724
+ )
725
+
726
+ hidden_states = general_mm_embed_routine(
727
+ input_ids=input_ids,
728
+ forward_batch=forward_batch,
729
+ language_model=self.model,
730
+ multimodal_model=self,
731
+ positions=positions,
732
+ use_deepstack=self.use_deepstack,
733
+ )
734
+
735
+ if not get_embedding:
736
+ return self.logits_processor(
737
+ input_ids, hidden_states, self.lm_head, forward_batch
738
+ )
739
+ else:
740
+ return self.pooler(hidden_states, forward_batch)
741
+
742
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
743
+ stacked_params_mapping = [
744
+ # (param_name, shard_name, shard_id)
745
+ (".qkv_proj", ".q_proj", "q"),
746
+ (".qkv_proj", ".k_proj", "k"),
747
+ (".qkv_proj", ".v_proj", "v"),
748
+ ("gate_up_proj", "up_proj", 1),
749
+ ("gate_up_proj", "gate_proj", 0),
750
+ ]
751
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
752
+ for name, loaded_weight in weights:
753
+ if "rotary_emb.inv_freq" in name:
754
+ continue
755
+ if "language_model" in name:
756
+ name = name.replace(r"model.language_model.", r"model.")
757
+
758
+ for param_name, weight_name, shard_id in stacked_params_mapping:
759
+ if weight_name not in name:
760
+ continue
761
+ if "visual" in name:
762
+ continue
763
+ name = name.replace(weight_name, param_name)
764
+
765
+ # Skip loading extra bias for GPTQ models.
766
+ if name.endswith(".bias") and name not in params_dict:
767
+ continue
768
+ param = params_dict[name]
769
+ weight_loader = param.weight_loader
770
+ weight_loader(param, loaded_weight, shard_id)
771
+ break
772
+ else:
773
+ if "visual" in name:
774
+ # adapt to VisionAttention
775
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
776
+ name = name.replace(r"model.visual.", r"visual.")
777
+
778
+ try:
779
+ # Skip loading extra bias for GPTQ models.
780
+ if name.endswith(".bias") and name not in params_dict:
781
+ continue
782
+ param = params_dict[name]
783
+ except KeyError:
784
+ print(params_dict.keys())
785
+ raise
786
+
787
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
788
+ weight_loader(param, loaded_weight)
789
+
790
+
791
+ EntryClass = Qwen3VLForConditionalGeneration