sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -24,18 +24,21 @@ import threading
24
24
  import time
25
25
  from collections import defaultdict
26
26
  from dataclasses import dataclass
27
- from typing import List, Optional, Tuple, Union
28
- from urllib.parse import urlparse
27
+ from typing import Callable, List, Optional, Tuple, Union
29
28
 
30
- import requests
31
29
  import torch
32
30
  import torch.distributed as dist
33
31
 
32
+ from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
34
33
  from sglang.srt.configs.device_config import DeviceConfig
35
34
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
36
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
35
+ from sglang.srt.configs.model_config import (
36
+ AttentionArch,
37
+ ModelConfig,
38
+ get_nsa_index_head_dim,
39
+ is_deepseek_nsa,
40
+ )
37
41
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
38
- from sglang.srt.connector import ConnectorType
39
42
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
40
43
  from sglang.srt.distributed import (
41
44
  get_pp_group,
@@ -45,8 +48,10 @@ from sglang.srt.distributed import (
45
48
  initialize_model_parallel,
46
49
  set_custom_all_reduce,
47
50
  set_mscclpp_all_reduce,
51
+ set_symm_mem_all_reduce,
48
52
  )
49
53
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
54
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
50
55
  from sglang.srt.eplb.eplb_manager import EPLBManager
51
56
  from sglang.srt.eplb.expert_distribution import (
52
57
  ExpertDistributionRecorder,
@@ -60,6 +65,11 @@ from sglang.srt.eplb.expert_location import (
60
65
  set_global_expert_location_metadata,
61
66
  )
62
67
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
68
+ from sglang.srt.layers import deep_gemm_wrapper
69
+ from sglang.srt.layers.attention.attention_registry import (
70
+ ATTENTION_BACKENDS,
71
+ attn_backend_wrapper,
72
+ )
63
73
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
64
74
  from sglang.srt.layers.dp_attention import (
65
75
  get_attention_tp_group,
@@ -67,18 +77,11 @@ from sglang.srt.layers.dp_attention import (
67
77
  initialize_dp_attention,
68
78
  )
69
79
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
70
- from sglang.srt.layers.quantization import (
71
- deep_gemm_wrapper,
72
- monkey_patch_isinstance_for_vllm_base_layer,
73
- )
80
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
74
81
  from sglang.srt.layers.sampler import Sampler
75
82
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
76
83
  from sglang.srt.lora.lora_manager import LoRAManager
77
84
  from sglang.srt.lora.lora_registry import LoRARef
78
- from sglang.srt.managers.schedule_batch import (
79
- GLOBAL_SERVER_ARGS_KEYS,
80
- global_server_args_dict,
81
- )
82
85
  from sglang.srt.mem_cache.allocator import (
83
86
  BaseTokenToKVPoolAllocator,
84
87
  PagedTokenToKVPoolAllocator,
@@ -94,6 +97,7 @@ from sglang.srt.mem_cache.memory_pool import (
94
97
  HybridReqToTokenPool,
95
98
  MHATokenToKVPool,
96
99
  MLATokenToKVPool,
100
+ NSATokenToKVPool,
97
101
  ReqToTokenPool,
98
102
  SWAKVPool,
99
103
  )
@@ -101,23 +105,23 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
101
105
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
102
106
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
103
107
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
108
+ from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
109
+ PiecewiseCudaGraphRunner,
110
+ )
104
111
  from sglang.srt.model_loader import get_model
105
112
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
106
- from sglang.srt.model_loader.utils import set_default_torch_dtype
107
- from sglang.srt.model_loader.weight_utils import default_weight_loader
108
- from sglang.srt.offloader import (
109
- create_offloader_from_server_args,
110
- get_offloader,
111
- set_offloader,
112
- )
113
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
114
- from sglang.srt.remote_instance_weight_loader_utils import (
113
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
115
114
  trigger_init_weights_send_group_for_remote_instance_request,
116
115
  )
116
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
117
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
117
118
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
118
- from sglang.srt.server_args import ServerArgs
119
+ from sglang.srt.server_args import (
120
+ ServerArgs,
121
+ get_global_server_args,
122
+ set_global_server_args_for_scheduler,
123
+ )
119
124
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
120
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
121
125
  from sglang.srt.utils import (
122
126
  MultiprocessingSerializer,
123
127
  cpu_has_amx_support,
@@ -127,7 +131,6 @@ from sglang.srt.utils import (
127
131
  get_bool_env_var,
128
132
  get_cpu_ids_by_node,
129
133
  init_custom_process_group,
130
- is_blackwell,
131
134
  is_fa3_default_architecture,
132
135
  is_flashinfer_available,
133
136
  is_hip,
@@ -135,19 +138,66 @@ from sglang.srt.utils import (
135
138
  is_no_spec_infer_or_topk_one,
136
139
  is_npu,
137
140
  is_sm100_supported,
141
+ log_info_on_rank0,
138
142
  monkey_patch_p2p_access_check,
139
143
  monkey_patch_vllm_gguf_config,
140
- parse_connector_type,
141
144
  set_cuda_arch,
145
+ slow_rank_detector,
146
+ xpu_has_xmx_support,
142
147
  )
148
+ from sglang.srt.utils.offloader import (
149
+ create_offloader_from_server_args,
150
+ get_offloader,
151
+ set_offloader,
152
+ )
153
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
154
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
143
155
  from sglang.srt.weight_sync.tensor_bucket import (
144
156
  FlattenedTensorBucket,
145
157
  FlattenedTensorMetadata,
146
158
  )
147
159
 
160
+ MLA_ATTENTION_BACKENDS = [
161
+ "aiter",
162
+ "flashinfer",
163
+ "fa3",
164
+ "fa4",
165
+ "triton",
166
+ "flashmla",
167
+ "cutlass_mla",
168
+ "trtllm_mla",
169
+ "ascend",
170
+ "nsa",
171
+ ]
172
+
173
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
174
+ "flashinfer",
175
+ "fa3",
176
+ "fa4",
177
+ "flashmla",
178
+ "cutlass_mla",
179
+ "trtllm_mla",
180
+ ]
181
+
182
+
183
+ def add_mla_attention_backend(backend_name):
184
+ if backend_name not in MLA_ATTENTION_BACKENDS:
185
+ MLA_ATTENTION_BACKENDS.append(backend_name)
186
+ logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
187
+
188
+
189
+ def add_chunked_prefix_cache_attention_backend(backend_name):
190
+ if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
191
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
192
+ logger.info(
193
+ f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
194
+ )
195
+
196
+
148
197
  _is_hip = is_hip()
149
198
  _is_npu = is_npu()
150
199
  _is_cpu_amx_available = cpu_has_amx_support()
200
+ _is_xpu_xmx_available = xpu_has_xmx_support()
151
201
 
152
202
  # Use a small KV cache pool size for tests in CI
153
203
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -155,8 +205,17 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
155
205
  # Detect stragger ranks in model loading
156
206
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
157
207
 
208
+ # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
209
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
210
+
158
211
  logger = logging.getLogger(__name__)
159
212
 
213
+ if _is_npu:
214
+ import torch_npu
215
+
216
+ torch.npu.config.allow_internal_format = True
217
+ torch_npu.npu.set_compile_mode(jit_compile=False)
218
+
160
219
 
161
220
  class RankZeroFilter(logging.Filter):
162
221
  """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
@@ -222,25 +281,21 @@ class ModelRunner:
222
281
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
223
282
  self.attention_chunk_size = model_config.attention_chunk_size
224
283
  self.forward_pass_id = 0
284
+ self.init_new_workspace = False
225
285
 
226
286
  # Apply the rank zero filter to logger
227
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
228
- logger.addFilter(RankZeroFilter(tp_rank == 0))
229
287
  if server_args.show_time_cost:
230
288
  enable_show_time_cost()
231
289
 
232
290
  # Model-specific adjustment
233
291
  self.model_specific_adjustment()
234
292
 
235
- # Global vars
236
- global_server_args_dict.update(
237
- {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
238
- | {
239
- # TODO it is indeed not a "server args"
240
- "use_mla_backend": self.use_mla_backend,
241
- "speculative_algorithm": self.spec_algorithm,
242
- }
243
- )
293
+ # Set the global server_args in the scheduler process
294
+ set_global_server_args_for_scheduler(server_args)
295
+ global_server_args = get_global_server_args()
296
+
297
+ # FIXME: hacky set `use_mla_backend`
298
+ global_server_args.use_mla_backend = self.use_mla_backend
244
299
 
245
300
  # Init OpenMP threads binding for CPU
246
301
  if self.device == "cpu":
@@ -252,6 +307,9 @@ class ModelRunner:
252
307
  # CPU offload
253
308
  set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
254
309
 
310
+ if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
311
+ slow_rank_detector.execute()
312
+
255
313
  # Update deep gemm configure
256
314
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
257
315
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -268,6 +326,26 @@ class ModelRunner:
268
326
  self._model_update_group = {}
269
327
  self._weights_send_group = {}
270
328
 
329
+ if (
330
+ self.server_args.enable_piecewise_cuda_graph
331
+ and self.can_run_piecewise_cuda_graph()
332
+ ):
333
+ self.attention_layers = []
334
+ for layer in self.model.model.layers:
335
+ if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
336
+ self.attention_layers.append(layer.self_attn.attn)
337
+ if len(self.attention_layers) < self.model_config.num_hidden_layers:
338
+ # TODO(yuwei): support Non-Standard GQA
339
+ log_info_on_rank0(
340
+ logger,
341
+ "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
342
+ )
343
+ self.piecewise_cuda_graph_runner = None
344
+ else:
345
+ self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
346
+ else:
347
+ self.piecewise_cuda_graph_runner = None
348
+
271
349
  def initialize(self, min_per_gpu_memory: float):
272
350
  server_args = self.server_args
273
351
 
@@ -302,6 +380,11 @@ class ModelRunner:
302
380
  )
303
381
  self.expert_location_updater = ExpertLocationUpdater()
304
382
 
383
+ (
384
+ ElasticEPStateManager.init(self.server_args)
385
+ if self.server_args.elastic_ep_backend
386
+ else None
387
+ )
305
388
  # Load the model
306
389
  self.sampler = Sampler()
307
390
  self.load_model()
@@ -316,25 +399,10 @@ class ModelRunner:
316
399
  if architectures and not any("Llama4" in arch for arch in architectures):
317
400
  self.is_hybrid = self.model_config.is_hybrid = True
318
401
 
319
- if self.is_hybrid_gdn:
320
- logger.warning("Hybrid GDN model detected, disable radix cache")
402
+ if config := self.mamba2_config:
403
+ class_name = config.__class__.__name__
404
+ logger.warning(f"{class_name} model detected, disable radix cache")
321
405
  self.server_args.disable_radix_cache = True
322
- self.server_args.attention_backend = "hybrid_linear_attn"
323
- if self.server_args.max_mamba_cache_size is None:
324
- if self.server_args.max_running_requests is not None:
325
- self.server_args.max_mamba_cache_size = (
326
- self.server_args.max_running_requests
327
- )
328
- else:
329
- self.server_args.max_mamba_cache_size = 512
330
- self.server_args.max_mamba_cache_size = (
331
- self.server_args.max_mamba_cache_size
332
- // (
333
- self.server_args.dp_size
334
- if self.server_args.enable_dp_attention
335
- else 1
336
- )
337
- )
338
406
 
339
407
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
340
408
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
@@ -365,7 +433,7 @@ class ModelRunner:
365
433
  # In layered loading, torchao may have been applied
366
434
  if not torchao_applied:
367
435
  apply_torchao_config_to_model(
368
- self.model, global_server_args_dict["torchao_config"]
436
+ self.model, get_global_server_args().torchao_config
369
437
  )
370
438
 
371
439
  # Apply torch TP if the model supports it
@@ -385,6 +453,12 @@ class ModelRunner:
385
453
  )
386
454
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
387
455
 
456
+ # Enable batch invariant mode
457
+ if server_args.enable_deterministic_inference:
458
+ from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
459
+
460
+ enable_batch_invariant_mode()
461
+
388
462
  # Init memory pool and attention backends
389
463
  self.init_memory_pool(
390
464
  min_per_gpu_memory,
@@ -439,6 +513,16 @@ class ModelRunner:
439
513
  )
440
514
  server_args.attention_backend = "torch_native"
441
515
 
516
+ if (
517
+ server_args.attention_backend == "intel_xpu"
518
+ and server_args.device == "xpu"
519
+ and not _is_xpu_xmx_available
520
+ ):
521
+ logger.info(
522
+ "The current platform does not support Intel XMX, will fallback to triton backend."
523
+ )
524
+ server_args.attention_backend = "triton"
525
+
442
526
  if server_args.prefill_attention_backend is not None and (
443
527
  server_args.prefill_attention_backend
444
528
  == server_args.decode_attention_backend
@@ -496,9 +580,7 @@ class ModelRunner:
496
580
  elif _is_hip:
497
581
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
498
582
  # TODO current aiter only support head number 16 or 128 head number
499
- if (
500
- head_num == 128 or head_num == 16
501
- ) and self.spec_algorithm.is_none():
583
+ if head_num == 128 or head_num == 16:
502
584
  server_args.attention_backend = "aiter"
503
585
  else:
504
586
  server_args.attention_backend = "triton"
@@ -506,21 +588,13 @@ class ModelRunner:
506
588
  server_args.attention_backend = "ascend"
507
589
  else:
508
590
  server_args.attention_backend = "triton"
509
- logger.info(
510
- f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
591
+ log_info_on_rank0(
592
+ logger,
593
+ f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
511
594
  )
512
595
  elif self.use_mla_backend:
513
596
  if server_args.device != "cpu":
514
- if server_args.attention_backend in [
515
- "aiter",
516
- "flashinfer",
517
- "fa3",
518
- "triton",
519
- "flashmla",
520
- "cutlass_mla",
521
- "trtllm_mla",
522
- "ascend",
523
- ]:
597
+ if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
524
598
  logger.info(
525
599
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
526
600
  )
@@ -559,23 +633,15 @@ class ModelRunner:
559
633
  f"{self.model_config.hf_config.model_type}"
560
634
  )
561
635
 
562
- if not self.use_mla_backend:
563
- server_args.disable_chunked_prefix_cache = True
564
-
565
- # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
566
- # For more details, see: https://github.com/sgl-project/sglang/issues/8616
567
- elif (
568
- self.dp_size > 1
569
- and is_sm100_supported()
570
- and server_args.attention_backend != "triton"
571
- and server_args.attention_backend == "trtllm_mla"
636
+ if (
637
+ not self.use_mla_backend
638
+ or server_args.attention_backend
639
+ not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
572
640
  ):
573
- logger.info(
574
- "Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
575
- )
576
641
  server_args.disable_chunked_prefix_cache = True
642
+
577
643
  if not server_args.disable_chunked_prefix_cache:
578
- logger.info("Chunked prefix cache is turned on.")
644
+ log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
579
645
 
580
646
  if server_args.attention_backend == "aiter":
581
647
  if self.model_config.context_len > 8192:
@@ -599,8 +665,37 @@ class ModelRunner:
599
665
  server_args.hicache_io_backend = "direct"
600
666
  logger.warning(
601
667
  "FlashAttention3 decode backend is not compatible with hierarchical cache. "
602
- f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
668
+ "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
669
+ )
670
+
671
+ if self.model_config.hf_config.model_type == "qwen3_vl_moe":
672
+ if (
673
+ quantization_config := getattr(
674
+ self.model_config.hf_config, "quantization_config", None
675
+ )
676
+ ) is not None:
677
+ weight_block_size_n = quantization_config["weight_block_size"][0]
678
+
679
+ if self.tp_size % self.moe_ep_size != 0:
680
+ raise ValueError(
681
+ f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
682
+ )
683
+ moe_tp_size = self.tp_size // self.moe_ep_size
684
+
685
+ moe_intermediate_size = (
686
+ self.model_config.hf_text_config.moe_intermediate_size
603
687
  )
688
+ if moe_intermediate_size % moe_tp_size != 0:
689
+ raise ValueError(
690
+ f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
691
+ )
692
+
693
+ if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
694
+ raise ValueError(
695
+ f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
696
+ f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
697
+ f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
698
+ )
604
699
 
605
700
  def init_torch_distributed(self):
606
701
  logger.info("Init torch distributed begin.")
@@ -614,7 +709,18 @@ class ModelRunner:
614
709
  raise
615
710
 
616
711
  if self.device == "cuda":
617
- backend = "nccl"
712
+ if self.server_args.elastic_ep_backend == "mooncake":
713
+ backend = "mooncake"
714
+ if self.server_args.mooncake_ib_device:
715
+ mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
716
+ try:
717
+ from mooncake import ep as mooncake_ep
718
+
719
+ mooncake_ep.set_device_filter(mooncake_ib_device)
720
+ except:
721
+ pass # A warning will be raised in `init_distributed_environment`
722
+ else:
723
+ backend = "nccl"
618
724
  elif self.device == "xpu":
619
725
  backend = "xccl"
620
726
  elif self.device == "hpu":
@@ -634,6 +740,7 @@ class ModelRunner:
634
740
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
635
741
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
636
742
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
743
+ set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
637
744
 
638
745
  if not self.is_draft_worker:
639
746
  if self.device == "cpu":
@@ -668,6 +775,7 @@ class ModelRunner:
668
775
  pipeline_model_parallel_size=self.pp_size,
669
776
  expert_model_parallel_size=self.moe_ep_size,
670
777
  duplicate_tp_group=self.server_args.enable_pdmux,
778
+ torch_compile=self.server_args.enable_piecewise_cuda_graph,
671
779
  )
672
780
  initialize_dp_attention(
673
781
  server_args=self.server_args,
@@ -726,10 +834,25 @@ class ModelRunner:
726
834
  set_cuda_arch()
727
835
 
728
836
  # Prepare the model config
837
+ from sglang.srt.configs.modelopt_config import ModelOptConfig
838
+
839
+ modelopt_config = ModelOptConfig(
840
+ quant=self.server_args.modelopt_quant,
841
+ checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
842
+ checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
843
+ export_path=self.server_args.modelopt_export_path,
844
+ quantize_and_serve=self.server_args.quantize_and_serve,
845
+ )
846
+
729
847
  self.load_config = LoadConfig(
730
848
  load_format=self.server_args.load_format,
731
849
  download_dir=self.server_args.download_dir,
732
850
  model_loader_extra_config=self.server_args.model_loader_extra_config,
851
+ tp_rank=self.tp_rank,
852
+ remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
853
+ remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
854
+ remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
855
+ modelopt_config=modelopt_config,
733
856
  )
734
857
  if self.device == "cpu":
735
858
  self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -757,7 +880,10 @@ class ModelRunner:
757
880
  monkey_patch_vllm_parallel_state()
758
881
  monkey_patch_isinstance_for_vllm_base_layer()
759
882
 
760
- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
883
+ with self.memory_saver_adapter.region(
884
+ GPU_MEMORY_TYPE_WEIGHTS,
885
+ enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
886
+ ):
761
887
  self.model = get_model(
762
888
  model_config=self.model_config,
763
889
  load_config=self.load_config,
@@ -813,33 +939,56 @@ class ModelRunner:
813
939
  f"mem usage={self.weight_load_mem_usage:.2f} GB."
814
940
  )
815
941
 
816
- # Handle the case where some ranks do not finish loading.
817
- try:
818
- dist.monitored_barrier(
819
- group=get_tp_group().cpu_group,
820
- timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
821
- wait_all_ranks=True,
822
- )
823
- except RuntimeError:
824
- raise ValueError(
825
- f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
826
- ) from None
942
+ if self.server_args.elastic_ep_backend == "mooncake":
943
+ # Mooncake does not support `monitored_barrier`
944
+ dist.barrier(group=get_tp_group().cpu_group)
945
+ else:
946
+ # Handle the case where some ranks do not finish loading.
947
+ try:
948
+ dist.monitored_barrier(
949
+ group=get_tp_group().cpu_group,
950
+ timeout=datetime.timedelta(
951
+ seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
952
+ ),
953
+ wait_all_ranks=True,
954
+ )
955
+ except RuntimeError:
956
+ raise ValueError(
957
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
958
+ ) from None
827
959
 
828
960
  def update_expert_location(
829
961
  self,
830
962
  new_expert_location_metadata: ExpertLocationMetadata,
831
963
  update_layer_ids: List[int],
832
964
  ):
833
- self.expert_location_updater.update(
834
- self.model.routed_experts_weights_of_layer,
835
- new_expert_location_metadata,
836
- update_layer_ids=update_layer_ids,
837
- nnodes=self.server_args.nnodes,
838
- rank=self.tp_rank,
839
- )
965
+ if ElasticEPStateManager.instance() is not None:
966
+ # TODO: refactor the weights update when elastic ep
967
+ old_expert_location_metadata = get_global_expert_location_metadata()
968
+ assert old_expert_location_metadata is not None
969
+ old_expert_location_metadata.update(
970
+ new_expert_location_metadata,
971
+ update_layer_ids=update_layer_ids,
972
+ )
973
+ self.update_weights_from_disk(
974
+ self.server_args.model_path,
975
+ self.server_args.load_format,
976
+ lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
977
+ )
978
+ else:
979
+ self.expert_location_updater.update(
980
+ self.model.routed_experts_weights_of_layer,
981
+ new_expert_location_metadata,
982
+ update_layer_ids=update_layer_ids,
983
+ nnodes=self.server_args.nnodes,
984
+ rank=self.tp_rank,
985
+ )
840
986
 
841
987
  def update_weights_from_disk(
842
- self, model_path: str, load_format: str
988
+ self,
989
+ model_path: str,
990
+ load_format: str,
991
+ weight_name_filter: Optional[Callable[[str], bool]] = None,
843
992
  ) -> tuple[bool, str]:
844
993
  """Update engine weights in-place from the disk."""
845
994
  logger.info(
@@ -852,7 +1001,7 @@ class ModelRunner:
852
1001
  load_config = LoadConfig(load_format=load_format)
853
1002
 
854
1003
  # Only support DefaultModelLoader for now
855
- loader = get_model_loader(load_config)
1004
+ loader = get_model_loader(load_config, self.model_config)
856
1005
  if not isinstance(loader, DefaultModelLoader):
857
1006
  message = f"Failed to get model loader: {loader}."
858
1007
  return False, message
@@ -861,6 +1010,11 @@ class ModelRunner:
861
1010
  iter = loader._get_weights_iterator(
862
1011
  DefaultModelLoader.Source.init_new(config, self.model)
863
1012
  )
1013
+ if weight_name_filter is not None:
1014
+ iter = (
1015
+ (name, weight) for name, weight in iter if weight_name_filter(name)
1016
+ )
1017
+
864
1018
  return iter
865
1019
 
866
1020
  def model_load_weights(model, iter):
@@ -1035,6 +1189,19 @@ class ModelRunner:
1035
1189
  logger.error(message)
1036
1190
  return False, message
1037
1191
 
1192
+ def destroy_weights_update_group(self, group_name):
1193
+ try:
1194
+ if group_name in self._model_update_group:
1195
+ pg = self._model_update_group.pop(group_name)
1196
+ torch.distributed.destroy_process_group(pg)
1197
+ return True, "Succeeded to destroy custom process group."
1198
+ else:
1199
+ return False, "The group to be destroyed does not exist."
1200
+ except Exception as e:
1201
+ message = f"Failed to destroy custom process group: {e}."
1202
+ logger.error(message)
1203
+ return False, message
1204
+
1038
1205
  def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
1039
1206
  """
1040
1207
  Update specific parameter in the model weights online
@@ -1072,7 +1239,7 @@ class ModelRunner:
1072
1239
  handle.wait()
1073
1240
 
1074
1241
  self.model.load_weights(weights)
1075
- return True, f"Succeeded to update parameter online."
1242
+ return True, "Succeeded to update parameter online."
1076
1243
 
1077
1244
  except Exception as e:
1078
1245
  error_msg = (
@@ -1176,6 +1343,7 @@ class ModelRunner:
1176
1343
  max_lora_rank=self.server_args.max_lora_rank,
1177
1344
  target_modules=self.server_args.lora_target_modules,
1178
1345
  lora_paths=self.server_args.lora_paths,
1346
+ server_args=self.server_args,
1179
1347
  )
1180
1348
 
1181
1349
  def load_lora_adapter(self, lora_ref: LoRARef):
@@ -1225,8 +1393,8 @@ class ModelRunner:
1225
1393
  "num_nextn_predict_layers",
1226
1394
  self.num_effective_layers,
1227
1395
  )
1228
- elif self.is_hybrid_gdn:
1229
- num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1396
+ elif config := self.mambaish_config:
1397
+ num_layers = len(config.full_attention_layer_ids)
1230
1398
  else:
1231
1399
  num_layers = self.num_effective_layers
1232
1400
  if self.use_mla_backend:
@@ -1235,6 +1403,17 @@ class ModelRunner:
1235
1403
  * num_layers
1236
1404
  * torch._utils._element_size(self.kv_cache_dtype)
1237
1405
  )
1406
+ # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
1407
+ if is_deepseek_nsa(self.model_config.hf_config):
1408
+ index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
1409
+ indexer_size_per_token = (
1410
+ index_head_dim
1411
+ + index_head_dim // NSATokenToKVPool.quant_block_size * 4
1412
+ )
1413
+ element_size = torch._utils._element_size(
1414
+ NSATokenToKVPool.index_k_with_scale_buffer_dtype
1415
+ )
1416
+ cell_size += indexer_size_per_token * num_layers * element_size
1238
1417
  else:
1239
1418
  cell_size = (
1240
1419
  self.model_config.get_num_kv_heads(get_attention_tp_size())
@@ -1246,21 +1425,77 @@ class ModelRunner:
1246
1425
  rest_memory = available_gpu_memory - total_gpu_memory * (
1247
1426
  1 - self.mem_fraction_static
1248
1427
  )
1249
- if self.is_hybrid_gdn:
1250
- rest_memory -= (
1251
- self.server_args.max_mamba_cache_size
1252
- * self.model_config.hf_config.mamba_cache_per_req
1253
- / (1 << 30)
1254
- )
1428
+ if self.mambaish_config is not None:
1429
+ rest_memory = self.handle_max_mamba_cache(rest_memory)
1255
1430
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1256
1431
  return max_num_token
1257
1432
 
1433
+ def handle_max_mamba_cache(self, total_rest_memory):
1434
+ config = self.mambaish_config
1435
+ server_args = self.server_args
1436
+ assert config is not None
1437
+
1438
+ speculativa_ratio = (
1439
+ 0
1440
+ if server_args.speculative_num_draft_tokens is None
1441
+ else server_args.speculative_num_draft_tokens
1442
+ )
1443
+ if (
1444
+ server_args.disable_radix_cache
1445
+ or config.mamba2_cache_params.mamba_cache_per_req == 0
1446
+ ):
1447
+ # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
1448
+ if server_args.max_mamba_cache_size is None:
1449
+ if server_args.max_running_requests is not None:
1450
+ server_args.max_mamba_cache_size = server_args.max_running_requests
1451
+ else:
1452
+ server_args.max_mamba_cache_size = 512
1453
+ else:
1454
+ # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
1455
+ # solve the equations:
1456
+ # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
1457
+ # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
1458
+ mamba_state_memory_raw = (
1459
+ total_rest_memory
1460
+ * server_args.mamba_full_memory_ratio
1461
+ / (1 + server_args.mamba_full_memory_ratio)
1462
+ )
1463
+ # calculate the max_mamba_cache_size based on the given total mamba memory
1464
+ server_args.max_mamba_cache_size = int(
1465
+ (mamba_state_memory_raw * (1 << 30))
1466
+ // config.mamba2_cache_params.mamba_cache_per_req
1467
+ // (1 + speculativa_ratio)
1468
+ )
1469
+
1470
+ if self.hybrid_gdn_config is not None:
1471
+ server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
1472
+ server_args.dp_size if server_args.enable_dp_attention else 1
1473
+ )
1474
+ mamba_state_memory = (
1475
+ server_args.max_mamba_cache_size
1476
+ * config.mamba2_cache_params.mamba_cache_per_req
1477
+ * (1 + speculativa_ratio)
1478
+ / (1 << 30)
1479
+ )
1480
+ return total_rest_memory - mamba_state_memory
1481
+
1258
1482
  @property
1259
- def is_hybrid_gdn(self):
1260
- return self.model_config.hf_config.architectures[0] in [
1261
- "Qwen3NextForCausalLM",
1262
- "Qwen3NextForCausalLMMTP",
1263
- ]
1483
+ def hybrid_gdn_config(self):
1484
+ config = self.model_config.hf_config
1485
+ if isinstance(config, Qwen3NextConfig):
1486
+ return config
1487
+ return None
1488
+
1489
+ @property
1490
+ def mamba2_config(self):
1491
+ config = self.model_config.hf_config
1492
+ if isinstance(config, FalconH1Config | NemotronHConfig):
1493
+ return config
1494
+ return None
1495
+
1496
+ @property
1497
+ def mambaish_config(self):
1498
+ return self.mamba2_config or self.hybrid_gdn_config
1264
1499
 
1265
1500
  def set_num_token_hybrid(self):
1266
1501
  if (
@@ -1344,6 +1579,27 @@ class ModelRunner:
1344
1579
  f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
1345
1580
  )
1346
1581
 
1582
+ def can_run_piecewise_cuda_graph(self):
1583
+ if self.server_args.disable_cuda_graph:
1584
+ log_info_on_rank0(
1585
+ logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
1586
+ )
1587
+ return False
1588
+ if self.server_args.enable_torch_compile:
1589
+ log_info_on_rank0(
1590
+ logger,
1591
+ "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
1592
+ )
1593
+ return False
1594
+ if self.pp_size > 1:
1595
+ # TODO(yuwei): support PP
1596
+ log_info_on_rank0(
1597
+ logger,
1598
+ "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
1599
+ )
1600
+ return False
1601
+ return True
1602
+
1347
1603
  def init_memory_pool(
1348
1604
  self,
1349
1605
  total_gpu_memory: int,
@@ -1352,7 +1608,18 @@ class ModelRunner:
1352
1608
  ):
1353
1609
  # Determine the kv cache dtype
1354
1610
  if self.server_args.kv_cache_dtype == "auto":
1355
- self.kv_cache_dtype = self.dtype
1611
+ quant_config = getattr(self.model, "quant_config", None)
1612
+ kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
1613
+ if (
1614
+ isinstance(kv_cache_quant_algo, str)
1615
+ and kv_cache_quant_algo.upper() == "FP8"
1616
+ ):
1617
+ if _is_hip:
1618
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
1619
+ else:
1620
+ self.kv_cache_dtype = torch.float8_e4m3fn
1621
+ else:
1622
+ self.kv_cache_dtype = self.dtype
1356
1623
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1357
1624
  if _is_hip: # Using natively supported format
1358
1625
  self.kv_cache_dtype = torch.float8_e5m2fnuz
@@ -1363,11 +1630,15 @@ class ModelRunner:
1363
1630
  self.kv_cache_dtype = torch.float8_e4m3fnuz
1364
1631
  else:
1365
1632
  self.kv_cache_dtype = torch.float8_e4m3fn
1633
+ elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
1634
+ self.kv_cache_dtype = torch.bfloat16
1366
1635
  else:
1367
1636
  raise ValueError(
1368
1637
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
1369
1638
  )
1370
1639
 
1640
+ log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
1641
+
1371
1642
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1372
1643
  if SGLANG_CI_SMALL_KV_SIZE:
1373
1644
  self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
@@ -1382,10 +1653,18 @@ class ModelRunner:
1382
1653
  ),
1383
1654
  4096,
1384
1655
  )
1385
- if self.is_hybrid_gdn:
1386
- max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1387
1656
 
1388
- if not self.spec_algorithm.is_none():
1657
+ if self.mambaish_config is not None:
1658
+ ratio = (
1659
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
1660
+ if not self.server_args.disable_radix_cache
1661
+ else 1
1662
+ )
1663
+ max_num_reqs = min(
1664
+ max_num_reqs, self.server_args.max_mamba_cache_size // ratio
1665
+ )
1666
+
1667
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1389
1668
  if self.is_draft_worker:
1390
1669
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1391
1670
  max_num_reqs = self.server_args.max_num_reqs
@@ -1438,7 +1717,8 @@ class ModelRunner:
1438
1717
 
1439
1718
  if self.max_total_num_tokens <= 0:
1440
1719
  raise RuntimeError(
1441
- "Not enough memory. Please try to increase --mem-fraction-static."
1720
+ f"Not enough memory. Please try to increase --mem-fraction-static. "
1721
+ f"Current value: {self.server_args.mem_fraction_static=}"
1442
1722
  )
1443
1723
 
1444
1724
  # Initialize req_to_token_pool
@@ -1449,39 +1729,43 @@ class ModelRunner:
1449
1729
  extra_max_context_len += self.server_args.speculative_num_draft_tokens
1450
1730
 
1451
1731
  if self.server_args.disaggregation_mode == "decode":
1452
- from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
1732
+ from sglang.srt.disaggregation.decode import (
1733
+ DecodeReqToTokenPool,
1734
+ HybridMambaDecodeReqToTokenPool,
1735
+ )
1453
1736
 
1454
1737
  # subscribe memory for pre-allocated requests
1455
1738
  # if max_num_reqs <= 32, we pre-allocate 2x requests
1456
1739
  pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1457
- self.req_to_token_pool = DecodeReqToTokenPool(
1458
- size=max_num_reqs,
1459
- max_context_len=self.model_config.context_len
1460
- + extra_max_context_len,
1461
- device=self.device,
1462
- enable_memory_saver=self.server_args.enable_memory_saver,
1463
- pre_alloc_size=pre_alloc_size,
1464
- )
1465
- elif self.is_hybrid_gdn:
1466
- config = self.model_config.hf_config
1467
- (
1468
- conv_state_shape,
1469
- temporal_state_shape,
1470
- conv_dtype,
1471
- ssm_dtype,
1472
- mamba_layers,
1473
- ) = config.hybrid_gdn_params
1740
+ if config := self.mambaish_config:
1741
+ self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
1742
+ size=max_num_reqs,
1743
+ max_context_len=self.model_config.context_len
1744
+ + extra_max_context_len,
1745
+ device=self.device,
1746
+ enable_memory_saver=self.server_args.enable_memory_saver,
1747
+ cache_params=config.mamba2_cache_params,
1748
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1749
+ pre_alloc_size=pre_alloc_size,
1750
+ )
1751
+ else:
1752
+ self.req_to_token_pool = DecodeReqToTokenPool(
1753
+ size=max_num_reqs,
1754
+ max_context_len=self.model_config.context_len
1755
+ + extra_max_context_len,
1756
+ device=self.device,
1757
+ enable_memory_saver=self.server_args.enable_memory_saver,
1758
+ pre_alloc_size=pre_alloc_size,
1759
+ )
1760
+ elif config := self.mambaish_config:
1474
1761
  self.req_to_token_pool = HybridReqToTokenPool(
1475
1762
  size=max_num_reqs,
1763
+ mamba_size=self.server_args.max_mamba_cache_size,
1476
1764
  max_context_len=self.model_config.context_len
1477
1765
  + extra_max_context_len,
1478
1766
  device=self.device,
1479
1767
  enable_memory_saver=self.server_args.enable_memory_saver,
1480
- conv_state_shape=conv_state_shape,
1481
- temporal_state_shape=temporal_state_shape,
1482
- conv_dtype=conv_dtype,
1483
- ssm_dtype=ssm_dtype,
1484
- mamba_layers=mamba_layers,
1768
+ cache_params=config.mamba2_cache_params,
1485
1769
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1486
1770
  )
1487
1771
  else:
@@ -1497,6 +1781,7 @@ class ModelRunner:
1497
1781
  assert self.is_draft_worker
1498
1782
 
1499
1783
  # Initialize token_to_kv_pool
1784
+ is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
1500
1785
  if self.server_args.attention_backend == "ascend":
1501
1786
  if self.use_mla_backend:
1502
1787
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1505,6 +1790,7 @@ class ModelRunner:
1505
1790
  dtype=self.kv_cache_dtype,
1506
1791
  kv_lora_rank=self.model_config.kv_lora_rank,
1507
1792
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1793
+ index_head_dim=self.model_config.index_head_dim,
1508
1794
  layer_num=self.num_effective_layers,
1509
1795
  device=self.device,
1510
1796
  enable_memory_saver=self.server_args.enable_memory_saver,
@@ -1524,7 +1810,22 @@ class ModelRunner:
1524
1810
  device=self.device,
1525
1811
  enable_memory_saver=self.server_args.enable_memory_saver,
1526
1812
  )
1813
+ elif self.use_mla_backend and is_nsa_model:
1814
+ self.token_to_kv_pool = NSATokenToKVPool(
1815
+ self.max_total_num_tokens,
1816
+ page_size=self.page_size,
1817
+ dtype=self.kv_cache_dtype,
1818
+ kv_lora_rank=self.model_config.kv_lora_rank,
1819
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1820
+ layer_num=self.num_effective_layers,
1821
+ device=self.device,
1822
+ enable_memory_saver=self.server_args.enable_memory_saver,
1823
+ start_layer=self.start_layer,
1824
+ end_layer=self.end_layer,
1825
+ index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1826
+ )
1527
1827
  elif self.use_mla_backend:
1828
+ assert not is_nsa_model
1528
1829
  self.token_to_kv_pool = MLATokenToKVPool(
1529
1830
  self.max_total_num_tokens,
1530
1831
  page_size=self.page_size,
@@ -1566,9 +1867,9 @@ class ModelRunner:
1566
1867
  enable_kvcache_transpose=False,
1567
1868
  device=self.device,
1568
1869
  )
1569
- elif self.is_hybrid_gdn:
1870
+ elif config := self.mambaish_config:
1570
1871
  self.token_to_kv_pool = HybridLinearKVPool(
1571
- page_size=self.page_size if _is_npu else 1,
1872
+ page_size=self.page_size,
1572
1873
  size=self.max_total_num_tokens,
1573
1874
  dtype=self.kv_cache_dtype,
1574
1875
  head_num=self.model_config.get_num_kv_heads(
@@ -1577,12 +1878,11 @@ class ModelRunner:
1577
1878
  head_dim=self.model_config.head_dim,
1578
1879
  # if draft worker, we only need 1 attention layer's kv pool
1579
1880
  full_attention_layer_ids=(
1580
- [0]
1581
- if self.is_draft_worker
1582
- else self.model_config.hf_config.full_attention_layer_ids
1881
+ [0] if self.is_draft_worker else config.full_attention_layer_ids
1583
1882
  ),
1584
1883
  enable_kvcache_transpose=False,
1585
1884
  device=self.device,
1885
+ mamba_pool=self.req_to_token_pool.mamba_pool,
1586
1886
  )
1587
1887
  else:
1588
1888
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -1598,15 +1898,18 @@ class ModelRunner:
1598
1898
  enable_memory_saver=self.server_args.enable_memory_saver,
1599
1899
  start_layer=self.start_layer,
1600
1900
  end_layer=self.end_layer,
1901
+ enable_kv_cache_copy=(
1902
+ self.server_args.speculative_algorithm is not None
1903
+ ),
1601
1904
  )
1602
1905
 
1603
1906
  # Initialize token_to_kv_pool_allocator
1604
1907
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1605
1908
  if self.token_to_kv_pool_allocator is None:
1606
- if _is_npu and self.server_args.attention_backend in [
1607
- "ascend",
1608
- "hybrid_linear_attn",
1609
- ]:
1909
+ if _is_npu and (
1910
+ self.server_args.attention_backend == "ascend"
1911
+ or self.hybrid_gdn_config is not None
1912
+ ):
1610
1913
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1611
1914
  self.max_total_num_tokens,
1612
1915
  page_size=self.page_size,
@@ -1670,16 +1973,10 @@ class ModelRunner:
1670
1973
 
1671
1974
  def _get_attention_backend(self):
1672
1975
  """Init attention kernel backend."""
1673
- self.decode_attention_backend_str = (
1674
- self.server_args.decode_attention_backend
1675
- if self.server_args.decode_attention_backend
1676
- else self.server_args.attention_backend
1677
- )
1678
- self.prefill_attention_backend_str = (
1679
- self.server_args.prefill_attention_backend
1680
- if self.server_args.prefill_attention_backend
1681
- else self.server_args.attention_backend
1976
+ self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1977
+ self.server_args.get_attention_backends()
1682
1978
  )
1979
+
1683
1980
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1684
1981
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1685
1982
  HybridAttnBackend,
@@ -1700,157 +1997,25 @@ class ModelRunner:
1700
1997
  f"prefill_backend={self.prefill_attention_backend_str}."
1701
1998
  )
1702
1999
  logger.warning(
1703
- f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1704
- f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
2000
+ "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
2001
+ "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1705
2002
  )
1706
2003
  else:
1707
2004
  attn_backend = self._get_attention_backend_from_str(
1708
2005
  self.server_args.attention_backend
1709
2006
  )
1710
2007
 
1711
- global_server_args_dict.update(
1712
- {
1713
- "decode_attention_backend": self.decode_attention_backend_str,
1714
- "prefill_attention_backend": self.prefill_attention_backend_str,
1715
- }
1716
- )
2008
+ (
2009
+ get_global_server_args().prefill_attention_backend,
2010
+ get_global_server_args().decode_attention_backend,
2011
+ ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1717
2012
  return attn_backend
1718
2013
 
1719
2014
  def _get_attention_backend_from_str(self, backend_str: str):
1720
- if backend_str == "flashinfer":
1721
- if not self.use_mla_backend:
1722
- from sglang.srt.layers.attention.flashinfer_backend import (
1723
- FlashInferAttnBackend,
1724
- )
1725
-
1726
- # Init streams
1727
- if self.server_args.speculative_algorithm == "EAGLE":
1728
- if (
1729
- not hasattr(self, "plan_stream_for_flashinfer")
1730
- or not self.plan_stream_for_flashinfer
1731
- ):
1732
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1733
- return FlashInferAttnBackend(self)
1734
- else:
1735
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
1736
- FlashInferMLAAttnBackend,
1737
- )
1738
-
1739
- return FlashInferMLAAttnBackend(self)
1740
- elif backend_str == "aiter":
1741
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1742
-
1743
- return AiterAttnBackend(self)
1744
- elif self.server_args.attention_backend == "wave":
1745
- from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1746
-
1747
- return WaveAttnBackend(self)
1748
- elif backend_str == "ascend":
1749
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1750
-
1751
- return AscendAttnBackend(self)
1752
- elif backend_str == "triton":
1753
- assert not self.model_config.is_encoder_decoder, (
1754
- "Cross attention is not supported in the triton attention backend. "
1755
- "Please use `--attention-backend flashinfer`."
1756
- )
1757
- if self.server_args.enable_double_sparsity:
1758
- from sglang.srt.layers.attention.double_sparsity_backend import (
1759
- DoubleSparseAttnBackend,
1760
- )
1761
-
1762
- return DoubleSparseAttnBackend(self)
1763
- else:
1764
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1765
-
1766
- return TritonAttnBackend(self)
1767
- elif backend_str == "torch_native":
1768
- from sglang.srt.layers.attention.torch_native_backend import (
1769
- TorchNativeAttnBackend,
1770
- )
1771
-
1772
- return TorchNativeAttnBackend(self)
1773
- elif backend_str == "flashmla":
1774
- from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1775
-
1776
- return FlashMLABackend(self)
1777
- elif backend_str == "fa3":
1778
- assert (
1779
- torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1780
- ) or torch.cuda.get_device_capability()[0] == 9, (
1781
- "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1782
- "Please use `--attention-backend flashinfer`."
1783
- )
1784
- from sglang.srt.layers.attention.flashattention_backend import (
1785
- FlashAttentionBackend,
1786
- )
1787
-
1788
- return FlashAttentionBackend(self)
1789
- elif backend_str == "cutlass_mla":
1790
- from sglang.srt.layers.attention.cutlass_mla_backend import (
1791
- CutlassMLABackend,
1792
- )
1793
-
1794
- return CutlassMLABackend(self)
1795
- elif backend_str == "trtllm_mla":
1796
- if not self.use_mla_backend:
1797
- raise ValueError("trtllm_mla backend can only be used with MLA models.")
1798
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1799
-
1800
- return TRTLLMMLABackend(self)
1801
- elif backend_str == "trtllm_mha":
1802
- if self.use_mla_backend:
1803
- raise ValueError(
1804
- "trtllm_mha backend can only be used with non-MLA models."
1805
- )
1806
- from sglang.srt.layers.attention.trtllm_mha_backend import (
1807
- TRTLLMHAAttnBackend,
1808
- )
1809
-
1810
- return TRTLLMHAAttnBackend(self)
1811
- elif backend_str == "intel_amx":
1812
- from sglang.srt.layers.attention.intel_amx_backend import (
1813
- IntelAMXAttnBackend,
1814
- )
1815
-
1816
- return IntelAMXAttnBackend(self)
1817
- elif backend_str == "dual_chunk_flash_attn":
1818
- from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1819
- DualChunkFlashAttentionBackend,
1820
- )
1821
-
1822
- return DualChunkFlashAttentionBackend(self)
1823
- elif backend_str == "hybrid_linear_attn":
1824
- assert (
1825
- self.is_hybrid_gdn
1826
- ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
1827
- from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
1828
- HybridLinearAttnBackend,
1829
- MambaAttnBackend,
1830
- )
1831
-
1832
- if _is_npu:
1833
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1834
-
1835
- full_attn_backend = AscendAttnBackend(self)
1836
- elif is_blackwell():
1837
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1838
-
1839
- full_attn_backend = TritonAttnBackend(self)
1840
- else:
1841
- from sglang.srt.layers.attention.flashattention_backend import (
1842
- FlashAttentionBackend,
1843
- )
1844
-
1845
- full_attn_backend = FlashAttentionBackend(self)
1846
-
1847
- linear_attn_backend = MambaAttnBackend(self)
1848
- full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
1849
- return HybridLinearAttnBackend(
1850
- full_attn_backend, linear_attn_backend, full_attn_layers
1851
- )
1852
- else:
2015
+ if backend_str not in ATTENTION_BACKENDS:
1853
2016
  raise ValueError(f"Invalid attention backend: {backend_str}")
2017
+ full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
2018
+ return attn_backend_wrapper(self, full_attention_backend)
1854
2019
 
1855
2020
  def init_double_sparsity_channel_config(self, selected_channel):
1856
2021
  selected_channel = "." + selected_channel + "_proj"
@@ -1981,6 +2146,11 @@ class ModelRunner:
1981
2146
  kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1982
2147
  if not self.is_generation:
1983
2148
  kwargs["get_embedding"] = True
2149
+
2150
+ if self.piecewise_cuda_graph_runner is not None:
2151
+ if self.piecewise_cuda_graph_runner.can_run(forward_batch):
2152
+ return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
2153
+
1984
2154
  return self.model.forward(
1985
2155
  forward_batch.input_ids,
1986
2156
  forward_batch.positions,
@@ -2114,15 +2284,11 @@ class ModelRunner:
2114
2284
  def _preprocess_logits(
2115
2285
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
2116
2286
  ):
2117
- # Apply logit bias
2118
- if sampling_info.sampling_info_done:
2119
- # Overlap mode: the function update_regex_vocab_mask was executed
2120
- # in process_batch_result of the last batch.
2121
- if sampling_info.grammars:
2122
- sampling_info.sampling_info_done.wait()
2123
- else:
2124
- # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
2125
- sampling_info.update_regex_vocab_mask()
2287
+ # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
2288
+ # was executed after we processed last batch's results.
2289
+
2290
+ # Calculate logits bias and apply it to next_token_logits.
2291
+ sampling_info.update_regex_vocab_mask()
2126
2292
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
2127
2293
 
2128
2294
  def sample(
@@ -2147,7 +2313,6 @@ class ModelRunner:
2147
2313
  )
2148
2314
 
2149
2315
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
2150
-
2151
2316
  # Sample the next tokens
2152
2317
  next_token_ids = self.sampler(
2153
2318
  logits_output,
@@ -2155,6 +2320,12 @@ class ModelRunner:
2155
2320
  forward_batch.return_logprob,
2156
2321
  forward_batch.top_logprobs_nums,
2157
2322
  forward_batch.token_ids_logprobs,
2323
+ # For prefill, we only use the position of the last token.
2324
+ (
2325
+ forward_batch.positions
2326
+ if forward_batch.forward_mode.is_decode()
2327
+ else forward_batch.seq_lens - 1
2328
+ ),
2158
2329
  )
2159
2330
  return next_token_ids
2160
2331
 
@@ -2216,6 +2387,23 @@ class ModelRunner:
2216
2387
  )
2217
2388
  ShardedStateLoader.save_model(self.model, path, pattern, max_size)
2218
2389
 
2390
+ def update_weights_from_ipc(self, recv_req):
2391
+ """Update weights from IPC for checkpoint-engine integration."""
2392
+ try:
2393
+ from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
2394
+ SGLangCheckpointEngineWorkerExtensionImpl,
2395
+ )
2396
+
2397
+ # Create a worker extension that integrates with SGLang's model
2398
+ worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
2399
+ worker.update_weights_from_ipc(recv_req.zmq_handles)
2400
+ return True, "IPC weight update completed successfully"
2401
+ except ImportError as e:
2402
+ return False, f"IPC weight update failed: ImportError {e}"
2403
+ except Exception as e:
2404
+ logger.error(f"IPC weight update failed: {e}")
2405
+ return False, str(e)
2406
+
2219
2407
 
2220
2408
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
2221
2409
  params_dict = dict(model.named_parameters())