sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -24,11 +24,12 @@ import threading
24
24
  import time
25
25
  from collections import defaultdict
26
26
  from dataclasses import dataclass
27
- from typing import List, Optional, Tuple, Union
27
+ from typing import Callable, List, Optional, Tuple, Union
28
28
 
29
29
  import torch
30
30
  import torch.distributed as dist
31
31
 
32
+ from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
32
33
  from sglang.srt.configs.device_config import DeviceConfig
33
34
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
34
35
  from sglang.srt.configs.model_config import (
@@ -50,6 +51,7 @@ from sglang.srt.distributed import (
50
51
  set_symm_mem_all_reduce,
51
52
  )
52
53
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
54
+ from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
53
55
  from sglang.srt.eplb.eplb_manager import EPLBManager
54
56
  from sglang.srt.eplb.expert_distribution import (
55
57
  ExpertDistributionRecorder,
@@ -63,6 +65,7 @@ from sglang.srt.eplb.expert_location import (
63
65
  set_global_expert_location_metadata,
64
66
  )
65
67
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
68
+ from sglang.srt.layers import deep_gemm_wrapper
66
69
  from sglang.srt.layers.attention.attention_registry import (
67
70
  ATTENTION_BACKENDS,
68
71
  attn_backend_wrapper,
@@ -74,18 +77,11 @@ from sglang.srt.layers.dp_attention import (
74
77
  initialize_dp_attention,
75
78
  )
76
79
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
77
- from sglang.srt.layers.quantization import (
78
- deep_gemm_wrapper,
79
- monkey_patch_isinstance_for_vllm_base_layer,
80
- )
80
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
81
81
  from sglang.srt.layers.sampler import Sampler
82
82
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
83
83
  from sglang.srt.lora.lora_manager import LoRAManager
84
84
  from sglang.srt.lora.lora_registry import LoRARef
85
- from sglang.srt.managers.schedule_batch import (
86
- GLOBAL_SERVER_ARGS_KEYS,
87
- global_server_args_dict,
88
- )
89
85
  from sglang.srt.mem_cache.allocator import (
90
86
  BaseTokenToKVPoolAllocator,
91
87
  PagedTokenToKVPoolAllocator,
@@ -109,6 +105,9 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
109
105
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
110
106
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
111
107
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
108
+ from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
109
+ PiecewiseCudaGraphRunner,
110
+ )
112
111
  from sglang.srt.model_loader import get_model
113
112
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
114
113
  from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
@@ -116,15 +115,13 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
116
115
  )
117
116
  from sglang.srt.model_loader.utils import set_default_torch_dtype
118
117
  from sglang.srt.model_loader.weight_utils import default_weight_loader
119
- from sglang.srt.offloader import (
120
- create_offloader_from_server_args,
121
- get_offloader,
122
- set_offloader,
123
- )
124
118
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
125
- from sglang.srt.server_args import ServerArgs
119
+ from sglang.srt.server_args import (
120
+ ServerArgs,
121
+ get_global_server_args,
122
+ set_global_server_args_for_scheduler,
123
+ )
126
124
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
127
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
128
125
  from sglang.srt.utils import (
129
126
  MultiprocessingSerializer,
130
127
  cpu_has_amx_support,
@@ -134,20 +131,21 @@ from sglang.srt.utils import (
134
131
  get_bool_env_var,
135
132
  get_cpu_ids_by_node,
136
133
  init_custom_process_group,
137
- is_fa3_default_architecture,
138
- is_flashinfer_available,
139
134
  is_hip,
140
- is_hopper_with_cuda_12_3,
141
- is_no_spec_infer_or_topk_one,
142
135
  is_npu,
143
- is_sm100_supported,
144
136
  log_info_on_rank0,
145
137
  monkey_patch_p2p_access_check,
146
- monkey_patch_vllm_gguf_config,
147
138
  set_cuda_arch,
148
139
  slow_rank_detector,
140
+ xpu_has_xmx_support,
141
+ )
142
+ from sglang.srt.utils.offloader import (
143
+ create_offloader_from_server_args,
144
+ get_offloader,
145
+ set_offloader,
149
146
  )
150
147
  from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
148
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
151
149
  from sglang.srt.weight_sync.tensor_bucket import (
152
150
  FlattenedTensorBucket,
153
151
  FlattenedTensorMetadata,
@@ -166,6 +164,15 @@ MLA_ATTENTION_BACKENDS = [
166
164
  "nsa",
167
165
  ]
168
166
 
167
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
168
+ "flashinfer",
169
+ "fa3",
170
+ "fa4",
171
+ "flashmla",
172
+ "cutlass_mla",
173
+ "trtllm_mla",
174
+ ]
175
+
169
176
 
170
177
  def add_mla_attention_backend(backend_name):
171
178
  if backend_name not in MLA_ATTENTION_BACKENDS:
@@ -173,9 +180,18 @@ def add_mla_attention_backend(backend_name):
173
180
  logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
174
181
 
175
182
 
183
+ def add_chunked_prefix_cache_attention_backend(backend_name):
184
+ if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
185
+ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
186
+ logger.info(
187
+ f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
188
+ )
189
+
190
+
176
191
  _is_hip = is_hip()
177
192
  _is_npu = is_npu()
178
193
  _is_cpu_amx_available = cpu_has_amx_support()
194
+ _is_xpu_xmx_available = xpu_has_xmx_support()
179
195
 
180
196
  # Use a small KV cache pool size for tests in CI
181
197
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -183,8 +199,10 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
183
199
  # Detect stragger ranks in model loading
184
200
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
185
201
 
186
- logger = logging.getLogger(__name__)
202
+ # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
203
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
187
204
 
205
+ logger = logging.getLogger(__name__)
188
206
 
189
207
  if _is_npu:
190
208
  import torch_npu
@@ -257,25 +275,21 @@ class ModelRunner:
257
275
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
258
276
  self.attention_chunk_size = model_config.attention_chunk_size
259
277
  self.forward_pass_id = 0
278
+ self.init_new_workspace = False
260
279
 
261
280
  # Apply the rank zero filter to logger
262
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
263
- logger.addFilter(RankZeroFilter(tp_rank == 0))
264
281
  if server_args.show_time_cost:
265
282
  enable_show_time_cost()
266
283
 
267
284
  # Model-specific adjustment
268
285
  self.model_specific_adjustment()
269
286
 
270
- # Global vars
271
- global_server_args_dict.update(
272
- {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
273
- | {
274
- # TODO it is indeed not a "server args"
275
- "use_mla_backend": self.use_mla_backend,
276
- "speculative_algorithm": self.spec_algorithm,
277
- }
278
- )
287
+ # Set the global server_args in the scheduler process
288
+ set_global_server_args_for_scheduler(server_args)
289
+ global_server_args = get_global_server_args()
290
+
291
+ # FIXME: hacky set `use_mla_backend`
292
+ global_server_args.use_mla_backend = self.use_mla_backend
279
293
 
280
294
  # Init OpenMP threads binding for CPU
281
295
  if self.device == "cpu":
@@ -306,6 +320,26 @@ class ModelRunner:
306
320
  self._model_update_group = {}
307
321
  self._weights_send_group = {}
308
322
 
323
+ if (
324
+ self.server_args.enable_piecewise_cuda_graph
325
+ and self.can_run_piecewise_cuda_graph()
326
+ ):
327
+ self.attention_layers = []
328
+ for layer in self.model.model.layers:
329
+ if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
330
+ self.attention_layers.append(layer.self_attn.attn)
331
+ if len(self.attention_layers) < self.model_config.num_hidden_layers:
332
+ # TODO(yuwei): support Non-Standard GQA
333
+ log_info_on_rank0(
334
+ logger,
335
+ "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
336
+ )
337
+ self.piecewise_cuda_graph_runner = None
338
+ else:
339
+ self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
340
+ else:
341
+ self.piecewise_cuda_graph_runner = None
342
+
309
343
  def initialize(self, min_per_gpu_memory: float):
310
344
  server_args = self.server_args
311
345
 
@@ -340,6 +374,11 @@ class ModelRunner:
340
374
  )
341
375
  self.expert_location_updater = ExpertLocationUpdater()
342
376
 
377
+ (
378
+ ElasticEPStateManager.init(self.server_args)
379
+ if self.server_args.elastic_ep_backend
380
+ else None
381
+ )
343
382
  # Load the model
344
383
  self.sampler = Sampler()
345
384
  self.load_model()
@@ -354,24 +393,10 @@ class ModelRunner:
354
393
  if architectures and not any("Llama4" in arch for arch in architectures):
355
394
  self.is_hybrid = self.model_config.is_hybrid = True
356
395
 
357
- if self.is_hybrid_gdn:
358
- logger.warning("Hybrid GDN model detected, disable radix cache")
396
+ if config := self.mamba2_config:
397
+ class_name = config.__class__.__name__
398
+ logger.warning(f"{class_name} model detected, disable radix cache")
359
399
  self.server_args.disable_radix_cache = True
360
- if self.server_args.max_mamba_cache_size is None:
361
- if self.server_args.max_running_requests is not None:
362
- self.server_args.max_mamba_cache_size = (
363
- self.server_args.max_running_requests
364
- )
365
- else:
366
- self.server_args.max_mamba_cache_size = 512
367
- self.server_args.max_mamba_cache_size = (
368
- self.server_args.max_mamba_cache_size
369
- // (
370
- self.server_args.dp_size
371
- if self.server_args.enable_dp_attention
372
- else 1
373
- )
374
- )
375
400
 
376
401
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
377
402
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
@@ -402,7 +427,7 @@ class ModelRunner:
402
427
  # In layered loading, torchao may have been applied
403
428
  if not torchao_applied:
404
429
  apply_torchao_config_to_model(
405
- self.model, global_server_args_dict["torchao_config"]
430
+ self.model, get_global_server_args().torchao_config
406
431
  )
407
432
 
408
433
  # Apply torch TP if the model supports it
@@ -472,110 +497,6 @@ class ModelRunner:
472
497
  def model_specific_adjustment(self):
473
498
  server_args = self.server_args
474
499
 
475
- if (
476
- server_args.attention_backend == "intel_amx"
477
- and server_args.device == "cpu"
478
- and not _is_cpu_amx_available
479
- ):
480
- logger.info(
481
- "The current platform does not support Intel AMX, will fallback to torch_native backend."
482
- )
483
- server_args.attention_backend = "torch_native"
484
-
485
- if server_args.prefill_attention_backend is not None and (
486
- server_args.prefill_attention_backend
487
- == server_args.decode_attention_backend
488
- ): # override the default attention backend
489
- server_args.attention_backend = server_args.prefill_attention_backend
490
-
491
- if (
492
- getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
493
- is not None
494
- ):
495
- if server_args.attention_backend is None:
496
- server_args.attention_backend = "dual_chunk_flash_attn"
497
- logger.info("Dual chunk attention is turned on by default.")
498
- elif server_args.attention_backend != "dual_chunk_flash_attn":
499
- raise ValueError(
500
- "Dual chunk attention is enabled, but attention backend is set to "
501
- f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
502
- )
503
-
504
- if server_args.attention_backend is None:
505
- """
506
- Auto select the fastest attention backend.
507
-
508
- 1. Models with MHA Architecture (e.g: Llama, QWen)
509
- 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
510
- 1.2 In other cases, we will use flashinfer if available, otherwise use triton.
511
- 2. Models with MLA Architecture and using FA3
512
- 2.1 We will use FA3 backend on hopper.
513
- 2.2 We will use Flashinfer backend on blackwell.
514
- 2.3 Otherwise, we will use triton backend.
515
- """
516
-
517
- if not self.use_mla_backend:
518
- # MHA architecture
519
- if (
520
- is_hopper_with_cuda_12_3()
521
- and is_no_spec_infer_or_topk_one(server_args)
522
- and is_fa3_default_architecture(self.model_config.hf_config)
523
- ):
524
- server_args.attention_backend = "fa3"
525
- elif _is_hip:
526
- server_args.attention_backend = "aiter"
527
- elif _is_npu:
528
- server_args.attention_backend = "ascend"
529
- else:
530
- server_args.attention_backend = (
531
- "flashinfer" if is_flashinfer_available() else "triton"
532
- )
533
- else:
534
- # MLA architecture
535
- if is_hopper_with_cuda_12_3():
536
- server_args.attention_backend = "fa3"
537
- elif is_sm100_supported():
538
- server_args.attention_backend = "flashinfer"
539
- elif _is_hip:
540
- head_num = self.model_config.get_num_kv_heads(self.tp_size)
541
- # TODO current aiter only support head number 16 or 128 head number
542
- if head_num == 128 or head_num == 16:
543
- server_args.attention_backend = "aiter"
544
- else:
545
- server_args.attention_backend = "triton"
546
- elif _is_npu:
547
- server_args.attention_backend = "ascend"
548
- else:
549
- server_args.attention_backend = "triton"
550
- logger.info(
551
- f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
552
- )
553
- elif self.use_mla_backend:
554
- if server_args.device != "cpu":
555
- if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
556
- logger.info(
557
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
558
- )
559
- else:
560
- raise ValueError(
561
- f"Invalid attention backend for MLA: {server_args.attention_backend}"
562
- )
563
- else:
564
- if server_args.attention_backend != "intel_amx":
565
- raise ValueError(
566
- "MLA optimization not supported on CPU except for intel_amx backend."
567
- )
568
-
569
- if (
570
- server_args.attention_backend == "fa3"
571
- and server_args.kv_cache_dtype == "fp8_e5m2"
572
- ):
573
- logger.warning(
574
- "FlashAttention3 only supports fp8_e4m3 if using FP8; "
575
- "Setting attention backend to triton."
576
- )
577
- server_args.attention_backend = "triton"
578
-
579
500
  if server_args.enable_double_sparsity:
580
501
  logger.info(
581
502
  "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -591,36 +512,44 @@ class ModelRunner:
591
512
  f"{self.model_config.hf_config.model_type}"
592
513
  )
593
514
 
594
- if not self.use_mla_backend:
515
+ if (
516
+ not self.use_mla_backend
517
+ or server_args.attention_backend
518
+ not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
519
+ ):
595
520
  server_args.disable_chunked_prefix_cache = True
596
521
 
597
522
  if not server_args.disable_chunked_prefix_cache:
598
- logger.info("Chunked prefix cache is turned on.")
523
+ log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
599
524
 
600
- if server_args.attention_backend == "aiter":
601
- if self.model_config.context_len > 8192:
602
- self.mem_fraction_static *= 0.85
525
+ if self.model_config.hf_config.model_type == "qwen3_vl_moe":
526
+ if (
527
+ quantization_config := getattr(
528
+ self.model_config.hf_config, "quantization_config", None
529
+ )
530
+ ) is not None and "weight_block_size" in quantization_config:
531
+ weight_block_size_n = quantization_config["weight_block_size"][0]
603
532
 
604
- if (
605
- server_args.enable_hierarchical_cache
606
- and server_args.hicache_io_backend == "kernel"
607
- ):
608
- # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
609
- if server_args.decode_attention_backend is None:
610
- if not self.use_mla_backend:
611
- server_args.decode_attention_backend = (
612
- "flashinfer" if is_flashinfer_available() else "triton"
613
- )
614
- else:
615
- server_args.decode_attention_backend = (
616
- "flashinfer" if is_sm100_supported() else "triton"
533
+ if self.tp_size % self.moe_ep_size != 0:
534
+ raise ValueError(
535
+ f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
617
536
  )
618
- elif server_args.decode_attention_backend == "fa3":
619
- server_args.hicache_io_backend = "direct"
620
- logger.warning(
621
- "FlashAttention3 decode backend is not compatible with hierarchical cache. "
622
- "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
537
+ moe_tp_size = self.tp_size // self.moe_ep_size
538
+
539
+ moe_intermediate_size = (
540
+ self.model_config.hf_text_config.moe_intermediate_size
623
541
  )
542
+ if moe_intermediate_size % moe_tp_size != 0:
543
+ raise ValueError(
544
+ f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
545
+ )
546
+
547
+ if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
548
+ raise ValueError(
549
+ f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
550
+ f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
551
+ f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
552
+ )
624
553
 
625
554
  def init_torch_distributed(self):
626
555
  logger.info("Init torch distributed begin.")
@@ -634,7 +563,18 @@ class ModelRunner:
634
563
  raise
635
564
 
636
565
  if self.device == "cuda":
637
- backend = "nccl"
566
+ if self.server_args.elastic_ep_backend == "mooncake":
567
+ backend = "mooncake"
568
+ if self.server_args.mooncake_ib_device:
569
+ mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
570
+ try:
571
+ from mooncake import ep as mooncake_ep
572
+
573
+ mooncake_ep.set_device_filter(mooncake_ib_device)
574
+ except:
575
+ pass # A warning will be raised in `init_distributed_environment`
576
+ else:
577
+ backend = "nccl"
638
578
  elif self.device == "xpu":
639
579
  backend = "xccl"
640
580
  elif self.device == "hpu":
@@ -689,6 +629,7 @@ class ModelRunner:
689
629
  pipeline_model_parallel_size=self.pp_size,
690
630
  expert_model_parallel_size=self.moe_ep_size,
691
631
  duplicate_tp_group=self.server_args.enable_pdmux,
632
+ torch_compile=self.server_args.enable_piecewise_cuda_graph,
692
633
  )
693
634
  initialize_dp_attention(
694
635
  server_args=self.server_args,
@@ -747,6 +688,16 @@ class ModelRunner:
747
688
  set_cuda_arch()
748
689
 
749
690
  # Prepare the model config
691
+ from sglang.srt.configs.modelopt_config import ModelOptConfig
692
+
693
+ modelopt_config = ModelOptConfig(
694
+ quant=self.server_args.modelopt_quant,
695
+ checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
696
+ checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
697
+ export_path=self.server_args.modelopt_export_path,
698
+ quantize_and_serve=self.server_args.quantize_and_serve,
699
+ )
700
+
750
701
  self.load_config = LoadConfig(
751
702
  load_format=self.server_args.load_format,
752
703
  download_dir=self.server_args.download_dir,
@@ -755,13 +706,12 @@ class ModelRunner:
755
706
  remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
756
707
  remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
757
708
  remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
709
+ modelopt_config=modelopt_config,
758
710
  )
759
711
  if self.device == "cpu":
760
712
  self.model_config = adjust_config_with_unaligned_cpu_tp(
761
713
  self.model_config, self.load_config, self.tp_size
762
714
  )
763
- if self.server_args.load_format == "gguf":
764
- monkey_patch_vllm_gguf_config()
765
715
 
766
716
  if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
767
717
  if self.tp_rank == 0:
@@ -841,33 +791,56 @@ class ModelRunner:
841
791
  f"mem usage={self.weight_load_mem_usage:.2f} GB."
842
792
  )
843
793
 
844
- # Handle the case where some ranks do not finish loading.
845
- try:
846
- dist.monitored_barrier(
847
- group=get_tp_group().cpu_group,
848
- timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
849
- wait_all_ranks=True,
850
- )
851
- except RuntimeError:
852
- raise ValueError(
853
- f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
854
- ) from None
794
+ if self.server_args.elastic_ep_backend == "mooncake":
795
+ # Mooncake does not support `monitored_barrier`
796
+ dist.barrier(group=get_tp_group().cpu_group)
797
+ else:
798
+ # Handle the case where some ranks do not finish loading.
799
+ try:
800
+ dist.monitored_barrier(
801
+ group=get_tp_group().cpu_group,
802
+ timeout=datetime.timedelta(
803
+ seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
804
+ ),
805
+ wait_all_ranks=True,
806
+ )
807
+ except RuntimeError:
808
+ raise ValueError(
809
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
810
+ ) from None
855
811
 
856
812
  def update_expert_location(
857
813
  self,
858
814
  new_expert_location_metadata: ExpertLocationMetadata,
859
815
  update_layer_ids: List[int],
860
816
  ):
861
- self.expert_location_updater.update(
862
- self.model.routed_experts_weights_of_layer,
863
- new_expert_location_metadata,
864
- update_layer_ids=update_layer_ids,
865
- nnodes=self.server_args.nnodes,
866
- rank=self.tp_rank,
867
- )
817
+ if ElasticEPStateManager.instance() is not None:
818
+ # TODO: refactor the weights update when elastic ep
819
+ old_expert_location_metadata = get_global_expert_location_metadata()
820
+ assert old_expert_location_metadata is not None
821
+ old_expert_location_metadata.update(
822
+ new_expert_location_metadata,
823
+ update_layer_ids=update_layer_ids,
824
+ )
825
+ self.update_weights_from_disk(
826
+ self.server_args.model_path,
827
+ self.server_args.load_format,
828
+ lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
829
+ )
830
+ else:
831
+ self.expert_location_updater.update(
832
+ self.model.routed_experts_weights_of_layer,
833
+ new_expert_location_metadata,
834
+ update_layer_ids=update_layer_ids,
835
+ nnodes=self.server_args.nnodes,
836
+ rank=self.tp_rank,
837
+ )
868
838
 
869
839
  def update_weights_from_disk(
870
- self, model_path: str, load_format: str
840
+ self,
841
+ model_path: str,
842
+ load_format: str,
843
+ weight_name_filter: Optional[Callable[[str], bool]] = None,
871
844
  ) -> tuple[bool, str]:
872
845
  """Update engine weights in-place from the disk."""
873
846
  logger.info(
@@ -880,7 +853,7 @@ class ModelRunner:
880
853
  load_config = LoadConfig(load_format=load_format)
881
854
 
882
855
  # Only support DefaultModelLoader for now
883
- loader = get_model_loader(load_config)
856
+ loader = get_model_loader(load_config, self.model_config)
884
857
  if not isinstance(loader, DefaultModelLoader):
885
858
  message = f"Failed to get model loader: {loader}."
886
859
  return False, message
@@ -889,6 +862,11 @@ class ModelRunner:
889
862
  iter = loader._get_weights_iterator(
890
863
  DefaultModelLoader.Source.init_new(config, self.model)
891
864
  )
865
+ if weight_name_filter is not None:
866
+ iter = (
867
+ (name, weight) for name, weight in iter if weight_name_filter(name)
868
+ )
869
+
892
870
  return iter
893
871
 
894
872
  def model_load_weights(model, iter):
@@ -1267,8 +1245,8 @@ class ModelRunner:
1267
1245
  "num_nextn_predict_layers",
1268
1246
  self.num_effective_layers,
1269
1247
  )
1270
- elif self.is_hybrid_gdn:
1271
- num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1248
+ elif config := self.mambaish_config:
1249
+ num_layers = len(config.full_attention_layer_ids)
1272
1250
  else:
1273
1251
  num_layers = self.num_effective_layers
1274
1252
  if self.use_mla_backend:
@@ -1277,6 +1255,17 @@ class ModelRunner:
1277
1255
  * num_layers
1278
1256
  * torch._utils._element_size(self.kv_cache_dtype)
1279
1257
  )
1258
+ # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
1259
+ if is_deepseek_nsa(self.model_config.hf_config):
1260
+ index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
1261
+ indexer_size_per_token = (
1262
+ index_head_dim
1263
+ + index_head_dim // NSATokenToKVPool.quant_block_size * 4
1264
+ )
1265
+ element_size = torch._utils._element_size(
1266
+ NSATokenToKVPool.index_k_with_scale_buffer_dtype
1267
+ )
1268
+ cell_size += indexer_size_per_token * num_layers * element_size
1280
1269
  else:
1281
1270
  cell_size = (
1282
1271
  self.model_config.get_num_kv_heads(get_attention_tp_size())
@@ -1288,22 +1277,77 @@ class ModelRunner:
1288
1277
  rest_memory = available_gpu_memory - total_gpu_memory * (
1289
1278
  1 - self.mem_fraction_static
1290
1279
  )
1291
- if self.is_hybrid_gdn:
1292
- rest_memory -= (
1293
- self.server_args.max_mamba_cache_size
1294
- * self.model_config.hf_config.mamba_cache_per_req
1295
- / (1 << 30)
1296
- )
1280
+ if self.mambaish_config is not None:
1281
+ rest_memory = self.handle_max_mamba_cache(rest_memory)
1297
1282
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1298
1283
  return max_num_token
1299
1284
 
1285
+ def handle_max_mamba_cache(self, total_rest_memory):
1286
+ config = self.mambaish_config
1287
+ server_args = self.server_args
1288
+ assert config is not None
1289
+
1290
+ speculativa_ratio = (
1291
+ 0
1292
+ if server_args.speculative_num_draft_tokens is None
1293
+ else server_args.speculative_num_draft_tokens
1294
+ )
1295
+ if (
1296
+ server_args.disable_radix_cache
1297
+ or config.mamba2_cache_params.mamba_cache_per_req == 0
1298
+ ):
1299
+ # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
1300
+ if server_args.max_mamba_cache_size is None:
1301
+ if server_args.max_running_requests is not None:
1302
+ server_args.max_mamba_cache_size = server_args.max_running_requests
1303
+ else:
1304
+ server_args.max_mamba_cache_size = 512
1305
+ else:
1306
+ # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
1307
+ # solve the equations:
1308
+ # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
1309
+ # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
1310
+ mamba_state_memory_raw = (
1311
+ total_rest_memory
1312
+ * server_args.mamba_full_memory_ratio
1313
+ / (1 + server_args.mamba_full_memory_ratio)
1314
+ )
1315
+ # calculate the max_mamba_cache_size based on the given total mamba memory
1316
+ server_args.max_mamba_cache_size = int(
1317
+ (mamba_state_memory_raw * (1 << 30))
1318
+ // config.mamba2_cache_params.mamba_cache_per_req
1319
+ // (1 + speculativa_ratio)
1320
+ )
1321
+
1322
+ if self.hybrid_gdn_config is not None:
1323
+ server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
1324
+ server_args.dp_size if server_args.enable_dp_attention else 1
1325
+ )
1326
+ mamba_state_memory = (
1327
+ server_args.max_mamba_cache_size
1328
+ * config.mamba2_cache_params.mamba_cache_per_req
1329
+ * (1 + speculativa_ratio)
1330
+ / (1 << 30)
1331
+ )
1332
+ return total_rest_memory - mamba_state_memory
1333
+
1300
1334
  @property
1301
- def is_hybrid_gdn(self):
1302
- return self.model_config.hf_config.architectures[0] in [
1303
- "Qwen3NextForCausalLM",
1304
- "Qwen3NextForCausalLMMTP",
1305
- "FalconH1ForCausalLM",
1306
- ]
1335
+ def hybrid_gdn_config(self):
1336
+ config = self.model_config.hf_config
1337
+ if isinstance(config, Qwen3NextConfig):
1338
+ return config
1339
+ return None
1340
+
1341
+ @property
1342
+ def mamba2_config(self):
1343
+ config = self.model_config.hf_config
1344
+ if isinstance(config, FalconH1Config | NemotronHConfig):
1345
+ return config
1346
+ return None
1347
+
1348
+ @property
1349
+ def mambaish_config(self):
1350
+ return self.mamba2_config or self.hybrid_gdn_config
1307
1351
 
1308
1352
  def set_num_token_hybrid(self):
1309
1353
  if (
@@ -1387,6 +1431,27 @@ class ModelRunner:
1387
1431
  f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
1388
1432
  )
1389
1433
 
1434
+ def can_run_piecewise_cuda_graph(self):
1435
+ if self.server_args.disable_cuda_graph:
1436
+ log_info_on_rank0(
1437
+ logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
1438
+ )
1439
+ return False
1440
+ if self.server_args.enable_torch_compile:
1441
+ log_info_on_rank0(
1442
+ logger,
1443
+ "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
1444
+ )
1445
+ return False
1446
+ if self.pp_size > 1:
1447
+ # TODO(yuwei): support PP
1448
+ log_info_on_rank0(
1449
+ logger,
1450
+ "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
1451
+ )
1452
+ return False
1453
+ return True
1454
+
1390
1455
  def init_memory_pool(
1391
1456
  self,
1392
1457
  total_gpu_memory: int,
@@ -1417,6 +1482,8 @@ class ModelRunner:
1417
1482
  self.kv_cache_dtype = torch.float8_e4m3fnuz
1418
1483
  else:
1419
1484
  self.kv_cache_dtype = torch.float8_e4m3fn
1485
+ elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
1486
+ self.kv_cache_dtype = torch.bfloat16
1420
1487
  else:
1421
1488
  raise ValueError(
1422
1489
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -1438,8 +1505,16 @@ class ModelRunner:
1438
1505
  ),
1439
1506
  4096,
1440
1507
  )
1441
- if self.is_hybrid_gdn:
1442
- max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1508
+
1509
+ if self.mambaish_config is not None:
1510
+ ratio = (
1511
+ MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
1512
+ if not self.server_args.disable_radix_cache
1513
+ else 1
1514
+ )
1515
+ max_num_reqs = min(
1516
+ max_num_reqs, self.server_args.max_mamba_cache_size // ratio
1517
+ )
1443
1518
 
1444
1519
  if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1445
1520
  if self.is_draft_worker:
@@ -1506,39 +1581,43 @@ class ModelRunner:
1506
1581
  extra_max_context_len += self.server_args.speculative_num_draft_tokens
1507
1582
 
1508
1583
  if self.server_args.disaggregation_mode == "decode":
1509
- from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
1584
+ from sglang.srt.disaggregation.decode import (
1585
+ DecodeReqToTokenPool,
1586
+ HybridMambaDecodeReqToTokenPool,
1587
+ )
1510
1588
 
1511
1589
  # subscribe memory for pre-allocated requests
1512
1590
  # if max_num_reqs <= 32, we pre-allocate 2x requests
1513
1591
  pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1514
- self.req_to_token_pool = DecodeReqToTokenPool(
1515
- size=max_num_reqs,
1516
- max_context_len=self.model_config.context_len
1517
- + extra_max_context_len,
1518
- device=self.device,
1519
- enable_memory_saver=self.server_args.enable_memory_saver,
1520
- pre_alloc_size=pre_alloc_size,
1521
- )
1522
- elif self.is_hybrid_gdn:
1523
- config = self.model_config.hf_config
1524
- (
1525
- conv_state_shape,
1526
- temporal_state_shape,
1527
- conv_dtype,
1528
- ssm_dtype,
1529
- mamba_layers,
1530
- ) = config.hybrid_gdn_params
1592
+ if config := self.mambaish_config:
1593
+ self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
1594
+ size=max_num_reqs,
1595
+ max_context_len=self.model_config.context_len
1596
+ + extra_max_context_len,
1597
+ device=self.device,
1598
+ enable_memory_saver=self.server_args.enable_memory_saver,
1599
+ cache_params=config.mamba2_cache_params,
1600
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1601
+ pre_alloc_size=pre_alloc_size,
1602
+ )
1603
+ else:
1604
+ self.req_to_token_pool = DecodeReqToTokenPool(
1605
+ size=max_num_reqs,
1606
+ max_context_len=self.model_config.context_len
1607
+ + extra_max_context_len,
1608
+ device=self.device,
1609
+ enable_memory_saver=self.server_args.enable_memory_saver,
1610
+ pre_alloc_size=pre_alloc_size,
1611
+ )
1612
+ elif config := self.mambaish_config:
1531
1613
  self.req_to_token_pool = HybridReqToTokenPool(
1532
1614
  size=max_num_reqs,
1615
+ mamba_size=self.server_args.max_mamba_cache_size,
1533
1616
  max_context_len=self.model_config.context_len
1534
1617
  + extra_max_context_len,
1535
1618
  device=self.device,
1536
1619
  enable_memory_saver=self.server_args.enable_memory_saver,
1537
- conv_state_shape=conv_state_shape,
1538
- temporal_state_shape=temporal_state_shape,
1539
- conv_dtype=conv_dtype,
1540
- ssm_dtype=ssm_dtype,
1541
- mamba_layers=mamba_layers,
1620
+ cache_params=config.mamba2_cache_params,
1542
1621
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1543
1622
  )
1544
1623
  else:
@@ -1640,7 +1719,7 @@ class ModelRunner:
1640
1719
  enable_kvcache_transpose=False,
1641
1720
  device=self.device,
1642
1721
  )
1643
- elif self.is_hybrid_gdn:
1722
+ elif config := self.mambaish_config:
1644
1723
  self.token_to_kv_pool = HybridLinearKVPool(
1645
1724
  page_size=self.page_size,
1646
1725
  size=self.max_total_num_tokens,
@@ -1651,12 +1730,11 @@ class ModelRunner:
1651
1730
  head_dim=self.model_config.head_dim,
1652
1731
  # if draft worker, we only need 1 attention layer's kv pool
1653
1732
  full_attention_layer_ids=(
1654
- [0]
1655
- if self.is_draft_worker
1656
- else self.model_config.hf_config.full_attention_layer_ids
1733
+ [0] if self.is_draft_worker else config.full_attention_layer_ids
1657
1734
  ),
1658
1735
  enable_kvcache_transpose=False,
1659
1736
  device=self.device,
1737
+ mamba_pool=self.req_to_token_pool.mamba_pool,
1660
1738
  )
1661
1739
  else:
1662
1740
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -1672,13 +1750,17 @@ class ModelRunner:
1672
1750
  enable_memory_saver=self.server_args.enable_memory_saver,
1673
1751
  start_layer=self.start_layer,
1674
1752
  end_layer=self.end_layer,
1753
+ enable_kv_cache_copy=(
1754
+ self.server_args.speculative_algorithm is not None
1755
+ ),
1675
1756
  )
1676
1757
 
1677
1758
  # Initialize token_to_kv_pool_allocator
1678
1759
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1679
1760
  if self.token_to_kv_pool_allocator is None:
1680
1761
  if _is_npu and (
1681
- self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
1762
+ self.server_args.attention_backend == "ascend"
1763
+ or self.hybrid_gdn_config is not None
1682
1764
  ):
1683
1765
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1684
1766
  self.max_total_num_tokens,
@@ -1743,16 +1825,10 @@ class ModelRunner:
1743
1825
 
1744
1826
  def _get_attention_backend(self):
1745
1827
  """Init attention kernel backend."""
1746
- self.decode_attention_backend_str = (
1747
- self.server_args.decode_attention_backend
1748
- if self.server_args.decode_attention_backend
1749
- else self.server_args.attention_backend
1750
- )
1751
- self.prefill_attention_backend_str = (
1752
- self.server_args.prefill_attention_backend
1753
- if self.server_args.prefill_attention_backend
1754
- else self.server_args.attention_backend
1828
+ self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1829
+ self.server_args.get_attention_backends()
1755
1830
  )
1831
+
1756
1832
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1757
1833
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1758
1834
  HybridAttnBackend,
@@ -1781,12 +1857,10 @@ class ModelRunner:
1781
1857
  self.server_args.attention_backend
1782
1858
  )
1783
1859
 
1784
- global_server_args_dict.update(
1785
- {
1786
- "decode_attention_backend": self.decode_attention_backend_str,
1787
- "prefill_attention_backend": self.prefill_attention_backend_str,
1788
- }
1789
- )
1860
+ (
1861
+ get_global_server_args().prefill_attention_backend,
1862
+ get_global_server_args().decode_attention_backend,
1863
+ ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1790
1864
  return attn_backend
1791
1865
 
1792
1866
  def _get_attention_backend_from_str(self, backend_str: str):
@@ -1924,6 +1998,11 @@ class ModelRunner:
1924
1998
  kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1925
1999
  if not self.is_generation:
1926
2000
  kwargs["get_embedding"] = True
2001
+
2002
+ if self.piecewise_cuda_graph_runner is not None:
2003
+ if self.piecewise_cuda_graph_runner.can_run(forward_batch):
2004
+ return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
2005
+
1927
2006
  return self.model.forward(
1928
2007
  forward_batch.input_ids,
1929
2008
  forward_batch.positions,
@@ -2057,15 +2136,11 @@ class ModelRunner:
2057
2136
  def _preprocess_logits(
2058
2137
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
2059
2138
  ):
2060
- # Apply logit bias
2061
- if sampling_info.sampling_info_done:
2062
- # Overlap mode: the function update_regex_vocab_mask was executed
2063
- # in process_batch_result of the last batch.
2064
- if sampling_info.grammars:
2065
- sampling_info.sampling_info_done.wait()
2066
- else:
2067
- # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
2068
- sampling_info.update_regex_vocab_mask()
2139
+ # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
2140
+ # was executed after we processed last batch's results.
2141
+
2142
+ # Calculate logits bias and apply it to next_token_logits.
2143
+ sampling_info.update_regex_vocab_mask()
2069
2144
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
2070
2145
 
2071
2146
  def sample(
@@ -2164,6 +2239,23 @@ class ModelRunner:
2164
2239
  )
2165
2240
  ShardedStateLoader.save_model(self.model, path, pattern, max_size)
2166
2241
 
2242
+ def update_weights_from_ipc(self, recv_req):
2243
+ """Update weights from IPC for checkpoint-engine integration."""
2244
+ try:
2245
+ from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
2246
+ SGLangCheckpointEngineWorkerExtensionImpl,
2247
+ )
2248
+
2249
+ # Create a worker extension that integrates with SGLang's model
2250
+ worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
2251
+ worker.update_weights_from_ipc(recv_req.zmq_handles)
2252
+ return True, "IPC weight update completed successfully"
2253
+ except ImportError as e:
2254
+ return False, f"IPC weight update failed: ImportError {e}"
2255
+ except Exception as e:
2256
+ logger.error(f"IPC weight update failed: {e}")
2257
+ return False, str(e)
2258
+
2167
2259
 
2168
2260
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
2169
2261
  params_dict = dict(model.named_parameters())