sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ import copy
2
+ from typing import Iterable, List, Optional, Set, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.managers.mm_utils import (
11
+ MultiModalityDataPaddingPatternMultimodalTokens,
12
+ general_mm_embed_routine,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import (
15
+ Modality,
16
+ MultimodalDataItem,
17
+ MultimodalInputs,
18
+ )
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
21
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
22
+ from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
23
+ from sglang.srt.utils import add_prefix
24
+
25
+
26
+ class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
27
+ def __init__(
28
+ self,
29
+ vision_config: POINTSV15ChatConfig,
30
+ norm_eps: float = 1e-6,
31
+ quant_config: Optional[QuantizationConfig] = None,
32
+ prefix: str = "",
33
+ ) -> None:
34
+ super().__init__(
35
+ vision_config,
36
+ norm_eps=norm_eps,
37
+ quant_config=quant_config,
38
+ prefix=prefix,
39
+ )
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ grid_thw: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ # patchify
47
+ x = x.to(device=self.device, dtype=self.dtype)
48
+ x = self.patch_embed(x)
49
+
50
+ # compute position embedding
51
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
52
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
53
+ position_embeddings = (emb.cos(), emb.sin())
54
+
55
+ # compute cu_seqlens
56
+ cu_seqlens = torch.repeat_interleave(
57
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
58
+ ).cumsum(dim=0, dtype=torch.int32)
59
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
60
+
61
+ # transformers
62
+ x = x.unsqueeze(1)
63
+ for blk in self.blocks:
64
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
65
+
66
+ return x
67
+
68
+
69
+ class POINTSV15ChatModel(nn.Module):
70
+ def __init__(
71
+ self,
72
+ config: POINTSV15ChatConfig,
73
+ quant_config: Optional[QuantizationConfig] = None,
74
+ prefix: str = "",
75
+ **kwargs,
76
+ ) -> None:
77
+ super().__init__()
78
+ config.llm_config._attn_implementation = "flash_attention_2"
79
+ config._attn_implementation_autoset = False
80
+ self.config = config
81
+ self.quant_config = quant_config
82
+
83
+ llm_config = copy.deepcopy(config.llm_config)
84
+ llm_config.architectures = ["Qwen2ForCausalLM"]
85
+ self.llm = Qwen2ForCausalLM(
86
+ config=llm_config,
87
+ quant_config=quant_config,
88
+ prefix=add_prefix("llm", prefix),
89
+ )
90
+
91
+ self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
92
+ config.vision_config,
93
+ quant_config=quant_config,
94
+ prefix=add_prefix("vision_encoder", prefix),
95
+ )
96
+
97
+ self.vision_projector = Qwen2VisionPatchMerger(
98
+ d_model=config.llm_config.hidden_size,
99
+ context_dim=1280,
100
+ quant_config=quant_config,
101
+ prefix=add_prefix("vision_projector", prefix),
102
+ )
103
+
104
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
105
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
106
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
107
+
108
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
109
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
110
+ self.vision_encoder.dtype
111
+ )
112
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
113
+
114
+ assert pixel_values.dim() == 2, pixel_values.dim()
115
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
116
+
117
+ image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
118
+ image_features = self.vision_projector(image_features)
119
+ return image_features
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.Tensor,
124
+ positions: torch.Tensor,
125
+ forward_batch: ForwardBatch,
126
+ get_embedding: bool = False,
127
+ ):
128
+ hidden_states = general_mm_embed_routine(
129
+ input_ids=input_ids,
130
+ forward_batch=forward_batch,
131
+ language_model=self.llm,
132
+ data_embedding_funcs={
133
+ Modality.IMAGE: self.get_image_feature,
134
+ },
135
+ positions=positions,
136
+ )
137
+
138
+ return hidden_states
139
+
140
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
141
+ stacked_params_mapping = [
142
+ # (param_name, shard_name, shard_id)
143
+ ("qkv_proj", "q_proj", "q"),
144
+ ("qkv_proj", "k_proj", "k"),
145
+ ("qkv_proj", "v_proj", "v"),
146
+ ("gate_up_proj", "gate_proj", 0),
147
+ ("gate_up_proj", "up_proj", 1),
148
+ ]
149
+ params_dict = dict(self.named_parameters())
150
+ loaded_params: Set[str] = set()
151
+
152
+ for name, loaded_weight in weights:
153
+ if "rotary_emb.inv_freq" in name:
154
+ continue
155
+
156
+ for param_name, weight_name, shard_id in stacked_params_mapping:
157
+ if weight_name not in name:
158
+ continue
159
+ name = name.replace(weight_name, param_name)
160
+
161
+ if name.endswith(".bias") and name not in params_dict:
162
+ continue
163
+
164
+ param = params_dict[name]
165
+ weight_loader = param.weight_loader
166
+ weight_loader(param, loaded_weight, shard_id)
167
+ break
168
+ else:
169
+ if "vision_encoder" in name:
170
+ # adapt to VisionAttention
171
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
172
+
173
+ try:
174
+ # Skip loading extra bias for GPTQ models.
175
+ if name.endswith(".bias") and name not in params_dict:
176
+ continue
177
+ param = params_dict[name]
178
+ except KeyError:
179
+ print(params_dict.keys())
180
+ raise
181
+
182
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
183
+ weight_loader(param, loaded_weight)
184
+
185
+
186
+ EntryClass = [POINTSV15ChatModel]
sglang/srt/models/qwen.py CHANGED
@@ -15,7 +15,6 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
17
 
18
- import time
19
18
  from typing import Any, Dict, Iterable, Optional, Tuple
20
19
 
21
20
  import torch
@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
59
59
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
60
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
61
  from sglang.srt.models.qwen2 import Qwen2Model
62
+ from sglang.srt.models.utils import permute_inv
62
63
  from sglang.srt.utils import add_prefix
63
64
  from sglang.srt.utils.hf_transformers_utils import get_processor
64
65
 
@@ -405,6 +406,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
405
406
 
406
407
  # Move window_index to the same device as x before using it to index x
407
408
  window_index = window_index.to(device=x.device)
409
+ reverse_indices = permute_inv(window_index)
408
410
 
409
411
  # Ensure rotary_pos_emb is on the same device/dtype as x
410
412
  rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
@@ -436,7 +438,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
436
438
  .to(device=x.device, dtype=torch.int32),
437
439
  ]
438
440
  )
439
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
441
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
440
442
 
441
443
  # transformers
442
444
  x = x.unsqueeze(1)
@@ -451,8 +453,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
451
453
 
452
454
  # adapter
453
455
  x = self.merger(x)
454
-
455
- reverse_indices = torch.argsort(window_index)
456
456
  x = x[reverse_indices, :]
457
457
 
458
458
  return x
@@ -23,30 +23,18 @@
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
25
  import logging
26
- import math
27
- from functools import lru_cache, partial
28
- from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
26
+ from typing import Any, Iterable, List, Optional, Tuple
29
27
 
30
28
  import torch
31
29
  import torch.nn as nn
32
- import torch.nn.functional as F
33
- from einops import rearrange
34
- from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35
- from transformers.activations import ACT2FN
30
+ from transformers import Qwen2AudioEncoderConfig, Qwen2Config
36
31
  from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37
32
  from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38
33
  Qwen2AudioEncoder,
39
34
  Qwen2AudioMultiModalProjector,
40
35
  )
41
36
 
42
- from sglang.srt.layers.activation import QuickGELU
43
- from sglang.srt.layers.attention.vision import VisionAttention
44
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
45
- from sglang.srt.layers.logits_processor import LogitsProcessor
46
- from sglang.srt.layers.pooler import Pooler, PoolingType
47
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
- from sglang.srt.layers.utils import get_layer_id
49
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
38
  from sglang.srt.managers.mm_utils import (
51
39
  MultiModalityDataPaddingPatternMultimodalTokens,
52
40
  general_mm_embed_routine,
@@ -60,7 +48,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
60
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
49
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
62
50
  from sglang.srt.utils import add_prefix
63
- from sglang.srt.utils.hf_transformers_utils import get_processor
64
51
 
65
52
  logger = logging.getLogger(__name__)
66
53
 
@@ -17,6 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
+ from contextlib import nullcontext
20
21
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
22
 
22
23
  import torch
@@ -64,10 +65,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
65
  ParallelLMHead,
65
66
  VocabParallelEmbedding,
66
67
  )
67
- from sglang.srt.managers.schedule_batch import global_server_args_dict
68
68
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
69
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
70
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
71
+ from sglang.srt.server_args import get_global_server_args
71
72
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
72
73
  from sglang.srt.utils import add_prefix, is_cuda, make_layers
73
74
 
@@ -156,7 +157,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
156
157
  layer_id=self.layer_id,
157
158
  top_k=config.num_experts_per_tok,
158
159
  num_experts=config.num_experts
159
- + global_server_args_dict["ep_num_redundant_experts"],
160
+ + get_global_server_args().ep_num_redundant_experts,
160
161
  hidden_size=config.hidden_size,
161
162
  intermediate_size=config.moe_intermediate_size,
162
163
  quant_config=quant_config,
@@ -192,7 +193,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
192
193
  # TODO: we will support tp < ep in the future
193
194
  self.ep_size = get_moe_expert_parallel_world_size()
194
195
  self.num_experts = (
195
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
196
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
196
197
  )
197
198
  self.top_k = config.num_experts_per_tok
198
199
 
@@ -219,7 +220,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
219
220
  # router_logits: (num_tokens, n_experts)
220
221
  router_logits, _ = self.gate(hidden_states)
221
222
  shared_output = self._forward_shared_experts(hidden_states)
222
- topk_weights, topk_idx, _ = self.topk(
223
+ topk_output = self.topk(
223
224
  hidden_states,
224
225
  router_logits,
225
226
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -228,14 +229,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
228
229
  ),
229
230
  )
230
231
  else:
231
- topk_weights, topk_idx, _ = self.topk.empty_topk_output(
232
- hidden_states.device
233
- )
232
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
234
233
  final_hidden_states = self.experts(
235
234
  hidden_states=hidden_states,
236
- topk_idx=topk_idx,
237
- topk_weights=topk_weights,
238
- forward_batch=forward_batch,
235
+ topk_output=topk_output,
239
236
  )
240
237
 
241
238
  if shared_output is not None:
@@ -518,6 +515,7 @@ class Qwen2MoeModel(nn.Module):
518
515
  ) -> None:
519
516
  super().__init__()
520
517
  self.config = config
518
+
521
519
  self.padding_idx = config.pad_token_id
522
520
  self.vocab_size = config.vocab_size
523
521
  self.pp_group = get_pp_group()
@@ -593,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
593
591
  if residual is not None
594
592
  else hidden_states
595
593
  )
596
- with get_global_expert_distribution_recorder().with_current_layer(i):
594
+ ctx = (
595
+ nullcontext()
596
+ if get_global_server_args().enable_piecewise_cuda_graph
597
+ else get_global_expert_distribution_recorder().with_current_layer(i)
598
+ )
599
+ with ctx:
597
600
  layer = self.layers[i]
598
601
  hidden_states, residual = layer(
599
602
  positions, hidden_states, forward_batch, residual
@@ -643,7 +646,7 @@ class Qwen2MoeForCausalLM(nn.Module):
643
646
  config.hidden_size,
644
647
  quant_config=quant_config,
645
648
  prefix=add_prefix("lm_head", prefix),
646
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
649
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
647
650
  )
648
651
  self.logits_processor = LogitsProcessor(config)
649
652
  # For EAGLE3 support
@@ -28,7 +28,6 @@ from typing import Iterable, List, Optional, Tuple, Type, TypedDict
28
28
 
29
29
  import torch
30
30
  import torch.nn as nn
31
- import torch.nn.functional as F
32
31
  from einops import rearrange
33
32
  from transformers import Qwen2VLConfig
34
33
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
@@ -407,7 +406,7 @@ class Qwen2VisionTransformer(nn.Module):
407
406
  cu_seqlens = torch.repeat_interleave(
408
407
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
409
408
  ).cumsum(dim=0, dtype=torch.int32)
410
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
409
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
411
410
 
412
411
  # transformers
413
412
  x = x.unsqueeze(1)
@@ -514,6 +513,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
514
513
  def get_input_embeddings(self):
515
514
  return self.model.embed_tokens
516
515
 
516
+ def should_apply_lora(self, module_name: str) -> bool:
517
+ # skip visual tower
518
+ return not module_name.startswith("visual")
519
+
517
520
  def forward(
518
521
  self,
519
522
  input_ids: torch.Tensor,
@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
54
54
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
55
55
  from sglang.srt.layers.utils import get_layer_id
56
56
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
57
- from sglang.srt.managers.schedule_batch import global_server_args_dict
58
57
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
59
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
60
59
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
64
63
  create_fused_set_kv_buffer_arg,
65
64
  enable_fused_set_kv_buffer,
66
65
  )
66
+ from sglang.srt.server_args import get_global_server_args
67
67
  from sglang.srt.utils import (
68
68
  add_prefix,
69
69
  is_cuda,
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
104
104
 
105
105
  self.experts = get_moe_impl_class(quant_config)(
106
106
  num_experts=config.num_experts
107
- + global_server_args_dict["ep_num_redundant_experts"],
107
+ + get_global_server_args().ep_num_redundant_experts,
108
108
  top_k=config.num_experts_per_tok,
109
109
  layer_id=layer_id,
110
110
  hidden_size=config.hidden_size,
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
125
125
  # TODO: we will support tp < ep in the future
126
126
  self.ep_size = get_moe_expert_parallel_world_size()
127
127
  self.num_experts = (
128
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
128
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
129
129
  )
130
130
  self.top_k = config.num_experts_per_tok
131
131
 
@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
180
180
  if hidden_states.shape[0] > 0:
181
181
  # router_logits: (num_tokens, n_experts)
182
182
  router_logits, _ = self.gate(hidden_states)
183
- topk_weights, topk_idx, _ = self.topk(
183
+ topk_output = self.topk(
184
184
  hidden_states,
185
185
  router_logits,
186
186
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
189
189
  ),
190
190
  )
191
191
  else:
192
- topk_idx = torch.full(
193
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
194
- )
195
- topk_weights = torch.empty(
196
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
197
- )
192
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
198
193
  final_hidden_states = self.experts(
199
194
  hidden_states=hidden_states,
200
- topk_idx=topk_idx,
201
- topk_weights=topk_weights,
202
- forward_batch=forward_batch,
195
+ topk_output=topk_output,
203
196
  )
204
197
  return final_hidden_states
205
198
 
@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
219
212
  with get_global_expert_distribution_recorder().with_current_layer(
220
213
  self.layer_id
221
214
  ):
222
- state.topk_weights_local, state.topk_idx_local, _ = self.topk(
215
+ state.topk_output = self.topk(
223
216
  hidden_states=hidden_states,
224
217
  router_logits=router_logits,
225
218
  num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
228
221
  ),
229
222
  )
230
223
  else:
231
- state.topk_idx_local = torch.full(
232
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
233
- )
234
- state.topk_weights_local = torch.empty(
235
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
236
- )
224
+ state.topk_output = self.topk.empty_topk_output(hidden_states.device)
237
225
 
238
226
  def op_dispatch_a(self, state):
239
227
  if self.ep_size > 1:
240
- self.experts.deepep_dispatcher.dispatch_a(
228
+ self.experts.dispatcher.dispatch_a(
241
229
  hidden_states=state.pop("hidden_states_mlp_input"),
242
- topk_idx=state.pop("topk_idx_local"),
243
- topk_weights=state.pop("topk_weights_local"),
244
- forward_batch=state.forward_batch,
230
+ topk_output=state.pop("topk_output"),
245
231
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
246
232
  )
247
233
 
@@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
250
236
  with get_global_expert_distribution_recorder().with_current_layer(
251
237
  self.layer_id
252
238
  ):
253
- state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
239
+ state.dispatch_output = self.experts.dispatcher.dispatch_b(
254
240
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
255
241
  )
256
242
 
257
243
  def op_experts(self, state):
258
- state.hidden_states_experts_output = self.experts.moe_impl(
244
+ state.hidden_states_experts_output = self.experts.run_moe_core(
259
245
  dispatch_output=state.dispatch_output,
260
246
  )
261
247
 
262
248
  def op_combine_a(self, state):
263
249
  if self.ep_size > 1:
264
- self.experts.deepep_dispatcher.combine_a(
250
+ self.experts.dispatcher.combine_a(
265
251
  hidden_states=state.pop("hidden_states_experts_output"),
266
- topk_idx=state.dispatch_output.topk_idx,
252
+ topk_ids=state.dispatch_output.topk_ids,
267
253
  topk_weights=state.dispatch_output.topk_weights,
268
- forward_batch=state.forward_batch,
269
254
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
270
255
  )
271
256
  state.pop("dispatch_output")
272
257
 
273
258
  def op_combine_b(self, state):
274
259
  if self.ep_size > 1:
275
- state.hidden_states_after_combine = (
276
- self.experts.deepep_dispatcher.combine_b(
277
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
278
- )
260
+ state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
261
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
279
262
  )
280
263
 
281
264
  def op_output(self, state):
@@ -661,13 +644,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
661
644
  config: Qwen3MoeConfig,
662
645
  quant_config: Optional[QuantizationConfig] = None,
663
646
  prefix: str = "",
647
+ decoder_layer_type=Qwen3MoeDecoderLayer,
664
648
  ) -> None:
665
649
  alt_stream = torch.cuda.Stream() if _is_cuda else None
666
650
  super().__init__(
667
651
  config=config,
668
652
  quant_config=quant_config,
669
653
  prefix=prefix,
670
- decoder_layer_type=Qwen3MoeDecoderLayer,
654
+ decoder_layer_type=decoder_layer_type,
671
655
  alt_stream=alt_stream,
672
656
  )
673
657
 
@@ -693,7 +677,7 @@ class Qwen3MoeForCausalLM(nn.Module):
693
677
  config.hidden_size,
694
678
  quant_config=quant_config,
695
679
  prefix=add_prefix("lm_head", prefix),
696
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
680
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
697
681
  )
698
682
  self.logits_processor = LogitsProcessor(config)
699
683
  self.capture_aux_hidden_states = False
@@ -1,18 +1,12 @@
1
1
  import enum
2
2
  import logging
3
- from typing import Any, Dict, Iterable, Optional, Set, Tuple
3
+ from typing import Any, Iterable, Optional, Set, Tuple
4
4
 
5
5
  import torch
6
- import torch.nn.functional as F
7
6
  from torch import nn
8
7
 
9
8
  from sglang.srt.configs.qwen3_next import Qwen3NextConfig
10
- from sglang.srt.distributed import (
11
- divide,
12
- get_pp_group,
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
9
+ from sglang.srt.distributed import divide, get_pp_group
16
10
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
17
11
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
18
12
  from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
@@ -23,10 +17,9 @@ from sglang.srt.layers.dp_attention import (
23
17
  get_attention_tp_size,
24
18
  is_dp_attention_enabled,
25
19
  )
26
- from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
20
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
27
21
  from sglang.srt.layers.linear import (
28
22
  ColumnParallelLinear,
29
- MergedColumnParallelLinear,
30
23
  QKVParallelLinear,
31
24
  RowParallelLinear,
32
25
  )
@@ -39,7 +32,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
32
  ParallelLMHead,
40
33
  VocabParallelEmbedding,
41
34
  )
42
- from sglang.srt.managers.schedule_batch import global_server_args_dict
43
35
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
44
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
37
  from sglang.srt.model_loader.weight_utils import (
@@ -47,6 +39,7 @@ from sglang.srt.model_loader.weight_utils import (
47
39
  sharded_weight_loader,
48
40
  )
49
41
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
42
+ from sglang.srt.server_args import get_global_server_args
50
43
  from sglang.srt.utils import (
51
44
  LazyValue,
52
45
  add_prefix,
@@ -527,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
527
520
  config=config,
528
521
  quant_config=quant_config,
529
522
  alt_stream=alt_stream,
523
+ prefix=add_prefix("mlp", prefix),
530
524
  )
531
525
  else:
532
526
  self.mlp = Qwen2MoeMLP(
@@ -680,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
680
674
  config=config,
681
675
  quant_config=quant_config,
682
676
  alt_stream=alt_stream,
677
+ prefix=add_prefix("mlp", prefix),
683
678
  )
684
679
  else:
685
680
  self.mlp = Qwen2MoeMLP(
@@ -905,7 +900,7 @@ class Qwen3NextForCausalLM(nn.Module):
905
900
  quant_config=quant_config,
906
901
  org_num_embeddings=config.vocab_size,
907
902
  prefix=add_prefix("lm_head", prefix),
908
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
903
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
909
904
  )
910
905
  self.lm_head = self.lm_head.float()
911
906
  self.logits_processor = LogitsProcessor(config)
@@ -21,14 +21,13 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
- from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
24
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
25
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
26
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
27
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
28
- from sglang.srt.managers.schedule_batch import global_server_args_dict
29
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
- from sglang.srt.models.qwen3_moe import Qwen3MoeModel
31
29
  from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
30
+ from sglang.srt.server_args import get_global_server_args
32
31
  from sglang.srt.utils import add_prefix
33
32
 
34
33
  logger = logging.getLogger(__name__)
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
69
68
  config.hidden_size,
70
69
  quant_config=quant_config,
71
70
  prefix=add_prefix("model.shared_head.head", prefix),
72
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
71
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
73
72
  )
74
73
  self.logits_processor = LogitsProcessor(config)
75
74