sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -115,7 +115,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
115
115
  def __init__(
116
116
  self,
117
117
  tokenizer,
118
- whitespace_pattern: bool,
118
+ whitespace_pattern: str | None,
119
119
  ):
120
120
  super().__init__()
121
121
 
@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
20
+ from .base_grammar_backend import (
21
+ INVALID_GRAMMAR_OBJ,
22
+ BaseGrammarBackend,
23
+ BaseGrammarObject,
24
+ )
21
25
 
22
26
 
23
27
  class ReasonerGrammarObject(BaseGrammarObject):
@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
81
85
  self.grammar_backend = grammar_backend
82
86
  self.think_end_id = think_end_id
83
87
 
84
- def _init_value_dispatch(
85
- self, key: Tuple[str, str]
86
- ) -> Optional[ReasonerGrammarObject]:
88
+ def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
87
89
  ret = self.grammar_backend._init_value_dispatch(key)
88
- if ret is None:
89
- return None
90
+ # avoid wrapping invalid grammar, so that the scheduler can detect it
91
+ if ret is None or ret is INVALID_GRAMMAR_OBJ:
92
+ return ret
90
93
  return ReasonerGrammarObject(ret, self.think_end_id)
@@ -0,0 +1,12 @@
1
+ from typing import Dict
2
+
3
+
4
+ def is_legacy_structural_tag(obj: Dict) -> bool:
5
+ # test whether an object is a legacy structural tag
6
+ # see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
7
+ if obj.get("structures", None) is not None:
8
+ assert obj.get("triggers", None) is not None
9
+ return True
10
+ else:
11
+ assert obj.get("format", None) is not None
12
+ return False
@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
34
34
  BaseGrammarObject,
35
35
  GrammarStats,
36
36
  )
37
+ from sglang.srt.constrained.utils import is_legacy_structural_tag
37
38
  from sglang.srt.utils import is_hip
38
39
 
39
40
  _is_hip = is_hip()
@@ -167,6 +168,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
167
168
  tokenizer,
168
169
  vocab_size: int,
169
170
  model_eos_token_ids: Optional[List[int]] = None,
171
+ any_whitespace: bool = True,
170
172
  ):
171
173
  super().__init__()
172
174
 
@@ -188,6 +190,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
188
190
  self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
189
191
  self.vocab_size = vocab_size
190
192
  self.override_stop_tokens = override_stop_tokens
193
+ self.any_whitespace = any_whitespace
191
194
 
192
195
  def _from_context(
193
196
  self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
@@ -212,7 +215,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
212
215
  # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
213
216
  ctx = self.grammar_compiler.compile_builtin_json_grammar()
214
217
  else:
215
- ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
218
+ ctx = self.grammar_compiler.compile_json_schema(
219
+ schema=key_string, any_whitespace=self.any_whitespace
220
+ )
216
221
 
217
222
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
218
223
  logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
@@ -237,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
237
242
 
238
243
  def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
239
244
  try:
245
+ # TODO(dark): it's REALLY stupid to construct object from string and decode it again
240
246
  structural_tag = json.loads(key_string)
241
- tags = [
242
- StructuralTagItem(
243
- begin=structure["begin"],
244
- schema=json.dumps(structure["schema"]),
245
- end=structure["end"],
247
+ if is_legacy_structural_tag(structural_tag):
248
+ tags = [
249
+ StructuralTagItem(
250
+ begin=structure["begin"],
251
+ schema=json.dumps(structure["schema"]),
252
+ end=structure["end"],
253
+ )
254
+ for structure in structural_tag["structures"]
255
+ ]
256
+ ctx = self.grammar_compiler.compile_structural_tag(
257
+ tags, structural_tag["triggers"]
246
258
  )
247
- for structure in structural_tag["structures"]
248
- ]
249
- ctx = self.grammar_compiler.compile_structural_tag(
250
- tags, structural_tag["triggers"]
251
- )
259
+ else:
260
+ ctx = self.grammar_compiler.compile_structural_tag(key_string)
252
261
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
253
262
  logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
254
263
  return INVALID_GRAMMAR_OBJ
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  import os
3
- from typing import List, Optional
3
+ from typing import List
4
4
 
5
5
  import torch
6
6
 
@@ -20,6 +20,10 @@ class KVArgs:
20
20
  aux_data_ptrs: List[int]
21
21
  aux_data_lens: List[int]
22
22
  aux_item_lens: List[int]
23
+ state_data_ptrs: List[int]
24
+ state_data_lens: List[int]
25
+ state_item_lens: List[int]
26
+ state_type: str # "none", "mamba", "swa"
23
27
  ib_device: str
24
28
  ib_traffic_class: str
25
29
  gpu_id: int
@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
76
80
  ...
77
81
 
78
82
  @abstractmethod
79
- def send(self, kv_indices: npt.NDArray[np.int32]):
83
+ def send(
84
+ self,
85
+ kv_indices: npt.NDArray[np.int32],
86
+ state_indices: Optional[List[int]] = None,
87
+ ):
80
88
  """
81
- Send the kv cache at the given kv indices to the decoder server
89
+ Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
82
90
  """
83
91
  ...
84
92
 
@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
108
116
  ): ...
109
117
 
110
118
  @abstractmethod
111
- def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
119
+ def init(
120
+ self,
121
+ kv_indices: npt.NDArray[np.int32],
122
+ aux_index: Optional[int] = None,
123
+ state_indices: Optional[List[int]] = None,
124
+ ):
112
125
  """
113
- Notify the prefill server about the kv indices and aux index
126
+ Notify the prefill server about the kv indices, aux index, and state_indices.
114
127
  """
115
128
  ...
116
129
 
@@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager):
77
77
 
78
78
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
79
79
  self._register_to_bootstrap()
80
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
81
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
80
+ self.transfer_infos = {}
81
+ self.decode_kv_args_table = {}
82
82
  self.pp_group = get_pp_group()
83
83
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
84
84
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
201
201
  def send(
202
202
  self,
203
203
  kv_indices: npt.NDArray[np.int32],
204
+ state_indices: Optional[List[int]] = None,
204
205
  ):
205
206
  pass
206
207
 
@@ -245,6 +246,7 @@ class CommonKVReceiver(BaseKVReceiver):
245
246
  f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
246
247
  )
247
248
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
249
+ self.bootstrap_infos = None
248
250
  return
249
251
  else:
250
252
  logger.debug(
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
30
30
  import torch
31
31
  from torch.distributed import ProcessGroup
32
32
 
33
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
33
34
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
34
35
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
36
  from sglang.srt.disaggregation.utils import (
@@ -49,10 +50,16 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
49
50
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
50
51
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
51
52
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
52
- from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
53
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
+ from sglang.srt.mem_cache.memory_pool import (
54
+ HybridLinearKVPool,
55
+ HybridReqToTokenPool,
56
+ KVCache,
57
+ NSATokenToKVPool,
58
+ ReqToTokenPool,
59
+ SWAKVPool,
60
+ )
55
61
  from sglang.srt.utils import get_int_env_var, require_mlp_sync
62
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
56
63
 
57
64
  logger = logging.getLogger(__name__)
58
65
 
@@ -124,6 +131,35 @@ class DecodeReqToTokenPool:
124
131
  self.free_slots = list(range(self.size + self.pre_alloc_size))
125
132
 
126
133
 
134
+ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
135
+
136
+ def __init__(
137
+ self,
138
+ size: int,
139
+ max_context_len: int,
140
+ device: str,
141
+ enable_memory_saver: bool,
142
+ cache_params: "Mamba2CacheParams",
143
+ speculative_num_draft_tokens: int,
144
+ pre_alloc_size: int,
145
+ ):
146
+ DecodeReqToTokenPool.__init__(
147
+ self,
148
+ size=size,
149
+ max_context_len=max_context_len,
150
+ device=device,
151
+ enable_memory_saver=enable_memory_saver,
152
+ pre_alloc_size=pre_alloc_size,
153
+ )
154
+ self._init_mamba_pool(
155
+ size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
156
+ )
157
+
158
+ def clear(self):
159
+ self.free_slots = list(range(self.size + self.pre_alloc_size))
160
+ self.mamba_pool.clear()
161
+
162
+
127
163
  @dataclass
128
164
  class DecodeRequest:
129
165
  req: Req
@@ -217,6 +253,28 @@ class DecodePreallocQueue:
217
253
  self.metadata_buffers.get_buf_infos()
218
254
  )
219
255
 
256
+ if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
257
+ state_data_ptrs, state_data_lens, state_item_lens = (
258
+ self.token_to_kv_pool.get_state_buf_infos()
259
+ )
260
+ kv_args.state_data_ptrs = state_data_ptrs
261
+ kv_args.state_data_lens = state_data_lens
262
+ kv_args.state_item_lens = state_item_lens
263
+
264
+ if isinstance(self.token_to_kv_pool, SWAKVPool):
265
+ kv_args.state_type = "swa"
266
+ elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
267
+ kv_args.state_type = "mamba"
268
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
269
+ kv_args.state_type = "nsa"
270
+ else:
271
+ kv_args.state_type = "none"
272
+ else:
273
+ kv_args.state_data_ptrs = []
274
+ kv_args.state_data_lens = []
275
+ kv_args.state_item_lens = []
276
+ kv_args.state_type = "none"
277
+
220
278
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
221
279
  kv_args.gpu_id = self.scheduler.gpu_id
222
280
  kv_manager_class: Type[BaseKVManager] = get_kv_class(
@@ -414,16 +472,56 @@ class DecodePreallocQueue:
414
472
  .cpu()
415
473
  .numpy()
416
474
  )
475
+ page_size = self.token_to_kv_pool_allocator.page_size
476
+
477
+ # Prepare extra pool indices for hybrid models
478
+ if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
479
+ # Mamba hybrid model: single mamba state index
480
+ state_indices = [
481
+ self.req_to_token_pool.req_index_to_mamba_index_mapping[
482
+ decode_req.req.req_pool_idx
483
+ ]
484
+ .cpu()
485
+ .numpy()
486
+ ]
487
+ elif isinstance(self.token_to_kv_pool, SWAKVPool):
488
+ # SWA hybrid model: send decode-side SWA window indices
489
+ seq_len = len(decode_req.req.origin_input_ids)
490
+ window_size = self.scheduler.sliding_window_size
491
+
492
+ window_start = max(0, seq_len - window_size)
493
+ window_start = (window_start // page_size) * page_size
494
+ window_kv_indices_full = self.req_to_token_pool.req_to_token[
495
+ decode_req.req.req_pool_idx, window_start:seq_len
496
+ ]
497
+
498
+ # Translate to SWA pool indices
499
+ window_kv_indices_swa = (
500
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
501
+ window_kv_indices_full
502
+ )
503
+ )
504
+ state_indices = window_kv_indices_swa.cpu().numpy()
505
+ state_indices = kv_to_page_indices(state_indices, page_size)
506
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
507
+ seq_len = len(decode_req.req.origin_input_ids)
508
+ kv_indices_full = self.req_to_token_pool.req_to_token[
509
+ decode_req.req.req_pool_idx, :seq_len
510
+ ]
511
+ state_indices = kv_indices_full.cpu().numpy()
512
+ state_indices = kv_to_page_indices(state_indices, page_size)
513
+ else:
514
+ state_indices = None
417
515
 
418
516
  decode_req.metadata_buffer_index = (
419
517
  self.req_to_metadata_buffer_idx_allocator.alloc()
420
518
  )
421
519
  assert decode_req.metadata_buffer_index is not None
422
- page_indices = kv_to_page_indices(
423
- kv_indices, self.token_to_kv_pool_allocator.page_size
520
+ page_indices = kv_to_page_indices(kv_indices, page_size)
521
+ decode_req.kv_receiver.init(
522
+ page_indices, decode_req.metadata_buffer_index, state_indices
424
523
  )
425
- decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
426
-
524
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
427
525
  preallocated_reqs.append(decode_req)
428
526
  indices_to_remove.add(i)
429
527
  decode_req.req.time_stats.decode_transfer_queue_entry_time = (
@@ -503,7 +601,10 @@ class DecodePreallocQueue:
503
601
 
504
602
  def _pre_alloc(self, req: Req) -> torch.Tensor:
505
603
  """Pre-allocate the memory for req_to_token and token_kv_pool"""
506
- req_pool_indices = self.req_to_token_pool.alloc(1)
604
+ if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
605
+ req_pool_indices = self.req_to_token_pool.alloc(1, [req])
606
+ else:
607
+ req_pool_indices = self.req_to_token_pool.alloc(1)
507
608
 
508
609
  assert (
509
610
  req_pool_indices is not None
@@ -611,8 +712,8 @@ class DecodeTransferQueue:
611
712
  self.scheduler.stream_output(
612
713
  [decode_req.req], decode_req.req.return_logprob
613
714
  )
614
- # unlock the kv cache or it will have memory leak
615
- self.tree_cache.cache_finished_req(decode_req.req)
715
+ # release pre-allocated kv cache, but don't insert into the tree since it's failed
716
+ self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
616
717
  indices_to_remove.add(i)
617
718
  if self.scheduler.enable_metrics:
618
719
  self.scheduler.metrics_collector.increment_transfer_failed_reqs()
@@ -747,11 +848,12 @@ class SchedulerDisaggregationDecodeMixin:
747
848
 
748
849
  @torch.no_grad()
749
850
  def event_loop_overlap_disagg_decode(self: Scheduler):
750
- result_queue = deque()
851
+ self.result_queue = deque()
751
852
  self.last_batch: Optional[ScheduleBatch] = None
752
853
  self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
753
854
 
754
855
  while True:
856
+
755
857
  recv_reqs = self.recv_requests()
756
858
  self.process_input_requests(recv_reqs)
757
859
  # polling and allocating kv cache
@@ -762,6 +864,7 @@ class SchedulerDisaggregationDecodeMixin:
762
864
 
763
865
  prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
764
866
 
867
+ batch_result = None
765
868
  if batch:
766
869
  # Generate fake extend output.
767
870
  if batch.forward_mode.is_extend():
@@ -770,45 +873,34 @@ class SchedulerDisaggregationDecodeMixin:
770
873
  batch.reqs, any(req.return_logprob for req in batch.reqs)
771
874
  )
772
875
  if prepare_mlp_sync_flag:
773
- batch_, result = self._prepare_idle_batch_and_run(
876
+ batch_, batch_result = self._prepare_idle_batch_and_run(
774
877
  None, delay_process=True
775
878
  )
776
879
  if batch_:
777
- result_queue.append((batch_.copy(), result))
880
+ self.result_queue.append((batch_.copy(), batch_result))
778
881
  last_batch_in_queue = True
779
882
  else:
780
883
  if prepare_mlp_sync_flag:
781
884
  self.prepare_mlp_sync_batch(batch)
782
- result = self.run_batch(batch)
783
- result_queue.append((batch.copy(), result))
784
-
785
- if (self.last_batch is None) or (not self.last_batch_in_queue):
786
- # Create a dummy first batch to start the pipeline for overlap schedule.
787
- # It is now used for triggering the sampling_info_done event.
788
- tmp_batch = ScheduleBatch(
789
- reqs=None,
790
- forward_mode=ForwardMode.DUMMY_FIRST,
791
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
792
- )
793
- self.set_next_batch_sampling_info_done(tmp_batch)
885
+ batch_result = self.run_batch(batch)
886
+ self.result_queue.append((batch.copy(), batch_result))
794
887
  last_batch_in_queue = True
795
888
 
796
889
  elif prepare_mlp_sync_flag:
797
- batch, result = self._prepare_idle_batch_and_run(
890
+ batch, batch_result = self._prepare_idle_batch_and_run(
798
891
  None, delay_process=True
799
892
  )
800
893
  if batch:
801
- result_queue.append((batch.copy(), result))
894
+ self.result_queue.append((batch.copy(), batch_result))
802
895
  last_batch_in_queue = True
803
896
 
804
897
  # Process the results of the previous batch but skip if the last batch is extend
805
898
  if self.last_batch and self.last_batch_in_queue:
806
- tmp_batch, tmp_result = result_queue.popleft()
807
- tmp_batch.next_batch_sampling_info = (
808
- self.tp_worker.cur_sampling_info if batch else None
809
- )
899
+ tmp_batch, tmp_result = self.result_queue.popleft()
810
900
  self.process_batch_result(tmp_batch, tmp_result)
811
901
 
902
+ self.launch_batch_sample_if_needed(batch_result)
903
+
812
904
  queue_size = (
813
905
  len(self.waiting_queue)
814
906
  + len(self.disagg_decode_transfer_queue.queue)
@@ -4,7 +4,6 @@ import time
4
4
 
5
5
  import torch
6
6
 
7
- from sglang import ServerArgs
8
7
  from sglang.srt.managers.cache_controller import HiCacheController
9
8
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
9
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -17,6 +16,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
17
16
  MHATokenToKVPoolHost,
18
17
  MLATokenToKVPoolHost,
19
18
  )
19
+ from sglang.srt.server_args import ServerArgs
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
48
48
  def send(
49
49
  self,
50
50
  kv_indices: npt.NDArray[np.int32],
51
+ state_indices: Optional[List[int]] = None,
51
52
  ):
52
53
  self.has_sent = True
53
- logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
54
+ logger.debug(
55
+ f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
56
+ )
54
57
 
55
58
  def failure_exception(self):
56
59
  raise Exception("Fake KVSender Exception")
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
75
78
  logger.debug("FakeKVReceiver poll success")
76
79
  return KVPoll.Success
77
80
 
78
- def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
81
+ def init(
82
+ self,
83
+ kv_indices: list[int],
84
+ aux_index: Optional[int] = None,
85
+ state_indices: Optional[List[int]] = None,
86
+ ):
79
87
  self.has_init = True
80
88
  logger.debug(
81
- f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
89
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
82
90
  )
83
91
 
84
92
  def failure_exception(self):