sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -51,15 +51,19 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
51
51
  from sglang.srt.layers.moe.topk import TopK
52
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
53
  from sglang.srt.layers.radix_attention import RadixAttention
54
- from sglang.srt.layers.rotary_embedding import get_rope
54
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
55
55
  from sglang.srt.layers.utils import get_layer_id
56
56
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
57
- from sglang.srt.managers.schedule_batch import global_server_args_dict
58
57
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
59
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
60
59
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
60
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
62
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
62
+ from sglang.srt.models.utils import (
63
+ create_fused_set_kv_buffer_arg,
64
+ enable_fused_set_kv_buffer,
65
+ )
66
+ from sglang.srt.server_args import get_global_server_args
63
67
  from sglang.srt.utils import (
64
68
  add_prefix,
65
69
  is_cuda,
@@ -100,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
100
104
 
101
105
  self.experts = get_moe_impl_class(quant_config)(
102
106
  num_experts=config.num_experts
103
- + global_server_args_dict["ep_num_redundant_experts"],
107
+ + get_global_server_args().ep_num_redundant_experts,
104
108
  top_k=config.num_experts_per_tok,
105
109
  layer_id=layer_id,
106
110
  hidden_size=config.hidden_size,
@@ -121,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
121
125
  # TODO: we will support tp < ep in the future
122
126
  self.ep_size = get_moe_expert_parallel_world_size()
123
127
  self.num_experts = (
124
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
128
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
125
129
  )
126
130
  self.top_k = config.num_experts_per_tok
127
131
 
@@ -176,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
176
180
  if hidden_states.shape[0] > 0:
177
181
  # router_logits: (num_tokens, n_experts)
178
182
  router_logits, _ = self.gate(hidden_states)
179
- topk_weights, topk_idx, _ = self.topk(
183
+ topk_output = self.topk(
180
184
  hidden_states,
181
185
  router_logits,
182
186
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -185,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
185
189
  ),
186
190
  )
187
191
  else:
188
- topk_idx = torch.full(
189
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
190
- )
191
- topk_weights = torch.empty(
192
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
193
- )
192
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
194
193
  final_hidden_states = self.experts(
195
194
  hidden_states=hidden_states,
196
- topk_idx=topk_idx,
197
- topk_weights=topk_weights,
198
- forward_batch=forward_batch,
195
+ topk_output=topk_output,
199
196
  )
200
197
  return final_hidden_states
201
198
 
@@ -215,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
215
212
  with get_global_expert_distribution_recorder().with_current_layer(
216
213
  self.layer_id
217
214
  ):
218
- state.topk_weights_local, state.topk_idx_local, _ = self.topk(
215
+ state.topk_output = self.topk(
219
216
  hidden_states=hidden_states,
220
217
  router_logits=router_logits,
221
218
  num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -224,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
224
221
  ),
225
222
  )
226
223
  else:
227
- state.topk_idx_local = torch.full(
228
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
229
- )
230
- state.topk_weights_local = torch.empty(
231
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
232
- )
224
+ state.topk_output = self.topk.empty_topk_output(hidden_states.device)
233
225
 
234
226
  def op_dispatch_a(self, state):
235
227
  if self.ep_size > 1:
236
- self.experts.deepep_dispatcher.dispatch_a(
228
+ self.experts.dispatcher.dispatch_a(
237
229
  hidden_states=state.pop("hidden_states_mlp_input"),
238
- topk_idx=state.pop("topk_idx_local"),
239
- topk_weights=state.pop("topk_weights_local"),
240
- forward_batch=state.forward_batch,
230
+ topk_output=state.pop("topk_output"),
241
231
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
242
232
  )
243
233
 
@@ -246,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
246
236
  with get_global_expert_distribution_recorder().with_current_layer(
247
237
  self.layer_id
248
238
  ):
249
- state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
239
+ state.dispatch_output = self.experts.dispatcher.dispatch_b(
250
240
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
251
241
  )
252
242
 
253
243
  def op_experts(self, state):
254
- state.hidden_states_experts_output = self.experts.moe_impl(
244
+ state.hidden_states_experts_output = self.experts.run_moe_core(
255
245
  dispatch_output=state.dispatch_output,
256
246
  )
257
247
 
258
248
  def op_combine_a(self, state):
259
249
  if self.ep_size > 1:
260
- self.experts.deepep_dispatcher.combine_a(
250
+ self.experts.dispatcher.combine_a(
261
251
  hidden_states=state.pop("hidden_states_experts_output"),
262
- topk_idx=state.dispatch_output.topk_idx,
252
+ topk_ids=state.dispatch_output.topk_ids,
263
253
  topk_weights=state.dispatch_output.topk_weights,
264
- forward_batch=state.forward_batch,
265
254
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
266
255
  )
267
256
  state.pop("dispatch_output")
268
257
 
269
258
  def op_combine_b(self, state):
270
259
  if self.ep_size > 1:
271
- state.hidden_states_after_combine = (
272
- self.experts.deepep_dispatcher.combine_b(
273
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
274
- )
260
+ state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
261
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
275
262
  )
276
263
 
277
264
  def op_output(self, state):
@@ -354,6 +341,10 @@ class Qwen3MoeAttention(nn.Module):
354
341
  rope_scaling=rope_scaling,
355
342
  dual_chunk_attention_config=dual_chunk_attention_config,
356
343
  )
344
+ self.compatible_with_fused_kv_buffer = (
345
+ False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
346
+ )
347
+
357
348
  self.attn = RadixAttention(
358
349
  self.num_heads,
359
350
  self.head_dim,
@@ -412,7 +403,21 @@ class Qwen3MoeAttention(nn.Module):
412
403
  qkv, _ = self.qkv_proj(hidden_states)
413
404
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
414
405
  q, k = self._apply_qk_norm(q, k)
415
- q, k = self.rotary_emb(positions, q, k)
406
+ q, k = self.rotary_emb(
407
+ positions,
408
+ q,
409
+ k,
410
+ fused_set_kv_buffer_arg=(
411
+ create_fused_set_kv_buffer_arg(
412
+ value=v,
413
+ layer=self.attn,
414
+ forward_batch=forward_batch,
415
+ )
416
+ if enable_fused_set_kv_buffer(forward_batch)
417
+ and self.compatible_with_fused_kv_buffer
418
+ else None
419
+ ),
420
+ )
416
421
  inner_state = q, k, v, forward_batch
417
422
  return None, forward_batch, inner_state
418
423
 
@@ -420,7 +425,13 @@ class Qwen3MoeAttention(nn.Module):
420
425
  hidden_states, forward_batch, inner_state = intermediate_state
421
426
  if inner_state is None:
422
427
  return hidden_states
423
- attn_output = self.attn(*inner_state)
428
+ attn_output = self.attn(
429
+ *inner_state,
430
+ save_kv_cache=not (
431
+ enable_fused_set_kv_buffer(forward_batch)
432
+ and self.compatible_with_fused_kv_buffer
433
+ ),
434
+ )
424
435
  output, _ = self.o_proj(attn_output)
425
436
  return output
426
437
 
@@ -633,13 +644,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
633
644
  config: Qwen3MoeConfig,
634
645
  quant_config: Optional[QuantizationConfig] = None,
635
646
  prefix: str = "",
647
+ decoder_layer_type=Qwen3MoeDecoderLayer,
636
648
  ) -> None:
637
649
  alt_stream = torch.cuda.Stream() if _is_cuda else None
638
650
  super().__init__(
639
651
  config=config,
640
652
  quant_config=quant_config,
641
653
  prefix=prefix,
642
- decoder_layer_type=Qwen3MoeDecoderLayer,
654
+ decoder_layer_type=decoder_layer_type,
643
655
  alt_stream=alt_stream,
644
656
  )
645
657
 
@@ -665,7 +677,7 @@ class Qwen3MoeForCausalLM(nn.Module):
665
677
  config.hidden_size,
666
678
  quant_config=quant_config,
667
679
  prefix=add_prefix("lm_head", prefix),
668
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
680
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
669
681
  )
670
682
  self.logits_processor = LogitsProcessor(config)
671
683
  self.capture_aux_hidden_states = False
@@ -1,18 +1,13 @@
1
1
  import enum
2
2
  import logging
3
- from typing import Any, Dict, Iterable, Optional, Set, Tuple
3
+ from typing import Any, Iterable, Optional, Set, Tuple
4
4
 
5
5
  import torch
6
- import torch.nn.functional as F
7
6
  from torch import nn
8
7
 
9
8
  from sglang.srt.configs.qwen3_next import Qwen3NextConfig
10
- from sglang.srt.distributed import (
11
- divide,
12
- get_pp_group,
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
9
+ from sglang.srt.distributed import divide, get_pp_group
10
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
16
11
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
17
12
  from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
18
13
  from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
@@ -22,10 +17,9 @@ from sglang.srt.layers.dp_attention import (
22
17
  get_attention_tp_size,
23
18
  is_dp_attention_enabled,
24
19
  )
25
- from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
20
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
26
21
  from sglang.srt.layers.linear import (
27
22
  ColumnParallelLinear,
28
- MergedColumnParallelLinear,
29
23
  QKVParallelLinear,
30
24
  RowParallelLinear,
31
25
  )
@@ -38,7 +32,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
32
  ParallelLMHead,
39
33
  VocabParallelEmbedding,
40
34
  )
41
- from sglang.srt.managers.schedule_batch import global_server_args_dict
42
35
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
43
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
37
  from sglang.srt.model_loader.weight_utils import (
@@ -46,7 +39,15 @@ from sglang.srt.model_loader.weight_utils import (
46
39
  sharded_weight_loader,
47
40
  )
48
41
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
49
- from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
42
+ from sglang.srt.server_args import get_global_server_args
43
+ from sglang.srt.utils import (
44
+ LazyValue,
45
+ add_prefix,
46
+ is_cuda,
47
+ is_npu,
48
+ make_layers,
49
+ set_weight_attrs,
50
+ )
50
51
 
51
52
  logger = logging.getLogger(__name__)
52
53
  _is_cuda = is_cuda()
@@ -239,6 +240,7 @@ class Qwen3GatedDeltaNet(nn.Module):
239
240
  self,
240
241
  config: Qwen3NextConfig,
241
242
  layer_id: int,
243
+ quant_config: Optional[QuantizationConfig] = None,
242
244
  alt_stream: Optional[torch.cuda.Stream] = None,
243
245
  ) -> None:
244
246
  super().__init__()
@@ -278,6 +280,7 @@ class Qwen3GatedDeltaNet(nn.Module):
278
280
  input_size=self.hidden_size,
279
281
  output_size=projection_size_qkvz,
280
282
  bias=False,
283
+ quant_config=quant_config,
281
284
  tp_rank=self.attn_tp_rank,
282
285
  tp_size=self.attn_tp_size,
283
286
  )
@@ -285,6 +288,7 @@ class Qwen3GatedDeltaNet(nn.Module):
285
288
  input_size=self.hidden_size,
286
289
  output_size=projection_size_ba,
287
290
  bias=False,
291
+ quant_config=None,
288
292
  tp_rank=self.attn_tp_rank,
289
293
  tp_size=self.attn_tp_size,
290
294
  )
@@ -336,6 +340,7 @@ class Qwen3GatedDeltaNet(nn.Module):
336
340
  self.value_dim,
337
341
  self.hidden_size,
338
342
  bias=False,
343
+ quant_config=quant_config,
339
344
  input_is_parallel=True,
340
345
  reduce_results=False,
341
346
  tp_rank=self.attn_tp_rank,
@@ -493,7 +498,9 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
493
498
  ) -> None:
494
499
  super().__init__()
495
500
  self.config = config
496
- self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
501
+ self.linear_attn = Qwen3GatedDeltaNet(
502
+ config, layer_id, quant_config, alt_stream
503
+ )
497
504
 
498
505
  # Qwen3Next all layers are sparse and have no nextn now
499
506
  self.is_layer_sparse = True
@@ -513,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
513
520
  config=config,
514
521
  quant_config=quant_config,
515
522
  alt_stream=alt_stream,
523
+ prefix=add_prefix("mlp", prefix),
516
524
  )
517
525
  else:
518
526
  self.mlp = Qwen2MoeMLP(
@@ -666,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
666
674
  config=config,
667
675
  quant_config=quant_config,
668
676
  alt_stream=alt_stream,
677
+ prefix=add_prefix("mlp", prefix),
669
678
  )
670
679
  else:
671
680
  self.mlp = Qwen2MoeMLP(
@@ -843,13 +852,14 @@ class Qwen3NextModel(nn.Module):
843
852
  residual = None
844
853
  for i in range(len(self.layers)):
845
854
  layer = self.layers[i]
846
- hidden_states, residual = layer(
847
- layer_id=i,
848
- positions=positions,
849
- hidden_states=hidden_states,
850
- residual=residual,
851
- forward_batch=forward_batch,
852
- )
855
+ with get_global_expert_distribution_recorder().with_current_layer(i):
856
+ hidden_states, residual = layer(
857
+ layer_id=i,
858
+ positions=positions,
859
+ hidden_states=hidden_states,
860
+ residual=residual,
861
+ forward_batch=forward_batch,
862
+ )
853
863
 
854
864
  if not forward_batch.forward_mode.is_idle():
855
865
  if residual is None:
@@ -890,11 +900,23 @@ class Qwen3NextForCausalLM(nn.Module):
890
900
  quant_config=quant_config,
891
901
  org_num_embeddings=config.vocab_size,
892
902
  prefix=add_prefix("lm_head", prefix),
893
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
903
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
894
904
  )
895
905
  self.lm_head = self.lm_head.float()
896
906
  self.logits_processor = LogitsProcessor(config)
897
907
 
908
+ self._routed_experts_weights_of_layer = LazyValue(
909
+ lambda: {
910
+ layer_id: layer.mlp.get_moe_weights()
911
+ for layer_id, layer in enumerate(self.model.layers)
912
+ if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
913
+ }
914
+ )
915
+
916
+ @property
917
+ def routed_experts_weights_of_layer(self):
918
+ return self._routed_experts_weights_of_layer.value
919
+
898
920
  @torch.no_grad()
899
921
  def forward(
900
922
  self,
@@ -21,14 +21,13 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
- from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
24
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
25
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
26
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
27
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
28
- from sglang.srt.managers.schedule_batch import global_server_args_dict
29
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
- from sglang.srt.models.qwen3_moe import Qwen3MoeModel
31
29
  from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
30
+ from sglang.srt.server_args import get_global_server_args
32
31
  from sglang.srt.utils import add_prefix
33
32
 
34
33
  logger = logging.getLogger(__name__)
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
69
68
  config.hidden_size,
70
69
  quant_config=quant_config,
71
70
  prefix=add_prefix("model.shared_head.head", prefix),
72
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
71
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
73
72
  )
74
73
  self.logits_processor = LogitsProcessor(config)
75
74