sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -24,11 +24,12 @@ import threading
24
24
  import time
25
25
  from collections import defaultdict
26
26
  from dataclasses import dataclass
27
- from typing import List, Optional, Tuple, Union
27
+ from typing import Callable, List, Optional, Tuple, Union
28
28
 
29
29
  import torch
30
30
  import torch.distributed as dist
31
31
 
32
+ from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
32
33
  from sglang.srt.configs.device_config import DeviceConfig
33
34
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
34
35
  from sglang.srt.configs.model_config import (
@@ -50,6 +51,7 @@ from sglang.srt.distributed import (
50
51
  set_symm_mem_all_reduce,
51
52
  )
52
53
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
54
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
53
55
  from sglang.srt.eplb.eplb_manager import EPLBManager
54
56
  from sglang.srt.eplb.expert_distribution import (
55
57
  ExpertDistributionRecorder,
@@ -63,6 +65,7 @@ from sglang.srt.eplb.expert_location import (
63
65
  set_global_expert_location_metadata,
64
66
  )
65
67
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
68
+ from sglang.srt.layers import deep_gemm_wrapper
66
69
  from sglang.srt.layers.attention.attention_registry import (
67
70
  ATTENTION_BACKENDS,
68
71
  attn_backend_wrapper,
@@ -74,18 +77,11 @@ from sglang.srt.layers.dp_attention import (
74
77
  initialize_dp_attention,
75
78
  )
76
79
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
77
- from sglang.srt.layers.quantization import (
78
- deep_gemm_wrapper,
79
- monkey_patch_isinstance_for_vllm_base_layer,
80
- )
80
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
81
81
  from sglang.srt.layers.sampler import Sampler
82
82
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
83
83
  from sglang.srt.lora.lora_manager import LoRAManager
84
84
  from sglang.srt.lora.lora_registry import LoRARef
85
- from sglang.srt.managers.schedule_batch import (
86
- GLOBAL_SERVER_ARGS_KEYS,
87
- global_server_args_dict,
88
- )
89
85
  from sglang.srt.mem_cache.allocator import (
90
86
  BaseTokenToKVPoolAllocator,
91
87
  PagedTokenToKVPoolAllocator,
@@ -109,6 +105,9 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
109
105
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
110
106
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
111
107
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
108
+ from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
109
+ PiecewiseCudaGraphRunner,
110
+ )
112
111
  from sglang.srt.model_loader import get_model
113
112
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
114
113
  from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
@@ -116,15 +115,13 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
116
115
  )
117
116
  from sglang.srt.model_loader.utils import set_default_torch_dtype
118
117
  from sglang.srt.model_loader.weight_utils import default_weight_loader
119
- from sglang.srt.offloader import (
120
- create_offloader_from_server_args,
121
- get_offloader,
122
- set_offloader,
123
- )
124
118
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
125
- from sglang.srt.server_args import ServerArgs
119
+ from sglang.srt.server_args import (
120
+ ServerArgs,
121
+ get_global_server_args,
122
+ set_global_server_args_for_scheduler,
123
+ )
126
124
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
127
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
128
125
  from sglang.srt.utils import (
129
126
  MultiprocessingSerializer,
130
127
  cpu_has_amx_support,
@@ -146,8 +143,15 @@ from sglang.srt.utils import (
146
143
  monkey_patch_vllm_gguf_config,
147
144
  set_cuda_arch,
148
145
  slow_rank_detector,
146
+ xpu_has_xmx_support,
147
+ )
148
+ from sglang.srt.utils.offloader import (
149
+ create_offloader_from_server_args,
150
+ get_offloader,
151
+ set_offloader,
149
152
  )
150
153
  from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
154
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
151
155
  from sglang.srt.weight_sync.tensor_bucket import (
152
156
  FlattenedTensorBucket,
153
157
  FlattenedTensorMetadata,
@@ -166,6 +170,15 @@ MLA_ATTENTION_BACKENDS = [
166
170
  "nsa",
167
171
  ]
168
172
 
173
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
174
+ "flashinfer",
175
+ "fa3",
176
+ "fa4",
177
+ "flashmla",
178
+ "cutlass_mla",
179
+ "trtllm_mla",
180
+ ]
181
+
169
182
 
170
183
  def add_mla_attention_backend(backend_name):
171
184
  if backend_name not in MLA_ATTENTION_BACKENDS:
@@ -173,9 +186,18 @@ def add_mla_attention_backend(backend_name):
173
186
  logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
174
187
 
175
188
 
189
+ def add_chunked_prefix_cache_attention_backend(backend_name):
190
+ if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
191
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
192
+ logger.info(
193
+ f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
194
+ )
195
+
196
+
176
197
  _is_hip = is_hip()
177
198
  _is_npu = is_npu()
178
199
  _is_cpu_amx_available = cpu_has_amx_support()
200
+ _is_xpu_xmx_available = xpu_has_xmx_support()
179
201
 
180
202
  # Use a small KV cache pool size for tests in CI
181
203
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -183,8 +205,10 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
183
205
  # Detect stragger ranks in model loading
184
206
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
185
207
 
186
- logger = logging.getLogger(__name__)
208
+ # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
209
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
187
210
 
211
+ logger = logging.getLogger(__name__)
188
212
 
189
213
  if _is_npu:
190
214
  import torch_npu
@@ -257,25 +281,21 @@ class ModelRunner:
257
281
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
258
282
  self.attention_chunk_size = model_config.attention_chunk_size
259
283
  self.forward_pass_id = 0
284
+ self.init_new_workspace = False
260
285
 
261
286
  # Apply the rank zero filter to logger
262
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
263
- logger.addFilter(RankZeroFilter(tp_rank == 0))
264
287
  if server_args.show_time_cost:
265
288
  enable_show_time_cost()
266
289
 
267
290
  # Model-specific adjustment
268
291
  self.model_specific_adjustment()
269
292
 
270
- # Global vars
271
- global_server_args_dict.update(
272
- {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
273
- | {
274
- # TODO it is indeed not a "server args"
275
- "use_mla_backend": self.use_mla_backend,
276
- "speculative_algorithm": self.spec_algorithm,
277
- }
278
- )
293
+ # Set the global server_args in the scheduler process
294
+ set_global_server_args_for_scheduler(server_args)
295
+ global_server_args = get_global_server_args()
296
+
297
+ # FIXME: hacky set `use_mla_backend`
298
+ global_server_args.use_mla_backend = self.use_mla_backend
279
299
 
280
300
  # Init OpenMP threads binding for CPU
281
301
  if self.device == "cpu":
@@ -306,6 +326,26 @@ class ModelRunner:
306
326
  self._model_update_group = {}
307
327
  self._weights_send_group = {}
308
328
 
329
+ if (
330
+ self.server_args.enable_piecewise_cuda_graph
331
+ and self.can_run_piecewise_cuda_graph()
332
+ ):
333
+ self.attention_layers = []
334
+ for layer in self.model.model.layers:
335
+ if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
336
+ self.attention_layers.append(layer.self_attn.attn)
337
+ if len(self.attention_layers) < self.model_config.num_hidden_layers:
338
+ # TODO(yuwei): support Non-Standard GQA
339
+ log_info_on_rank0(
340
+ logger,
341
+ "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
342
+ )
343
+ self.piecewise_cuda_graph_runner = None
344
+ else:
345
+ self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
346
+ else:
347
+ self.piecewise_cuda_graph_runner = None
348
+
309
349
  def initialize(self, min_per_gpu_memory: float):
310
350
  server_args = self.server_args
311
351
 
@@ -340,6 +380,11 @@ class ModelRunner:
340
380
  )
341
381
  self.expert_location_updater = ExpertLocationUpdater()
342
382
 
383
+ (
384
+ ElasticEPStateManager.init(self.server_args)
385
+ if self.server_args.elastic_ep_backend
386
+ else None
387
+ )
343
388
  # Load the model
344
389
  self.sampler = Sampler()
345
390
  self.load_model()
@@ -354,24 +399,10 @@ class ModelRunner:
354
399
  if architectures and not any("Llama4" in arch for arch in architectures):
355
400
  self.is_hybrid = self.model_config.is_hybrid = True
356
401
 
357
- if self.is_hybrid_gdn:
358
- logger.warning("Hybrid GDN model detected, disable radix cache")
402
+ if config := self.mamba2_config:
403
+ class_name = config.__class__.__name__
404
+ logger.warning(f"{class_name} model detected, disable radix cache")
359
405
  self.server_args.disable_radix_cache = True
360
- if self.server_args.max_mamba_cache_size is None:
361
- if self.server_args.max_running_requests is not None:
362
- self.server_args.max_mamba_cache_size = (
363
- self.server_args.max_running_requests
364
- )
365
- else:
366
- self.server_args.max_mamba_cache_size = 512
367
- self.server_args.max_mamba_cache_size = (
368
- self.server_args.max_mamba_cache_size
369
- // (
370
- self.server_args.dp_size
371
- if self.server_args.enable_dp_attention
372
- else 1
373
- )
374
- )
375
406
 
376
407
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
377
408
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
@@ -402,7 +433,7 @@ class ModelRunner:
402
433
  # In layered loading, torchao may have been applied
403
434
  if not torchao_applied:
404
435
  apply_torchao_config_to_model(
405
- self.model, global_server_args_dict["torchao_config"]
436
+ self.model, get_global_server_args().torchao_config
406
437
  )
407
438
 
408
439
  # Apply torch TP if the model supports it
@@ -482,6 +513,16 @@ class ModelRunner:
482
513
  )
483
514
  server_args.attention_backend = "torch_native"
484
515
 
516
+ if (
517
+ server_args.attention_backend == "intel_xpu"
518
+ and server_args.device == "xpu"
519
+ and not _is_xpu_xmx_available
520
+ ):
521
+ logger.info(
522
+ "The current platform does not support Intel XMX, will fallback to triton backend."
523
+ )
524
+ server_args.attention_backend = "triton"
525
+
485
526
  if server_args.prefill_attention_backend is not None and (
486
527
  server_args.prefill_attention_backend
487
528
  == server_args.decode_attention_backend
@@ -547,8 +588,9 @@ class ModelRunner:
547
588
  server_args.attention_backend = "ascend"
548
589
  else:
549
590
  server_args.attention_backend = "triton"
550
- logger.info(
551
- f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
591
+ log_info_on_rank0(
592
+ logger,
593
+ f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
552
594
  )
553
595
  elif self.use_mla_backend:
554
596
  if server_args.device != "cpu":
@@ -591,11 +633,15 @@ class ModelRunner:
591
633
  f"{self.model_config.hf_config.model_type}"
592
634
  )
593
635
 
594
- if not self.use_mla_backend:
636
+ if (
637
+ not self.use_mla_backend
638
+ or server_args.attention_backend
639
+ not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
640
+ ):
595
641
  server_args.disable_chunked_prefix_cache = True
596
642
 
597
643
  if not server_args.disable_chunked_prefix_cache:
598
- logger.info("Chunked prefix cache is turned on.")
644
+ log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
599
645
 
600
646
  if server_args.attention_backend == "aiter":
601
647
  if self.model_config.context_len > 8192:
@@ -622,6 +668,35 @@ class ModelRunner:
622
668
  "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
623
669
  )
624
670
 
671
+ if self.model_config.hf_config.model_type == "qwen3_vl_moe":
672
+ if (
673
+ quantization_config := getattr(
674
+ self.model_config.hf_config, "quantization_config", None
675
+ )
676
+ ) is not None:
677
+ weight_block_size_n = quantization_config["weight_block_size"][0]
678
+
679
+ if self.tp_size % self.moe_ep_size != 0:
680
+ raise ValueError(
681
+ f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
682
+ )
683
+ moe_tp_size = self.tp_size // self.moe_ep_size
684
+
685
+ moe_intermediate_size = (
686
+ self.model_config.hf_text_config.moe_intermediate_size
687
+ )
688
+ if moe_intermediate_size % moe_tp_size != 0:
689
+ raise ValueError(
690
+ f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
691
+ )
692
+
693
+ if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
694
+ raise ValueError(
695
+ f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
696
+ f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
697
+ f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
698
+ )
699
+
625
700
  def init_torch_distributed(self):
626
701
  logger.info("Init torch distributed begin.")
627
702
 
@@ -634,7 +709,18 @@ class ModelRunner:
634
709
  raise
635
710
 
636
711
  if self.device == "cuda":
637
- backend = "nccl"
712
+ if self.server_args.elastic_ep_backend == "mooncake":
713
+ backend = "mooncake"
714
+ if self.server_args.mooncake_ib_device:
715
+ mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
716
+ try:
717
+ from mooncake import ep as mooncake_ep
718
+
719
+ mooncake_ep.set_device_filter(mooncake_ib_device)
720
+ except:
721
+ pass # A warning will be raised in `init_distributed_environment`
722
+ else:
723
+ backend = "nccl"
638
724
  elif self.device == "xpu":
639
725
  backend = "xccl"
640
726
  elif self.device == "hpu":
@@ -689,6 +775,7 @@ class ModelRunner:
689
775
  pipeline_model_parallel_size=self.pp_size,
690
776
  expert_model_parallel_size=self.moe_ep_size,
691
777
  duplicate_tp_group=self.server_args.enable_pdmux,
778
+ torch_compile=self.server_args.enable_piecewise_cuda_graph,
692
779
  )
693
780
  initialize_dp_attention(
694
781
  server_args=self.server_args,
@@ -747,6 +834,16 @@ class ModelRunner:
747
834
  set_cuda_arch()
748
835
 
749
836
  # Prepare the model config
837
+ from sglang.srt.configs.modelopt_config import ModelOptConfig
838
+
839
+ modelopt_config = ModelOptConfig(
840
+ quant=self.server_args.modelopt_quant,
841
+ checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
842
+ checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
843
+ export_path=self.server_args.modelopt_export_path,
844
+ quantize_and_serve=self.server_args.quantize_and_serve,
845
+ )
846
+
750
847
  self.load_config = LoadConfig(
751
848
  load_format=self.server_args.load_format,
752
849
  download_dir=self.server_args.download_dir,
@@ -755,6 +852,7 @@ class ModelRunner:
755
852
  remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
756
853
  remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
757
854
  remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
855
+ modelopt_config=modelopt_config,
758
856
  )
759
857
  if self.device == "cpu":
760
858
  self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -841,33 +939,56 @@ class ModelRunner:
841
939
  f"mem usage={self.weight_load_mem_usage:.2f} GB."
842
940
  )
843
941
 
844
- # Handle the case where some ranks do not finish loading.
845
- try:
846
- dist.monitored_barrier(
847
- group=get_tp_group().cpu_group,
848
- timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
849
- wait_all_ranks=True,
850
- )
851
- except RuntimeError:
852
- raise ValueError(
853
- f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
854
- ) from None
942
+ if self.server_args.elastic_ep_backend == "mooncake":
943
+ # Mooncake does not support `monitored_barrier`
944
+ dist.barrier(group=get_tp_group().cpu_group)
945
+ else:
946
+ # Handle the case where some ranks do not finish loading.
947
+ try:
948
+ dist.monitored_barrier(
949
+ group=get_tp_group().cpu_group,
950
+ timeout=datetime.timedelta(
951
+ seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
952
+ ),
953
+ wait_all_ranks=True,
954
+ )
955
+ except RuntimeError:
956
+ raise ValueError(
957
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
958
+ ) from None
855
959
 
856
960
  def update_expert_location(
857
961
  self,
858
962
  new_expert_location_metadata: ExpertLocationMetadata,
859
963
  update_layer_ids: List[int],
860
964
  ):
861
- self.expert_location_updater.update(
862
- self.model.routed_experts_weights_of_layer,
863
- new_expert_location_metadata,
864
- update_layer_ids=update_layer_ids,
865
- nnodes=self.server_args.nnodes,
866
- rank=self.tp_rank,
867
- )
965
+ if ElasticEPStateManager.instance() is not None:
966
+ # TODO: refactor the weights update when elastic ep
967
+ old_expert_location_metadata = get_global_expert_location_metadata()
968
+ assert old_expert_location_metadata is not None
969
+ old_expert_location_metadata.update(
970
+ new_expert_location_metadata,
971
+ update_layer_ids=update_layer_ids,
972
+ )
973
+ self.update_weights_from_disk(
974
+ self.server_args.model_path,
975
+ self.server_args.load_format,
976
+ lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
977
+ )
978
+ else:
979
+ self.expert_location_updater.update(
980
+ self.model.routed_experts_weights_of_layer,
981
+ new_expert_location_metadata,
982
+ update_layer_ids=update_layer_ids,
983
+ nnodes=self.server_args.nnodes,
984
+ rank=self.tp_rank,
985
+ )
868
986
 
869
987
  def update_weights_from_disk(
870
- self, model_path: str, load_format: str
988
+ self,
989
+ model_path: str,
990
+ load_format: str,
991
+ weight_name_filter: Optional[Callable[[str], bool]] = None,
871
992
  ) -> tuple[bool, str]:
872
993
  """Update engine weights in-place from the disk."""
873
994
  logger.info(
@@ -880,7 +1001,7 @@ class ModelRunner:
880
1001
  load_config = LoadConfig(load_format=load_format)
881
1002
 
882
1003
  # Only support DefaultModelLoader for now
883
- loader = get_model_loader(load_config)
1004
+ loader = get_model_loader(load_config, self.model_config)
884
1005
  if not isinstance(loader, DefaultModelLoader):
885
1006
  message = f"Failed to get model loader: {loader}."
886
1007
  return False, message
@@ -889,6 +1010,11 @@ class ModelRunner:
889
1010
  iter = loader._get_weights_iterator(
890
1011
  DefaultModelLoader.Source.init_new(config, self.model)
891
1012
  )
1013
+ if weight_name_filter is not None:
1014
+ iter = (
1015
+ (name, weight) for name, weight in iter if weight_name_filter(name)
1016
+ )
1017
+
892
1018
  return iter
893
1019
 
894
1020
  def model_load_weights(model, iter):
@@ -1267,8 +1393,8 @@ class ModelRunner:
1267
1393
  "num_nextn_predict_layers",
1268
1394
  self.num_effective_layers,
1269
1395
  )
1270
- elif self.is_hybrid_gdn:
1271
- num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1396
+ elif config := self.mambaish_config:
1397
+ num_layers = len(config.full_attention_layer_ids)
1272
1398
  else:
1273
1399
  num_layers = self.num_effective_layers
1274
1400
  if self.use_mla_backend:
@@ -1277,6 +1403,17 @@ class ModelRunner:
1277
1403
  * num_layers
1278
1404
  * torch._utils._element_size(self.kv_cache_dtype)
1279
1405
  )
1406
+ # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
1407
+ if is_deepseek_nsa(self.model_config.hf_config):
1408
+ index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
1409
+ indexer_size_per_token = (
1410
+ index_head_dim
1411
+ + index_head_dim // NSATokenToKVPool.quant_block_size * 4
1412
+ )
1413
+ element_size = torch._utils._element_size(
1414
+ NSATokenToKVPool.index_k_with_scale_buffer_dtype
1415
+ )
1416
+ cell_size += indexer_size_per_token * num_layers * element_size
1280
1417
  else:
1281
1418
  cell_size = (
1282
1419
  self.model_config.get_num_kv_heads(get_attention_tp_size())
@@ -1288,22 +1425,77 @@ class ModelRunner:
1288
1425
  rest_memory = available_gpu_memory - total_gpu_memory * (
1289
1426
  1 - self.mem_fraction_static
1290
1427
  )
1291
- if self.is_hybrid_gdn:
1292
- rest_memory -= (
1293
- self.server_args.max_mamba_cache_size
1294
- * self.model_config.hf_config.mamba_cache_per_req
1295
- / (1 << 30)
1296
- )
1428
+ if self.mambaish_config is not None:
1429
+ rest_memory = self.handle_max_mamba_cache(rest_memory)
1297
1430
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1298
1431
  return max_num_token
1299
1432
 
1433
+ def handle_max_mamba_cache(self, total_rest_memory):
1434
+ config = self.mambaish_config
1435
+ server_args = self.server_args
1436
+ assert config is not None
1437
+
1438
+ speculativa_ratio = (
1439
+ 0
1440
+ if server_args.speculative_num_draft_tokens is None
1441
+ else server_args.speculative_num_draft_tokens
1442
+ )
1443
+ if (
1444
+ server_args.disable_radix_cache
1445
+ or config.mamba2_cache_params.mamba_cache_per_req == 0
1446
+ ):
1447
+ # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
1448
+ if server_args.max_mamba_cache_size is None:
1449
+ if server_args.max_running_requests is not None:
1450
+ server_args.max_mamba_cache_size = server_args.max_running_requests
1451
+ else:
1452
+ server_args.max_mamba_cache_size = 512
1453
+ else:
1454
+ # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
1455
+ # solve the equations:
1456
+ # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
1457
+ # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
1458
+ mamba_state_memory_raw = (
1459
+ total_rest_memory
1460
+ * server_args.mamba_full_memory_ratio
1461
+ / (1 + server_args.mamba_full_memory_ratio)
1462
+ )
1463
+ # calculate the max_mamba_cache_size based on the given total mamba memory
1464
+ server_args.max_mamba_cache_size = int(
1465
+ (mamba_state_memory_raw * (1 << 30))
1466
+ // config.mamba2_cache_params.mamba_cache_per_req
1467
+ // (1 + speculativa_ratio)
1468
+ )
1469
+
1470
+ if self.hybrid_gdn_config is not None:
1471
+ server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
1472
+ server_args.dp_size if server_args.enable_dp_attention else 1
1473
+ )
1474
+ mamba_state_memory = (
1475
+ server_args.max_mamba_cache_size
1476
+ * config.mamba2_cache_params.mamba_cache_per_req
1477
+ * (1 + speculativa_ratio)
1478
+ / (1 << 30)
1479
+ )
1480
+ return total_rest_memory - mamba_state_memory
1481
+
1300
1482
  @property
1301
- def is_hybrid_gdn(self):
1302
- return self.model_config.hf_config.architectures[0] in [
1303
- "Qwen3NextForCausalLM",
1304
- "Qwen3NextForCausalLMMTP",
1305
- "FalconH1ForCausalLM",
1306
- ]
1483
+ def hybrid_gdn_config(self):
1484
+ config = self.model_config.hf_config
1485
+ if isinstance(config, Qwen3NextConfig):
1486
+ return config
1487
+ return None
1488
+
1489
+ @property
1490
+ def mamba2_config(self):
1491
+ config = self.model_config.hf_config
1492
+ if isinstance(config, FalconH1Config | NemotronHConfig):
1493
+ return config
1494
+ return None
1495
+
1496
+ @property
1497
+ def mambaish_config(self):
1498
+ return self.mamba2_config or self.hybrid_gdn_config
1307
1499
 
1308
1500
  def set_num_token_hybrid(self):
1309
1501
  if (
@@ -1387,6 +1579,27 @@ class ModelRunner:
1387
1579
  f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
1388
1580
  )
1389
1581
 
1582
+ def can_run_piecewise_cuda_graph(self):
1583
+ if self.server_args.disable_cuda_graph:
1584
+ log_info_on_rank0(
1585
+ logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
1586
+ )
1587
+ return False
1588
+ if self.server_args.enable_torch_compile:
1589
+ log_info_on_rank0(
1590
+ logger,
1591
+ "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
1592
+ )
1593
+ return False
1594
+ if self.pp_size > 1:
1595
+ # TODO(yuwei): support PP
1596
+ log_info_on_rank0(
1597
+ logger,
1598
+ "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
1599
+ )
1600
+ return False
1601
+ return True
1602
+
1390
1603
  def init_memory_pool(
1391
1604
  self,
1392
1605
  total_gpu_memory: int,
@@ -1417,6 +1630,8 @@ class ModelRunner:
1417
1630
  self.kv_cache_dtype = torch.float8_e4m3fnuz
1418
1631
  else:
1419
1632
  self.kv_cache_dtype = torch.float8_e4m3fn
1633
+ elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
1634
+ self.kv_cache_dtype = torch.bfloat16
1420
1635
  else:
1421
1636
  raise ValueError(
1422
1637
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -1438,8 +1653,16 @@ class ModelRunner:
1438
1653
  ),
1439
1654
  4096,
1440
1655
  )
1441
- if self.is_hybrid_gdn:
1442
- max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1656
+
1657
+ if self.mambaish_config is not None:
1658
+ ratio = (
1659
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
1660
+ if not self.server_args.disable_radix_cache
1661
+ else 1
1662
+ )
1663
+ max_num_reqs = min(
1664
+ max_num_reqs, self.server_args.max_mamba_cache_size // ratio
1665
+ )
1443
1666
 
1444
1667
  if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1445
1668
  if self.is_draft_worker:
@@ -1506,39 +1729,43 @@ class ModelRunner:
1506
1729
  extra_max_context_len += self.server_args.speculative_num_draft_tokens
1507
1730
 
1508
1731
  if self.server_args.disaggregation_mode == "decode":
1509
- from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
1732
+ from sglang.srt.disaggregation.decode import (
1733
+ DecodeReqToTokenPool,
1734
+ HybridMambaDecodeReqToTokenPool,
1735
+ )
1510
1736
 
1511
1737
  # subscribe memory for pre-allocated requests
1512
1738
  # if max_num_reqs <= 32, we pre-allocate 2x requests
1513
1739
  pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1514
- self.req_to_token_pool = DecodeReqToTokenPool(
1515
- size=max_num_reqs,
1516
- max_context_len=self.model_config.context_len
1517
- + extra_max_context_len,
1518
- device=self.device,
1519
- enable_memory_saver=self.server_args.enable_memory_saver,
1520
- pre_alloc_size=pre_alloc_size,
1521
- )
1522
- elif self.is_hybrid_gdn:
1523
- config = self.model_config.hf_config
1524
- (
1525
- conv_state_shape,
1526
- temporal_state_shape,
1527
- conv_dtype,
1528
- ssm_dtype,
1529
- mamba_layers,
1530
- ) = config.hybrid_gdn_params
1740
+ if config := self.mambaish_config:
1741
+ self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
1742
+ size=max_num_reqs,
1743
+ max_context_len=self.model_config.context_len
1744
+ + extra_max_context_len,
1745
+ device=self.device,
1746
+ enable_memory_saver=self.server_args.enable_memory_saver,
1747
+ cache_params=config.mamba2_cache_params,
1748
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1749
+ pre_alloc_size=pre_alloc_size,
1750
+ )
1751
+ else:
1752
+ self.req_to_token_pool = DecodeReqToTokenPool(
1753
+ size=max_num_reqs,
1754
+ max_context_len=self.model_config.context_len
1755
+ + extra_max_context_len,
1756
+ device=self.device,
1757
+ enable_memory_saver=self.server_args.enable_memory_saver,
1758
+ pre_alloc_size=pre_alloc_size,
1759
+ )
1760
+ elif config := self.mambaish_config:
1531
1761
  self.req_to_token_pool = HybridReqToTokenPool(
1532
1762
  size=max_num_reqs,
1763
+ mamba_size=self.server_args.max_mamba_cache_size,
1533
1764
  max_context_len=self.model_config.context_len
1534
1765
  + extra_max_context_len,
1535
1766
  device=self.device,
1536
1767
  enable_memory_saver=self.server_args.enable_memory_saver,
1537
- conv_state_shape=conv_state_shape,
1538
- temporal_state_shape=temporal_state_shape,
1539
- conv_dtype=conv_dtype,
1540
- ssm_dtype=ssm_dtype,
1541
- mamba_layers=mamba_layers,
1768
+ cache_params=config.mamba2_cache_params,
1542
1769
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1543
1770
  )
1544
1771
  else:
@@ -1640,7 +1867,7 @@ class ModelRunner:
1640
1867
  enable_kvcache_transpose=False,
1641
1868
  device=self.device,
1642
1869
  )
1643
- elif self.is_hybrid_gdn:
1870
+ elif config := self.mambaish_config:
1644
1871
  self.token_to_kv_pool = HybridLinearKVPool(
1645
1872
  page_size=self.page_size,
1646
1873
  size=self.max_total_num_tokens,
@@ -1651,12 +1878,11 @@ class ModelRunner:
1651
1878
  head_dim=self.model_config.head_dim,
1652
1879
  # if draft worker, we only need 1 attention layer's kv pool
1653
1880
  full_attention_layer_ids=(
1654
- [0]
1655
- if self.is_draft_worker
1656
- else self.model_config.hf_config.full_attention_layer_ids
1881
+ [0] if self.is_draft_worker else config.full_attention_layer_ids
1657
1882
  ),
1658
1883
  enable_kvcache_transpose=False,
1659
1884
  device=self.device,
1885
+ mamba_pool=self.req_to_token_pool.mamba_pool,
1660
1886
  )
1661
1887
  else:
1662
1888
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -1672,13 +1898,17 @@ class ModelRunner:
1672
1898
  enable_memory_saver=self.server_args.enable_memory_saver,
1673
1899
  start_layer=self.start_layer,
1674
1900
  end_layer=self.end_layer,
1901
+ enable_kv_cache_copy=(
1902
+ self.server_args.speculative_algorithm is not None
1903
+ ),
1675
1904
  )
1676
1905
 
1677
1906
  # Initialize token_to_kv_pool_allocator
1678
1907
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1679
1908
  if self.token_to_kv_pool_allocator is None:
1680
1909
  if _is_npu and (
1681
- self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
1910
+ self.server_args.attention_backend == "ascend"
1911
+ or self.hybrid_gdn_config is not None
1682
1912
  ):
1683
1913
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1684
1914
  self.max_total_num_tokens,
@@ -1743,16 +1973,10 @@ class ModelRunner:
1743
1973
 
1744
1974
  def _get_attention_backend(self):
1745
1975
  """Init attention kernel backend."""
1746
- self.decode_attention_backend_str = (
1747
- self.server_args.decode_attention_backend
1748
- if self.server_args.decode_attention_backend
1749
- else self.server_args.attention_backend
1750
- )
1751
- self.prefill_attention_backend_str = (
1752
- self.server_args.prefill_attention_backend
1753
- if self.server_args.prefill_attention_backend
1754
- else self.server_args.attention_backend
1976
+ self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1977
+ self.server_args.get_attention_backends()
1755
1978
  )
1979
+
1756
1980
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1757
1981
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1758
1982
  HybridAttnBackend,
@@ -1781,12 +2005,10 @@ class ModelRunner:
1781
2005
  self.server_args.attention_backend
1782
2006
  )
1783
2007
 
1784
- global_server_args_dict.update(
1785
- {
1786
- "decode_attention_backend": self.decode_attention_backend_str,
1787
- "prefill_attention_backend": self.prefill_attention_backend_str,
1788
- }
1789
- )
2008
+ (
2009
+ get_global_server_args().prefill_attention_backend,
2010
+ get_global_server_args().decode_attention_backend,
2011
+ ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1790
2012
  return attn_backend
1791
2013
 
1792
2014
  def _get_attention_backend_from_str(self, backend_str: str):
@@ -1924,6 +2146,11 @@ class ModelRunner:
1924
2146
  kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1925
2147
  if not self.is_generation:
1926
2148
  kwargs["get_embedding"] = True
2149
+
2150
+ if self.piecewise_cuda_graph_runner is not None:
2151
+ if self.piecewise_cuda_graph_runner.can_run(forward_batch):
2152
+ return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
2153
+
1927
2154
  return self.model.forward(
1928
2155
  forward_batch.input_ids,
1929
2156
  forward_batch.positions,
@@ -2057,15 +2284,11 @@ class ModelRunner:
2057
2284
  def _preprocess_logits(
2058
2285
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
2059
2286
  ):
2060
- # Apply logit bias
2061
- if sampling_info.sampling_info_done:
2062
- # Overlap mode: the function update_regex_vocab_mask was executed
2063
- # in process_batch_result of the last batch.
2064
- if sampling_info.grammars:
2065
- sampling_info.sampling_info_done.wait()
2066
- else:
2067
- # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
2068
- sampling_info.update_regex_vocab_mask()
2287
+ # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
2288
+ # was executed after we processed last batch's results.
2289
+
2290
+ # Calculate logits bias and apply it to next_token_logits.
2291
+ sampling_info.update_regex_vocab_mask()
2069
2292
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
2070
2293
 
2071
2294
  def sample(
@@ -2164,6 +2387,23 @@ class ModelRunner:
2164
2387
  )
2165
2388
  ShardedStateLoader.save_model(self.model, path, pattern, max_size)
2166
2389
 
2390
+ def update_weights_from_ipc(self, recv_req):
2391
+ """Update weights from IPC for checkpoint-engine integration."""
2392
+ try:
2393
+ from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
2394
+ SGLangCheckpointEngineWorkerExtensionImpl,
2395
+ )
2396
+
2397
+ # Create a worker extension that integrates with SGLang's model
2398
+ worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
2399
+ worker.update_weights_from_ipc(recv_req.zmq_handles)
2400
+ return True, "IPC weight update completed successfully"
2401
+ except ImportError as e:
2402
+ return False, f"IPC weight update failed: ImportError {e}"
2403
+ except Exception as e:
2404
+ logger.error(f"IPC weight update failed: {e}")
2405
+ return False, str(e)
2406
+
2167
2407
 
2168
2408
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
2169
2409
  params_dict = dict(model.named_parameters())