sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,11 @@
15
15
  """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
16
  import logging
17
17
  from functools import lru_cache, partial
18
- from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
18
+ from typing import Callable, Iterable, List, Optional, Tuple, Union
19
19
 
20
20
  import numpy as np
21
21
  import torch
22
22
  import torch.nn as nn
23
- import torch.nn.functional as F
24
23
  from einops import rearrange
25
24
  from transformers.activations import ACT2FN
26
25
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
@@ -38,16 +37,20 @@ from sglang.srt.managers.mm_utils import (
38
37
  MultiModalityDataPaddingPatternMultimodalTokens,
39
38
  general_mm_embed_routine,
40
39
  )
41
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
40
+ from sglang.srt.managers.schedule_batch import (
41
+ Modality,
42
+ MultimodalDataItem,
43
+ MultimodalInputs,
44
+ )
42
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
46
  from sglang.srt.model_loader.weight_utils import default_weight_loader
44
- from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
45
47
  from sglang.srt.models.qwen3 import Qwen3Model
46
48
  from sglang.srt.utils import add_prefix
47
49
  from sglang.srt.utils.hf_transformers_utils import get_processor
48
50
 
49
51
  logger = logging.getLogger(__name__)
50
52
 
53
+
51
54
  # === Vision Encoder === #
52
55
 
53
56
 
@@ -189,14 +192,14 @@ class Qwen3_VisionBlock(nn.Module):
189
192
  position_embeddings=position_embeddings,
190
193
  )
191
194
  attn = rearrange(attn, "b s ... -> s b ...")
192
- x = x + attn
195
+ x += attn
193
196
  norm2 = self.norm2(x)
194
197
  mlp = self.mlp(norm2)
195
- x = x + mlp
198
+ x += mlp
196
199
  return x
197
200
 
198
201
 
199
- class Qwen3_VisionPatchMerger(nn.Module):
202
+ class Qwen3VLMoeVisionPatchMerger(nn.Module):
200
203
 
201
204
  def __init__(
202
205
  self,
@@ -246,7 +249,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
246
249
  return out
247
250
 
248
251
 
249
- class Qwen3_VisionTransformer(nn.Module):
252
+ class Qwen3VLMoeVisionModel(nn.Module):
250
253
 
251
254
  def __init__(
252
255
  self,
@@ -263,10 +266,10 @@ class Qwen3_VisionTransformer(nn.Module):
263
266
  self.spatial_merge_size = vision_config.spatial_merge_size
264
267
  self.spatial_merge_unit = self.spatial_merge_size**2
265
268
  self.temporal_patch_size = vision_config.temporal_patch_size
269
+ # layer indexes of which layer's output should be deep-stacked
266
270
  self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
267
271
  self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
268
272
  self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
269
-
270
273
  norm_layer = partial(nn.LayerNorm, eps=norm_eps)
271
274
  head_dim = self.hidden_size // self.num_heads
272
275
  self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
@@ -286,7 +289,7 @@ class Qwen3_VisionTransformer(nn.Module):
286
289
  for layer_idx in range(vision_config.depth)
287
290
  ]
288
291
  )
289
- self.merger = Qwen3_VisionPatchMerger(
292
+ self.merger = Qwen3VLMoeVisionPatchMerger(
290
293
  dim=vision_config.out_hidden_size,
291
294
  context_dim=self.hidden_size,
292
295
  norm_layer=norm_layer,
@@ -297,7 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
297
300
 
298
301
  self.deepstack_merger_list = nn.ModuleList(
299
302
  [
300
- Qwen3_VisionPatchMerger(
303
+ Qwen3VLMoeVisionPatchMerger(
301
304
  dim=vision_config.out_hidden_size,
302
305
  context_dim=self.hidden_size,
303
306
  spatial_merge_size=self.spatial_merge_size,
@@ -441,7 +444,7 @@ class Qwen3_VisionTransformer(nn.Module):
441
444
  x = self.patch_embed(x)
442
445
 
443
446
  pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
444
- x = x + pos_embeds
447
+ x += pos_embeds
445
448
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
446
449
 
447
450
  seq_len, _ = x.size()
@@ -452,15 +455,16 @@ class Qwen3_VisionTransformer(nn.Module):
452
455
  position_embeddings = (emb.cos(), emb.sin())
453
456
 
454
457
  # compute cu_seqlens
458
+ cu_seqlens = torch.repeat_interleave(
459
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
460
+ ).cumsum(dim=0)
455
461
  cu_seqlens = torch.cat(
456
462
  [
457
- torch.tensor([0], device=grid_thw.device),
458
- (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
463
+ torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
464
+ cu_seqlens.to(torch.int32),
459
465
  ]
460
466
  )
461
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
462
467
 
463
- # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
464
468
  x = x.unsqueeze(1)
465
469
 
466
470
  deepstack_feature_lists = []
@@ -574,10 +578,7 @@ class Qwen3LLMModel(Qwen3Model):
574
578
  and layer_idx in self.deepstack_embed_to_decoder_layer
575
579
  ):
576
580
  sep = self.hidden_size * layer_idx
577
- hidden_states = (
578
- hidden_states
579
- + input_deepstack_embeds[:, sep : sep + self.hidden_size]
580
- )
581
+ hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size]
581
582
 
582
583
  if not self.pp_group.is_last_rank:
583
584
  return PPProxyTensors(
@@ -605,37 +606,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
605
606
  config: Qwen3VLConfig,
606
607
  quant_config: Optional[QuantizationConfig] = None,
607
608
  prefix: str = "",
609
+ language_model_cls=Qwen3LLMModel,
608
610
  ) -> None:
609
611
  super().__init__()
610
612
 
611
- self.config = config
612
- self.visual = Qwen3_VisionTransformer(
613
+ self.visual = Qwen3VLMoeVisionModel(
613
614
  config.vision_config,
614
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
615
615
  # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
616
616
  # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
617
617
  quant_config=quant_config,
618
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
618
619
  prefix=add_prefix("visual", prefix),
619
620
  )
620
621
 
621
- self.model = Qwen3LLMModel(
622
- config=config,
622
+ # TODO: make it more elegant
623
+ if language_model_cls is Qwen3LLMModel:
624
+ self.config: Qwen3VLConfig = config # for qwen3-vl
625
+ else:
626
+ self.config = config.text_config # for qwen3-omni
627
+
628
+ self.model = language_model_cls(
629
+ config=self.config,
623
630
  quant_config=quant_config,
624
631
  prefix=add_prefix("model", prefix),
625
632
  )
626
633
 
627
- if config.tie_word_embeddings:
634
+ if self.config.tie_word_embeddings:
628
635
  self.lm_head = self.model.embed_tokens
629
636
  else:
630
637
  self.lm_head = ParallelLMHead(
631
- config.vocab_size,
632
- config.hidden_size,
638
+ self.config.vocab_size,
639
+ self.config.hidden_size,
633
640
  quant_config=quant_config,
634
641
  prefix=add_prefix("lm_head", prefix),
635
642
  )
636
643
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
637
644
 
638
- self.logits_processor = LogitsProcessor(config)
645
+ self.logits_processor = LogitsProcessor(self.config)
639
646
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
640
647
  # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
641
648
  # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
@@ -643,10 +650,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
643
650
  # deepstack
644
651
  self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
645
652
  self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
646
-
647
- @property
648
- def use_deepstack(self) -> bool:
649
- return hasattr(self, "deepstack_visual_indexes")
653
+ self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
650
654
 
651
655
  def separate_deepstack_embeds(self, embedding):
652
656
  assert (
@@ -14,49 +14,23 @@
14
14
  # ==============================================================================
15
15
  """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
16
  import logging
17
- from functools import lru_cache, partial
18
- from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
17
+ from functools import lru_cache
18
+ from typing import Iterable, Optional, Tuple, Union
19
19
 
20
- import numpy as np
21
20
  import torch
22
21
  import torch.nn as nn
23
- import torch.nn.functional as F
24
- from einops import rearrange
25
- from transformers import BatchFeature
26
- from transformers.activations import ACT2FN
27
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
28
- Qwen2_5_VisionRotaryEmbedding,
29
- )
30
22
 
31
- from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
23
+ from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
32
24
  from sglang.srt.distributed import (
33
25
  get_moe_expert_parallel_world_size,
34
- get_pp_group,
35
26
  get_tensor_model_parallel_rank,
36
27
  )
37
- from sglang.srt.layers.logits_processor import LogitsProcessor
38
28
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
39
- from sglang.srt.layers.pooler import Pooler, PoolingType
40
29
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
- from sglang.srt.layers.utils import get_layer_id
42
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
43
- from sglang.srt.managers.mm_utils import (
44
- MultiModalityDataPaddingPatternMultimodalTokens,
45
- general_mm_embed_routine,
46
- )
47
- from sglang.srt.managers.schedule_batch import (
48
- MultimodalDataItem,
49
- MultimodalInputs,
50
- global_server_args_dict,
51
- )
52
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
53
31
  from sglang.srt.model_loader.weight_utils import default_weight_loader
54
- from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
55
- from sglang.srt.models.qwen3_vl import (
56
- Qwen3_VisionTransformer,
57
- Qwen3VLForConditionalGeneration,
58
- )
59
- from sglang.srt.utils import add_prefix
32
+ from sglang.srt.models.qwen3_moe import Qwen3MoeModel
33
+ from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
60
34
  from sglang.srt.utils.hf_transformers_utils import get_processor
61
35
 
62
36
  logger = logging.getLogger(__name__)
@@ -68,28 +42,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
68
42
  def __init__(
69
43
  self,
70
44
  *,
71
- config: Qwen3VLMoeConfig,
45
+ config: Qwen3VLMoeTextConfig,
72
46
  quant_config: Optional[QuantizationConfig] = None,
73
47
  prefix: str = "",
74
48
  ):
75
49
  super().__init__(config=config, quant_config=quant_config, prefix=prefix)
76
-
77
50
  self.hidden_size = config.hidden_size
78
51
 
79
52
  def get_input_embeddings(self) -> nn.Embedding:
80
53
  return self.embed_tokens
81
54
 
82
- def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
83
- # in qwen-vl, last dim is the same
84
- pixel_values = torch.cat([item.feature for item in items], dim=0).type(
85
- self.visual.dtype
86
- )
87
- image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
88
- assert pixel_values.dim() == 2, pixel_values.dim()
89
- assert image_grid_thw.dim() == 2, image_grid_thw.dim()
90
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
91
- return image_embeds
92
-
93
55
  def forward(
94
56
  self,
95
57
  input_ids: torch.Tensor,
@@ -114,7 +76,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
114
76
  for layer_idx, layer in enumerate(
115
77
  self.layers[self.start_layer : self.end_layer]
116
78
  ):
117
- layer_idx = layer_idx + self.start_layer
79
+ layer_idx += self.start_layer
118
80
  if layer_idx in self.layers_to_capture:
119
81
  aux_hidden_states.append(
120
82
  hidden_states + residual if residual is not None else hidden_states
@@ -128,11 +90,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
128
90
  )
129
91
 
130
92
  # process deepstack
131
- if input_deepstack_embeds is not None and layer_idx in range(3):
93
+ if input_deepstack_embeds is not None and layer_idx < 3:
132
94
  sep = self.hidden_size * layer_idx
133
- hidden_states = (
134
- hidden_states
135
- + input_deepstack_embeds[:, sep : sep + self.hidden_size]
95
+ hidden_states.add_(
96
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
136
97
  )
137
98
 
138
99
  if not self.pp_group.is_last_rank:
@@ -155,144 +116,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
155
116
  return hidden_states, aux_hidden_states
156
117
 
157
118
 
158
- class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
159
- def __init__(
160
- self,
161
- *,
162
- config: Qwen3VLMoeConfig,
163
- quant_config: Optional[QuantizationConfig] = None,
164
- prefix: str = "",
165
- ):
166
- super(Qwen3VLForConditionalGeneration, self).__init__()
167
- self.config = config
168
-
169
- self.visual = Qwen3_VisionTransformer(
170
- config.vision_config,
171
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
172
- # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
173
- # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
174
- quant_config=quant_config,
175
- prefix=add_prefix("visual", prefix),
176
- )
177
-
178
- self.model = Qwen3MoeLLMModel(
179
- config=config,
180
- quant_config=quant_config,
181
- prefix=add_prefix("model", prefix),
182
- )
183
-
184
- if config.tie_word_embeddings:
185
- self.lm_head = self.model.embed_tokens
186
- else:
187
- self.lm_head = ParallelLMHead(
188
- config.vocab_size,
189
- config.hidden_size,
190
- quant_config=quant_config,
191
- prefix=add_prefix("lm_head", prefix),
119
+ def load_fused_expert_weights(
120
+ name: str,
121
+ params_dict: dict,
122
+ loaded_weight: torch.Tensor,
123
+ shard_id: str,
124
+ num_experts: int,
125
+ ):
126
+ param = params_dict[name]
127
+ # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
128
+ weight_loader = param.weight_loader
129
+ ep_rank = get_tensor_model_parallel_rank()
130
+ ep_size = get_moe_expert_parallel_world_size()
131
+ if ep_size == 1:
132
+ for expert_id in range(num_experts):
133
+ curr_expert_weight = loaded_weight[expert_id]
134
+ weight_loader(
135
+ param,
136
+ curr_expert_weight,
137
+ name,
138
+ shard_id,
139
+ expert_id,
192
140
  )
193
- self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
194
-
195
- self.logits_processor = LogitsProcessor(config)
196
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
197
-
198
- # deepstack
199
- self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
200
- self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
201
-
202
- @property
203
- def use_deepstack(self) -> bool:
204
- return hasattr(self, "deepstack_visual_indexes")
205
-
206
- def forward(
207
- self,
208
- input_ids: torch.Tensor,
209
- positions: torch.Tensor,
210
- forward_batch: ForwardBatch,
211
- get_embedding: bool = False,
212
- ):
213
- """Run forward pass for Qwen3-VL.
214
-
215
- Args:
216
- input_ids: Flattened (concatenated) input_ids corresponding to a
217
- batch.
218
- positions: Flattened (concatenated) position ids corresponding to a
219
- batch.
220
- **NOTE**: If mrope is enabled (default setting for Qwen2-VL
221
- opensource models), the shape will be `(3, seq_len)`,
222
- otherwise it will be `(seq_len,).
223
- (Use input_metadata.mrope_positions to replace it)
224
- """
225
- if self.is_mrope_enabled:
226
- positions = forward_batch.mrope_positions
227
-
228
- if not (
229
- forward_batch.forward_mode.is_decode()
230
- or not forward_batch.contains_image_inputs()
231
- ):
232
- if self.is_mrope_enabled:
233
- assert positions.ndim == 2 and positions.size(0) == 3, (
234
- "multimodal section rotary embedding requires "
235
- f"(3, seq_len) positions, but got {positions.size()}"
236
- )
237
-
238
- hidden_states = general_mm_embed_routine(
239
- input_ids=input_ids,
240
- forward_batch=forward_batch,
241
- language_model=self.model,
242
- multimodal_model=self,
243
- positions=positions,
244
- use_deepstack=self.use_deepstack,
141
+ else:
142
+ experts_per_ep = num_experts // ep_size
143
+ start_expert = ep_rank * experts_per_ep
144
+ end_expert = (
145
+ (ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
245
146
  )
246
147
 
247
- if not get_embedding:
248
- return self.logits_processor(
249
- input_ids, hidden_states, self.lm_head, forward_batch
148
+ for idx, expert_id in enumerate(range(start_expert, end_expert)):
149
+ curr_expert_weight = loaded_weight[expert_id]
150
+ weight_loader(
151
+ param,
152
+ curr_expert_weight,
153
+ name,
154
+ shard_id,
155
+ idx,
250
156
  )
251
- else:
252
- return self.pooler(hidden_states, forward_batch)
157
+ return True
253
158
 
254
- def load_fused_expert_weights(
159
+
160
+ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
161
+ def __init__(
255
162
  self,
256
- name: str,
257
- params_dict: dict,
258
- loaded_weight: torch.Tensor,
259
- shard_id: str,
260
- num_experts: int,
163
+ config: Qwen3VLMoeConfig,
164
+ quant_config: Optional[QuantizationConfig] = None,
165
+ prefix: str = "",
166
+ language_model_cls=Qwen3MoeLLMModel,
261
167
  ):
262
- param = params_dict[name]
263
- # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
264
- weight_loader = param.weight_loader
265
- ep_rank = get_tensor_model_parallel_rank()
266
- ep_size = get_moe_expert_parallel_world_size()
267
- if ep_size == 1:
268
- for expert_id in range(num_experts):
269
- curr_expert_weight = loaded_weight[expert_id]
270
- weight_loader(
271
- param,
272
- curr_expert_weight,
273
- name,
274
- shard_id,
275
- expert_id,
276
- )
277
- else:
278
- experts_per_ep = num_experts // ep_size
279
- start_expert = ep_rank * experts_per_ep
280
- end_expert = (
281
- (ep_rank + 1) * experts_per_ep
282
- if ep_rank != ep_size - 1
283
- else num_experts
284
- )
285
-
286
- for idx, expert_id in enumerate(range(start_expert, end_expert)):
287
- curr_expert_weight = loaded_weight[expert_id]
288
- weight_loader(
289
- param,
290
- curr_expert_weight,
291
- name,
292
- shard_id,
293
- idx,
294
- )
295
- return True
168
+ super().__init__(config, quant_config, prefix, language_model_cls)
296
169
 
297
170
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
298
171
  stacked_params_mapping = [
@@ -338,8 +211,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
338
211
  self._cached_params_dict = dict(self.named_parameters())
339
212
  params_dict = self._cached_params_dict
340
213
  for name, loaded_weight in weights:
341
- if "language_model" in name:
342
- name = name.replace(r"model.language_model.", r"model.")
214
+ name = name.replace(r"model.language_model.", r"model.")
343
215
 
344
216
  for param_name, weight_name, shard_id in stacked_params_mapping:
345
217
  if "experts.gate_up_proj" in name or "experts.down_proj" in name:
@@ -393,14 +265,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
393
265
  loaded_weight = loaded_weight.transpose(-1, -2) # no bias
394
266
  if "experts.gate_up_proj" in name:
395
267
  loaded_weight = loaded_weight.chunk(2, dim=-2)
396
- self.load_fused_expert_weights(
268
+ load_fused_expert_weights(
397
269
  name_mapped,
398
270
  params_dict,
399
271
  loaded_weight[0],
400
272
  "w1",
401
273
  num_experts,
402
274
  )
403
- self.load_fused_expert_weights(
275
+ load_fused_expert_weights(
404
276
  name_mapped,
405
277
  params_dict,
406
278
  loaded_weight[1],
@@ -408,7 +280,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
408
280
  num_experts,
409
281
  )
410
282
  else:
411
- self.load_fused_expert_weights(
283
+ load_fused_expert_weights(
412
284
  name_mapped,
413
285
  params_dict,
414
286
  loaded_weight,
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import itertools
3
+ import os
4
4
  from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
@@ -8,10 +8,12 @@ from torch import nn
8
8
 
9
9
  from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
10
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
11
+ from sglang.srt.layers.sparse_pooler import SparsePooler
11
12
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
12
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
13
14
  from sglang.srt.model_loader.weight_utils import default_weight_loader
14
15
  from sglang.srt.models.bert import BertEncoder
16
+ from sglang.srt.utils.hf_transformers_utils import download_from_hf
15
17
 
16
18
  RobertaConfig = None
17
19
 
@@ -206,12 +208,29 @@ class XLMRobertaModel(nn.Module):
206
208
  config: RobertaConfig,
207
209
  quant_config: Optional[QuantizationConfig] = None,
208
210
  prefix: str = "",
211
+ sparse_head: Optional[str] = None,
212
+ model_path: Optional[str] = None,
209
213
  ):
210
214
  super().__init__()
211
215
  self.roberta = XLMRobertaBaseModel(
212
216
  config=config, quant_config=quant_config, prefix=prefix
213
217
  )
214
- self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
218
+ if sparse_head is not None:
219
+ self._is_sparse = True
220
+ self._model_path = model_path
221
+ self._sparse_head = sparse_head
222
+ self.pooler = SparsePooler(config=config)
223
+ # Zero out special tokens
224
+ self._special_tokens = [
225
+ config.bos_token_id,
226
+ config.eos_token_id,
227
+ config.pad_token_id,
228
+ # self.config.unk_token_id # not available in the XLMRobertaConfig
229
+ ]
230
+ self._special_tokens = [t for t in self._special_tokens if t is not None]
231
+ else:
232
+ self._is_sparse = False
233
+ self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
215
234
 
216
235
  def forward(
217
236
  self,
@@ -224,11 +243,44 @@ class XLMRobertaModel(nn.Module):
224
243
  hidden_states = self.roberta(
225
244
  input_ids, positions, forward_batch, input_embeds, get_embedding
226
245
  )
227
- return self.pooler(hidden_states, forward_batch)
246
+ embeddings = self.pooler(hidden_states, forward_batch)
247
+
248
+ if self._is_sparse:
249
+ for token_id in self._special_tokens:
250
+ embeddings.embeddings[:, token_id] = 0.0
251
+ embeddings.embeddings = embeddings.embeddings.to_sparse()
252
+
253
+ return embeddings
228
254
 
229
255
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
230
256
  self.roberta.load_weights(weights)
231
257
 
258
+ if self._is_sparse:
259
+ sparse_dict = XLMRobertaModel._load_sparse_linear(
260
+ self._model_path, self._sparse_head
261
+ )
262
+ self.pooler.load_weights(sparse_dict)
263
+
264
+ @staticmethod
265
+ def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
266
+ """
267
+ Load sparse_head from local dir or HF Hub.
268
+ Returns a state_dict suitable for nn.Linear.load_state_dict().
269
+ """
270
+ if os.path.isdir(model_path_or_dir):
271
+ path = os.path.join(model_path_or_dir, sparse_head)
272
+ if not os.path.exists(path):
273
+ raise FileNotFoundError(
274
+ f"'{sparse_head}' not found in {model_path_or_dir}"
275
+ )
276
+ else:
277
+ # remote → use SGLang HF utility
278
+ local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
279
+ path = os.path.join(local_dir, sparse_head)
280
+
281
+ state_dict = torch.load(path)
282
+ return state_dict
283
+
232
284
 
233
285
  class XLMRobertaForSequenceClassification(nn.Module):
234
286
  def __init__(
@@ -17,7 +17,6 @@ import logging
17
17
  from typing import Iterable, List, Optional, Tuple
18
18
 
19
19
  import torch
20
- import torch.nn.functional as F
21
20
  from torch import nn
22
21
  from transformers import LlamaConfig
23
22
 
@@ -1,8 +1,7 @@
1
1
  import logging
2
2
  import math
3
- from collections.abc import Iterable
4
3
  from math import sqrt
5
- from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
4
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
5
 
7
6
  import torch
8
7
  from torch import nn
@@ -57,7 +56,6 @@ from sglang.srt.managers.schedule_batch import (
57
56
  Modality,
58
57
  MultimodalDataItem,
59
58
  MultimodalInputs,
60
- global_server_args_dict,
61
59
  )
62
60
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
63
61
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -300,7 +298,7 @@ class Step3TextDecoderLayer(nn.Module):
300
298
  # self.n_shared_experts = 1
301
299
  # self.num_fused_shared_experts = (
302
300
  # 0
303
- # if global_server_args_dict["disable_shared_experts_fusion"]
301
+ # if global_server_args.disable_shared_experts_fusion
304
302
  # else self.n_shared_experts
305
303
  # )
306
304
  self.num_fused_shared_experts = 0
@@ -774,7 +772,7 @@ class Step3VLForConditionalGeneration(nn.Module):
774
772
  # self.n_shared_experts = 1
775
773
  # self.num_fused_shared_experts = (
776
774
  # 0
777
- # if global_server_args_dict["disable_shared_experts_fusion"]
775
+ # if global_server_args.disable_shared_experts_fusion
778
776
  # else self.n_shared_experts
779
777
  # )
780
778
  self.num_fused_shared_experts = 0