sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
  from sglang.srt.layers.radix_attention import RadixAttention
20
20
  from sglang.srt.layers.rotary_embedding import get_rope
21
- from sglang.srt.layers.utils import PPMissingLayer
22
21
  from sglang.srt.layers.vocab_parallel_embedding import (
23
22
  DEFAULT_VOCAB_PADDING_SIZE,
24
23
  ParallelLMHead,
@@ -16,13 +16,10 @@
16
16
  Using mistral-community/pixtral-12b as reference.
17
17
  """
18
18
 
19
- import logging
20
- import math
21
19
  from typing import Iterable, List, Optional, Set, Tuple, Union
22
20
 
23
21
  import torch
24
22
  import torch.nn as nn
25
- import torch.nn.functional as F
26
23
  from transformers import PixtralVisionConfig, PretrainedConfig
27
24
  from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
28
25
  from transformers.models.pixtral.modeling_pixtral import (
@@ -0,0 +1,186 @@
1
+ import copy
2
+ from typing import Iterable, List, Optional, Set, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.managers.mm_utils import (
11
+ MultiModalityDataPaddingPatternMultimodalTokens,
12
+ general_mm_embed_routine,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import (
15
+ Modality,
16
+ MultimodalDataItem,
17
+ MultimodalInputs,
18
+ )
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
21
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
22
+ from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
23
+ from sglang.srt.utils import add_prefix
24
+
25
+
26
+ class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
27
+ def __init__(
28
+ self,
29
+ vision_config: POINTSV15ChatConfig,
30
+ norm_eps: float = 1e-6,
31
+ quant_config: Optional[QuantizationConfig] = None,
32
+ prefix: str = "",
33
+ ) -> None:
34
+ super().__init__(
35
+ vision_config,
36
+ norm_eps=norm_eps,
37
+ quant_config=quant_config,
38
+ prefix=prefix,
39
+ )
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ grid_thw: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ # patchify
47
+ x = x.to(device=self.device, dtype=self.dtype)
48
+ x = self.patch_embed(x)
49
+
50
+ # compute position embedding
51
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
52
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
53
+ position_embeddings = (emb.cos(), emb.sin())
54
+
55
+ # compute cu_seqlens
56
+ cu_seqlens = torch.repeat_interleave(
57
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
58
+ ).cumsum(dim=0, dtype=torch.int32)
59
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
60
+
61
+ # transformers
62
+ x = x.unsqueeze(1)
63
+ for blk in self.blocks:
64
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
65
+
66
+ return x
67
+
68
+
69
+ class POINTSV15ChatModel(nn.Module):
70
+ def __init__(
71
+ self,
72
+ config: POINTSV15ChatConfig,
73
+ quant_config: Optional[QuantizationConfig] = None,
74
+ prefix: str = "",
75
+ **kwargs,
76
+ ) -> None:
77
+ super().__init__()
78
+ config.llm_config._attn_implementation = "flash_attention_2"
79
+ config._attn_implementation_autoset = False
80
+ self.config = config
81
+ self.quant_config = quant_config
82
+
83
+ llm_config = copy.deepcopy(config.llm_config)
84
+ llm_config.architectures = ["Qwen2ForCausalLM"]
85
+ self.llm = Qwen2ForCausalLM(
86
+ config=llm_config,
87
+ quant_config=quant_config,
88
+ prefix=add_prefix("llm", prefix),
89
+ )
90
+
91
+ self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
92
+ config.vision_config,
93
+ quant_config=quant_config,
94
+ prefix=add_prefix("vision_encoder", prefix),
95
+ )
96
+
97
+ self.vision_projector = Qwen2VisionPatchMerger(
98
+ d_model=config.llm_config.hidden_size,
99
+ context_dim=1280,
100
+ quant_config=quant_config,
101
+ prefix=add_prefix("vision_projector", prefix),
102
+ )
103
+
104
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
105
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
106
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
107
+
108
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
109
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
110
+ self.vision_encoder.dtype
111
+ )
112
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
113
+
114
+ assert pixel_values.dim() == 2, pixel_values.dim()
115
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
116
+
117
+ image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
118
+ image_features = self.vision_projector(image_features)
119
+ return image_features
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.Tensor,
124
+ positions: torch.Tensor,
125
+ forward_batch: ForwardBatch,
126
+ get_embedding: bool = False,
127
+ ):
128
+ hidden_states = general_mm_embed_routine(
129
+ input_ids=input_ids,
130
+ forward_batch=forward_batch,
131
+ language_model=self.llm,
132
+ data_embedding_funcs={
133
+ Modality.IMAGE: self.get_image_feature,
134
+ },
135
+ positions=positions,
136
+ )
137
+
138
+ return hidden_states
139
+
140
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
141
+ stacked_params_mapping = [
142
+ # (param_name, shard_name, shard_id)
143
+ ("qkv_proj", "q_proj", "q"),
144
+ ("qkv_proj", "k_proj", "k"),
145
+ ("qkv_proj", "v_proj", "v"),
146
+ ("gate_up_proj", "gate_proj", 0),
147
+ ("gate_up_proj", "up_proj", 1),
148
+ ]
149
+ params_dict = dict(self.named_parameters())
150
+ loaded_params: Set[str] = set()
151
+
152
+ for name, loaded_weight in weights:
153
+ if "rotary_emb.inv_freq" in name:
154
+ continue
155
+
156
+ for param_name, weight_name, shard_id in stacked_params_mapping:
157
+ if weight_name not in name:
158
+ continue
159
+ name = name.replace(weight_name, param_name)
160
+
161
+ if name.endswith(".bias") and name not in params_dict:
162
+ continue
163
+
164
+ param = params_dict[name]
165
+ weight_loader = param.weight_loader
166
+ weight_loader(param, loaded_weight, shard_id)
167
+ break
168
+ else:
169
+ if "vision_encoder" in name:
170
+ # adapt to VisionAttention
171
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
172
+
173
+ try:
174
+ # Skip loading extra bias for GPTQ models.
175
+ if name.endswith(".bias") and name not in params_dict:
176
+ continue
177
+ param = params_dict[name]
178
+ except KeyError:
179
+ print(params_dict.keys())
180
+ raise
181
+
182
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
183
+ weight_loader(param, loaded_weight)
184
+
185
+
186
+ EntryClass = [POINTSV15ChatModel]
sglang/srt/models/qwen.py CHANGED
@@ -15,7 +15,6 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
17
 
18
- import time
19
18
  from typing import Any, Dict, Iterable, Optional, Tuple
20
19
 
21
20
  import torch
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
50
50
  kv_cache_scales_loader,
51
51
  )
52
+ from sglang.srt.server_args import get_global_server_args
52
53
  from sglang.srt.utils import add_prefix, make_layers
53
54
 
54
55
  Qwen2Config = None
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
89
90
  self.act_fn = SiluAndMul()
90
91
 
91
92
  def forward(self, x):
93
+ if get_global_server_args().rl_on_policy_target == "fsdp":
94
+ x = x.bfloat16()
95
+
92
96
  gate_up, _ = self.gate_up_proj(x)
93
97
  x = self.act_fn(gate_up)
94
98
  x, _ = self.down_proj(x)
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
275
279
  quant_config=quant_config,
276
280
  enable_tp=not is_dp_attention_enabled(),
277
281
  prefix=add_prefix("embed_tokens", prefix),
282
+ params_dtype=(
283
+ torch.float32
284
+ if get_global_server_args().rl_on_policy_target == "fsdp"
285
+ else None
286
+ ),
278
287
  )
279
288
  else:
280
289
  self.embed_tokens = PPMissingLayer()
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
295
304
  prefix=add_prefix("layers", prefix),
296
305
  )
297
306
  if self.pp_group.is_last_rank:
298
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
307
+ norm_kwargs = (
308
+ dict(
309
+ weight_dtype=torch.float32,
310
+ cast_x_before_out_mul=True,
311
+ override_orig_dtype=torch.float32,
312
+ fp32_residual=True,
313
+ )
314
+ if get_global_server_args().rl_on_policy_target == "fsdp"
315
+ else {}
316
+ )
317
+ self.norm = RMSNorm(
318
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
319
+ )
299
320
  else:
300
321
  self.norm = PPMissingLayer(return_tuple=True)
301
322
 
@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
59
59
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
60
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
61
  from sglang.srt.models.qwen2 import Qwen2Model
62
+ from sglang.srt.models.utils import permute_inv
62
63
  from sglang.srt.utils import add_prefix
63
64
  from sglang.srt.utils.hf_transformers_utils import get_processor
64
65
 
@@ -405,6 +406,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
405
406
 
406
407
  # Move window_index to the same device as x before using it to index x
407
408
  window_index = window_index.to(device=x.device)
409
+ reverse_indices = permute_inv(window_index)
408
410
 
409
411
  # Ensure rotary_pos_emb is on the same device/dtype as x
410
412
  rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
@@ -436,7 +438,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
436
438
  .to(device=x.device, dtype=torch.int32),
437
439
  ]
438
440
  )
439
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
441
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
440
442
 
441
443
  # transformers
442
444
  x = x.unsqueeze(1)
@@ -451,8 +453,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
451
453
 
452
454
  # adapter
453
455
  x = self.merger(x)
454
-
455
- reverse_indices = torch.argsort(window_index)
456
456
  x = x[reverse_indices, :]
457
457
 
458
458
  return x
@@ -23,30 +23,18 @@
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
25
  import logging
26
- import math
27
- from functools import lru_cache, partial
28
- from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
26
+ from typing import Any, Iterable, List, Optional, Tuple
29
27
 
30
28
  import torch
31
29
  import torch.nn as nn
32
- import torch.nn.functional as F
33
- from einops import rearrange
34
- from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35
- from transformers.activations import ACT2FN
30
+ from transformers import Qwen2AudioEncoderConfig, Qwen2Config
36
31
  from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37
32
  from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38
33
  Qwen2AudioEncoder,
39
34
  Qwen2AudioMultiModalProjector,
40
35
  )
41
36
 
42
- from sglang.srt.layers.activation import QuickGELU
43
- from sglang.srt.layers.attention.vision import VisionAttention
44
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
45
- from sglang.srt.layers.logits_processor import LogitsProcessor
46
- from sglang.srt.layers.pooler import Pooler, PoolingType
47
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
- from sglang.srt.layers.utils import get_layer_id
49
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
38
  from sglang.srt.managers.mm_utils import (
51
39
  MultiModalityDataPaddingPatternMultimodalTokens,
52
40
  general_mm_embed_routine,
@@ -60,7 +48,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
60
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
49
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
62
50
  from sglang.srt.utils import add_prefix
63
- from sglang.srt.utils.hf_transformers_utils import get_processor
64
51
 
65
52
  logger = logging.getLogger(__name__)
66
53
 
@@ -17,6 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
+ from contextlib import nullcontext
20
21
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
22
 
22
23
  import torch
@@ -64,10 +65,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
65
  ParallelLMHead,
65
66
  VocabParallelEmbedding,
66
67
  )
67
- from sglang.srt.managers.schedule_batch import global_server_args_dict
68
68
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
69
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
70
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
71
+ from sglang.srt.server_args import get_global_server_args
71
72
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
72
73
  from sglang.srt.utils import add_prefix, is_cuda, make_layers
73
74
 
@@ -156,7 +157,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
156
157
  layer_id=self.layer_id,
157
158
  top_k=config.num_experts_per_tok,
158
159
  num_experts=config.num_experts
159
- + global_server_args_dict["ep_num_redundant_experts"],
160
+ + get_global_server_args().ep_num_redundant_experts,
160
161
  hidden_size=config.hidden_size,
161
162
  intermediate_size=config.moe_intermediate_size,
162
163
  quant_config=quant_config,
@@ -192,7 +193,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
192
193
  # TODO: we will support tp < ep in the future
193
194
  self.ep_size = get_moe_expert_parallel_world_size()
194
195
  self.num_experts = (
195
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
196
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
196
197
  )
197
198
  self.top_k = config.num_experts_per_tok
198
199
 
@@ -219,7 +220,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
219
220
  # router_logits: (num_tokens, n_experts)
220
221
  router_logits, _ = self.gate(hidden_states)
221
222
  shared_output = self._forward_shared_experts(hidden_states)
222
- topk_weights, topk_idx, _ = self.topk(
223
+ topk_output = self.topk(
223
224
  hidden_states,
224
225
  router_logits,
225
226
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -228,14 +229,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
228
229
  ),
229
230
  )
230
231
  else:
231
- topk_weights, topk_idx, _ = self.topk.empty_topk_output(
232
- hidden_states.device
233
- )
232
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
234
233
  final_hidden_states = self.experts(
235
234
  hidden_states=hidden_states,
236
- topk_idx=topk_idx,
237
- topk_weights=topk_weights,
238
- forward_batch=forward_batch,
235
+ topk_output=topk_output,
239
236
  )
240
237
 
241
238
  if shared_output is not None:
@@ -518,6 +515,7 @@ class Qwen2MoeModel(nn.Module):
518
515
  ) -> None:
519
516
  super().__init__()
520
517
  self.config = config
518
+
521
519
  self.padding_idx = config.pad_token_id
522
520
  self.vocab_size = config.vocab_size
523
521
  self.pp_group = get_pp_group()
@@ -593,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
593
591
  if residual is not None
594
592
  else hidden_states
595
593
  )
596
- with get_global_expert_distribution_recorder().with_current_layer(i):
594
+ ctx = (
595
+ nullcontext()
596
+ if get_global_server_args().enable_piecewise_cuda_graph
597
+ else get_global_expert_distribution_recorder().with_current_layer(i)
598
+ )
599
+ with ctx:
597
600
  layer = self.layers[i]
598
601
  hidden_states, residual = layer(
599
602
  positions, hidden_states, forward_batch, residual
@@ -643,7 +646,7 @@ class Qwen2MoeForCausalLM(nn.Module):
643
646
  config.hidden_size,
644
647
  quant_config=quant_config,
645
648
  prefix=add_prefix("lm_head", prefix),
646
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
649
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
647
650
  )
648
651
  self.logits_processor = LogitsProcessor(config)
649
652
  # For EAGLE3 support
@@ -28,7 +28,6 @@ from typing import Iterable, List, Optional, Tuple, Type, TypedDict
28
28
 
29
29
  import torch
30
30
  import torch.nn as nn
31
- import torch.nn.functional as F
32
31
  from einops import rearrange
33
32
  from transformers import Qwen2VLConfig
34
33
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
@@ -407,7 +406,7 @@ class Qwen2VisionTransformer(nn.Module):
407
406
  cu_seqlens = torch.repeat_interleave(
408
407
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
409
408
  ).cumsum(dim=0, dtype=torch.int32)
410
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
409
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
411
410
 
412
411
  # transformers
413
412
  x = x.unsqueeze(1)
@@ -514,6 +513,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
514
513
  def get_input_embeddings(self):
515
514
  return self.model.embed_tokens
516
515
 
516
+ def should_apply_lora(self, module_name: str) -> bool:
517
+ # skip visual tower
518
+ return not module_name.startswith("visual")
519
+
517
520
  def forward(
518
521
  self,
519
522
  input_ids: torch.Tensor,
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
29
29
  )
30
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
31
31
  from sglang.srt.models.qwen2 import Qwen2Model
32
+ from sglang.srt.server_args import get_global_server_args
32
33
  from sglang.srt.utils import (
33
34
  add_prefix,
34
35
  get_cmo_stream,
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
88
89
  self.max_position_embeddings = max_position_embeddings
89
90
  self.tp_rank = get_tensor_model_parallel_rank()
90
91
 
91
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
+ norm_kwargs = (
93
+ dict(
94
+ weight_dtype=torch.float32,
95
+ cast_x_before_out_mul=True,
96
+ )
97
+ if get_global_server_args().rl_on_policy_target == "fsdp"
98
+ else {}
99
+ )
100
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
101
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
93
102
 
94
103
  self.qkv_proj = QKVParallelLinear(
95
104
  hidden_size,
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
158
167
  hidden_states: torch.Tensor,
159
168
  forward_batch: ForwardBatch,
160
169
  ) -> torch.Tensor:
170
+ if get_global_server_args().rl_on_policy_target == "fsdp":
171
+ hidden_states = hidden_states.bfloat16()
172
+
161
173
  qkv, _ = self.qkv_proj(hidden_states)
162
174
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
163
175
  q, k = self._apply_qk_norm(q, k)
164
176
  q, k = self.rotary_emb(positions, q, k)
177
+
178
+ if get_global_server_args().rl_on_policy_target == "fsdp":
179
+ q = q.to(torch.bfloat16)
180
+ k = k.to(torch.bfloat16)
181
+
165
182
  attn_output = self.attn(q, k, v, forward_batch)
166
183
  output, _ = self.o_proj(attn_output)
167
184
  return output
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
204
221
  quant_config=quant_config,
205
222
  prefix=add_prefix("mlp", prefix),
206
223
  )
207
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+
225
+ norm_kwargs = (
226
+ dict(
227
+ weight_dtype=torch.float32,
228
+ cast_x_before_out_mul=True,
229
+ override_orig_dtype=torch.float32,
230
+ fp32_residual=True,
231
+ )
232
+ if get_global_server_args().rl_on_policy_target == "fsdp"
233
+ else {}
234
+ )
235
+ self.input_layernorm = RMSNorm(
236
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
237
+ )
208
238
  self.post_attention_layernorm = RMSNorm(
209
- config.hidden_size, eps=config.rms_norm_eps
239
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
210
240
  )
211
241
 
212
242
  self.layer_scatter_modes = LayerScatterModes.init_new(