sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import time
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
@@ -29,6 +30,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
29
30
  import torch
30
31
  from torch.distributed import ProcessGroup
31
32
 
33
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
32
34
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
33
35
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
34
36
  from sglang.srt.disaggregation.utils import (
@@ -45,13 +47,19 @@ from sglang.srt.disaggregation.utils import (
45
47
  prepare_abort,
46
48
  )
47
49
  from sglang.srt.layers.dp_attention import get_attention_tp_size
48
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
50
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
49
51
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
50
52
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
- from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
52
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
53
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
+ from sglang.srt.mem_cache.memory_pool import (
54
+ HybridLinearKVPool,
55
+ HybridReqToTokenPool,
56
+ KVCache,
57
+ NSATokenToKVPool,
58
+ ReqToTokenPool,
59
+ SWAKVPool,
60
+ )
54
61
  from sglang.srt.utils import get_int_env_var, require_mlp_sync
62
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
55
63
 
56
64
  logger = logging.getLogger(__name__)
57
65
 
@@ -123,6 +131,35 @@ class DecodeReqToTokenPool:
123
131
  self.free_slots = list(range(self.size + self.pre_alloc_size))
124
132
 
125
133
 
134
+ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
135
+
136
+ def __init__(
137
+ self,
138
+ size: int,
139
+ max_context_len: int,
140
+ device: str,
141
+ enable_memory_saver: bool,
142
+ cache_params: "Mamba2CacheParams",
143
+ speculative_num_draft_tokens: int,
144
+ pre_alloc_size: int,
145
+ ):
146
+ DecodeReqToTokenPool.__init__(
147
+ self,
148
+ size=size,
149
+ max_context_len=max_context_len,
150
+ device=device,
151
+ enable_memory_saver=enable_memory_saver,
152
+ pre_alloc_size=pre_alloc_size,
153
+ )
154
+ self._init_mamba_pool(
155
+ size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
156
+ )
157
+
158
+ def clear(self):
159
+ self.free_slots = list(range(self.size + self.pre_alloc_size))
160
+ self.mamba_pool.clear()
161
+
162
+
126
163
  @dataclass
127
164
  class DecodeRequest:
128
165
  req: Req
@@ -216,6 +253,28 @@ class DecodePreallocQueue:
216
253
  self.metadata_buffers.get_buf_infos()
217
254
  )
218
255
 
256
+ if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
257
+ state_data_ptrs, state_data_lens, state_item_lens = (
258
+ self.token_to_kv_pool.get_state_buf_infos()
259
+ )
260
+ kv_args.state_data_ptrs = state_data_ptrs
261
+ kv_args.state_data_lens = state_data_lens
262
+ kv_args.state_item_lens = state_item_lens
263
+
264
+ if isinstance(self.token_to_kv_pool, SWAKVPool):
265
+ kv_args.state_type = "swa"
266
+ elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
267
+ kv_args.state_type = "mamba"
268
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
269
+ kv_args.state_type = "nsa"
270
+ else:
271
+ kv_args.state_type = "none"
272
+ else:
273
+ kv_args.state_data_ptrs = []
274
+ kv_args.state_data_lens = []
275
+ kv_args.state_item_lens = []
276
+ kv_args.state_type = "none"
277
+
219
278
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
220
279
  kv_args.gpu_id = self.scheduler.gpu_id
221
280
  kv_manager_class: Type[BaseKVManager] = get_kv_class(
@@ -253,6 +312,7 @@ class DecodePreallocQueue:
253
312
  prefill_dp_rank=req.data_parallel_rank,
254
313
  )
255
314
 
315
+ req.add_latency(RequestStage.DECODE_PREPARE)
256
316
  self.queue.append(
257
317
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
258
318
  )
@@ -412,17 +472,62 @@ class DecodePreallocQueue:
412
472
  .cpu()
413
473
  .numpy()
414
474
  )
475
+ page_size = self.token_to_kv_pool_allocator.page_size
476
+
477
+ # Prepare extra pool indices for hybrid models
478
+ if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
479
+ # Mamba hybrid model: single mamba state index
480
+ state_indices = [
481
+ self.req_to_token_pool.req_index_to_mamba_index_mapping[
482
+ decode_req.req.req_pool_idx
483
+ ]
484
+ .cpu()
485
+ .numpy()
486
+ ]
487
+ elif isinstance(self.token_to_kv_pool, SWAKVPool):
488
+ # SWA hybrid model: send decode-side SWA window indices
489
+ seq_len = len(decode_req.req.origin_input_ids)
490
+ window_size = self.scheduler.sliding_window_size
491
+
492
+ window_start = max(0, seq_len - window_size)
493
+ window_start = (window_start // page_size) * page_size
494
+ window_kv_indices_full = self.req_to_token_pool.req_to_token[
495
+ decode_req.req.req_pool_idx, window_start:seq_len
496
+ ]
497
+
498
+ # Translate to SWA pool indices
499
+ window_kv_indices_swa = (
500
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
501
+ window_kv_indices_full
502
+ )
503
+ )
504
+ state_indices = window_kv_indices_swa.cpu().numpy()
505
+ state_indices = kv_to_page_indices(state_indices, page_size)
506
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
507
+ seq_len = len(decode_req.req.origin_input_ids)
508
+ kv_indices_full = self.req_to_token_pool.req_to_token[
509
+ decode_req.req.req_pool_idx, :seq_len
510
+ ]
511
+ state_indices = kv_indices_full.cpu().numpy()
512
+ state_indices = kv_to_page_indices(state_indices, page_size)
513
+ else:
514
+ state_indices = None
415
515
 
416
516
  decode_req.metadata_buffer_index = (
417
517
  self.req_to_metadata_buffer_idx_allocator.alloc()
418
518
  )
419
519
  assert decode_req.metadata_buffer_index is not None
420
- page_indices = kv_to_page_indices(
421
- kv_indices, self.token_to_kv_pool_allocator.page_size
520
+ page_indices = kv_to_page_indices(kv_indices, page_size)
521
+ decode_req.kv_receiver.init(
522
+ page_indices, decode_req.metadata_buffer_index, state_indices
422
523
  )
423
- decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
524
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
424
525
  preallocated_reqs.append(decode_req)
425
526
  indices_to_remove.add(i)
527
+ decode_req.req.time_stats.decode_transfer_queue_entry_time = (
528
+ time.perf_counter()
529
+ )
530
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
426
531
 
427
532
  self.queue = [
428
533
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -496,7 +601,10 @@ class DecodePreallocQueue:
496
601
 
497
602
  def _pre_alloc(self, req: Req) -> torch.Tensor:
498
603
  """Pre-allocate the memory for req_to_token and token_kv_pool"""
499
- req_pool_indices = self.req_to_token_pool.alloc(1)
604
+ if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
605
+ req_pool_indices = self.req_to_token_pool.alloc(1, [req])
606
+ else:
607
+ req_pool_indices = self.req_to_token_pool.alloc(1)
500
608
 
501
609
  assert (
502
610
  req_pool_indices is not None
@@ -516,11 +624,19 @@ class DecodePreallocQueue:
516
624
  dtype=torch.int64,
517
625
  device=self.token_to_kv_pool_allocator.device,
518
626
  ),
627
+ prefix_lens_cpu=torch.tensor(
628
+ [0],
629
+ dtype=torch.int64,
630
+ ),
519
631
  seq_lens=torch.tensor(
520
632
  [num_tokens],
521
633
  dtype=torch.int64,
522
634
  device=self.token_to_kv_pool_allocator.device,
523
635
  ),
636
+ seq_lens_cpu=torch.tensor(
637
+ [num_tokens],
638
+ dtype=torch.int64,
639
+ ),
524
640
  last_loc=torch.tensor(
525
641
  [-1],
526
642
  dtype=torch.int64,
@@ -596,8 +712,8 @@ class DecodeTransferQueue:
596
712
  self.scheduler.stream_output(
597
713
  [decode_req.req], decode_req.req.return_logprob
598
714
  )
599
- # unlock the kv cache or it will have memory leak
600
- self.tree_cache.cache_finished_req(decode_req.req)
715
+ # release pre-allocated kv cache, but don't insert into the tree since it's failed
716
+ self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
601
717
  indices_to_remove.add(i)
602
718
  if self.scheduler.enable_metrics:
603
719
  self.scheduler.metrics_collector.increment_transfer_failed_reqs()
@@ -607,16 +723,23 @@ class DecodeTransferQueue:
607
723
  idx = decode_req.metadata_buffer_index
608
724
  (
609
725
  output_id,
726
+ cached_tokens,
610
727
  output_token_logprobs_val,
611
728
  output_token_logprobs_idx,
612
729
  output_top_logprobs_val,
613
730
  output_top_logprobs_idx,
731
+ output_topk_p,
732
+ output_topk_index,
614
733
  output_hidden_states,
615
734
  ) = self.metadata_buffers.get_buf(idx)
616
735
 
617
736
  decode_req.req.output_ids.append(output_id[0].item())
737
+ decode_req.req.cached_tokens = cached_tokens[0].item()
618
738
  if not self.spec_algorithm.is_none():
739
+ decode_req.req.output_topk_p = output_topk_p
740
+ decode_req.req.output_topk_index = output_topk_index
619
741
  decode_req.req.hidden_states_tensor = output_hidden_states
742
+
620
743
  if decode_req.req.return_logprob:
621
744
  decode_req.req.output_token_logprobs_val.append(
622
745
  output_token_logprobs_val[0].item()
@@ -637,10 +760,17 @@ class DecodeTransferQueue:
637
760
 
638
761
  if hasattr(decode_req.kv_receiver, "clear"):
639
762
  decode_req.kv_receiver.clear()
763
+ decode_req.kv_receiver = None
764
+
765
+ indices_to_remove.add(i)
766
+ decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
640
767
 
641
768
  # special handling for sampling_params.max_new_tokens == 1
642
769
  if decode_req.req.sampling_params.max_new_tokens == 1:
643
770
  # finish immediately
771
+ decode_req.req.time_stats.forward_entry_time = (
772
+ decode_req.req.time_stats.completion_time
773
+ ) = time.perf_counter()
644
774
  decode_req.req.check_finished()
645
775
  self.scheduler.stream_output(
646
776
  [decode_req.req], decode_req.req.return_logprob
@@ -648,8 +778,6 @@ class DecodeTransferQueue:
648
778
  self.tree_cache.cache_finished_req(decode_req.req)
649
779
  else:
650
780
  transferred_reqs.append(decode_req.req)
651
-
652
- indices_to_remove.add(i)
653
781
  elif poll in [
654
782
  KVPoll.Bootstrapping,
655
783
  KVPoll.WaitingForInput,
@@ -662,6 +790,7 @@ class DecodeTransferQueue:
662
790
  for i in indices_to_remove:
663
791
  idx = self.queue[i].metadata_buffer_index
664
792
  assert idx != -1
793
+ self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
665
794
  self.req_to_metadata_buffer_idx_allocator.free(idx)
666
795
 
667
796
  self.queue = [
@@ -704,23 +833,27 @@ class SchedulerDisaggregationDecodeMixin:
704
833
  elif prepare_mlp_sync_flag:
705
834
  batch, _ = self._prepare_idle_batch_and_run(None)
706
835
 
707
- if batch is None and (
836
+ queue_size = (
708
837
  len(self.waiting_queue)
709
838
  + len(self.disagg_decode_transfer_queue.queue)
710
839
  + len(self.disagg_decode_prealloc_queue.queue)
711
- == 0
712
- ):
840
+ )
841
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
842
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
843
+
844
+ if batch is None and queue_size == 0:
713
845
  self.self_check_during_idle()
714
846
 
715
847
  self.last_batch = batch
716
848
 
717
849
  @torch.no_grad()
718
850
  def event_loop_overlap_disagg_decode(self: Scheduler):
719
- result_queue = deque()
851
+ self.result_queue = deque()
720
852
  self.last_batch: Optional[ScheduleBatch] = None
721
853
  self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
722
854
 
723
855
  while True:
856
+
724
857
  recv_reqs = self.recv_requests()
725
858
  self.process_input_requests(recv_reqs)
726
859
  # polling and allocating kv cache
@@ -731,6 +864,7 @@ class SchedulerDisaggregationDecodeMixin:
731
864
 
732
865
  prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
733
866
 
867
+ batch_result = None
734
868
  if batch:
735
869
  # Generate fake extend output.
736
870
  if batch.forward_mode.is_extend():
@@ -739,51 +873,43 @@ class SchedulerDisaggregationDecodeMixin:
739
873
  batch.reqs, any(req.return_logprob for req in batch.reqs)
740
874
  )
741
875
  if prepare_mlp_sync_flag:
742
- batch_, result = self._prepare_idle_batch_and_run(
876
+ batch_, batch_result = self._prepare_idle_batch_and_run(
743
877
  None, delay_process=True
744
878
  )
745
879
  if batch_:
746
- result_queue.append((batch_.copy(), result))
880
+ self.result_queue.append((batch_.copy(), batch_result))
747
881
  last_batch_in_queue = True
748
882
  else:
749
883
  if prepare_mlp_sync_flag:
750
884
  self.prepare_mlp_sync_batch(batch)
751
- result = self.run_batch(batch)
752
- result_queue.append((batch.copy(), result))
753
-
754
- if (self.last_batch is None) or (not self.last_batch_in_queue):
755
- # Create a dummy first batch to start the pipeline for overlap schedule.
756
- # It is now used for triggering the sampling_info_done event.
757
- tmp_batch = ScheduleBatch(
758
- reqs=None,
759
- forward_mode=ForwardMode.DUMMY_FIRST,
760
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
761
- )
762
- self.set_next_batch_sampling_info_done(tmp_batch)
885
+ batch_result = self.run_batch(batch)
886
+ self.result_queue.append((batch.copy(), batch_result))
763
887
  last_batch_in_queue = True
764
888
 
765
889
  elif prepare_mlp_sync_flag:
766
- batch, result = self._prepare_idle_batch_and_run(
890
+ batch, batch_result = self._prepare_idle_batch_and_run(
767
891
  None, delay_process=True
768
892
  )
769
893
  if batch:
770
- result_queue.append((batch.copy(), result))
894
+ self.result_queue.append((batch.copy(), batch_result))
771
895
  last_batch_in_queue = True
772
896
 
773
897
  # Process the results of the previous batch but skip if the last batch is extend
774
898
  if self.last_batch and self.last_batch_in_queue:
775
- tmp_batch, tmp_result = result_queue.popleft()
776
- tmp_batch.next_batch_sampling_info = (
777
- self.tp_worker.cur_sampling_info if batch else None
778
- )
899
+ tmp_batch, tmp_result = self.result_queue.popleft()
779
900
  self.process_batch_result(tmp_batch, tmp_result)
780
901
 
781
- if batch is None and (
902
+ self.launch_batch_sample_if_needed(batch_result)
903
+
904
+ queue_size = (
782
905
  len(self.waiting_queue)
783
906
  + len(self.disagg_decode_transfer_queue.queue)
784
907
  + len(self.disagg_decode_prealloc_queue.queue)
785
- == 0
786
- ):
908
+ )
909
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
910
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
911
+
912
+ if batch is None and queue_size == 0:
787
913
  self.self_check_during_idle()
788
914
 
789
915
  self.last_batch = batch
@@ -853,6 +979,7 @@ class SchedulerDisaggregationDecodeMixin:
853
979
  # we can only add at least `num_not_used_batch` new batch to the running queue
854
980
  if i < num_not_used_batch:
855
981
  can_run_list.append(req)
982
+ req.add_latency(RequestStage.DECODE_WAITING)
856
983
  req.init_next_round_input(self.tree_cache)
857
984
  else:
858
985
  waiting_queue.append(req)
@@ -861,6 +988,9 @@ class SchedulerDisaggregationDecodeMixin:
861
988
  if len(can_run_list) == 0:
862
989
  return None
863
990
 
991
+ for req in can_run_list:
992
+ req.time_stats.forward_entry_time = time.perf_counter()
993
+
864
994
  # construct a schedule batch with those requests and mark as decode
865
995
  new_batch = ScheduleBatch.init_new(
866
996
  can_run_list,
@@ -901,3 +1031,6 @@ class SchedulerDisaggregationDecodeMixin:
901
1031
  self.disagg_decode_transfer_queue.pop_transferred()
902
1032
  ) # the requests which kv has arrived
903
1033
  self.waiting_queue.extend(alloc_reqs)
1034
+
1035
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
1036
+ self.decode_offload_manager.check_offload_progress()
@@ -0,0 +1,185 @@
1
+ import logging
2
+ import threading
3
+ import time
4
+
5
+ import torch
6
+
7
+ from sglang.srt.managers.cache_controller import HiCacheController
8
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
+ from sglang.srt.mem_cache.memory_pool import (
11
+ MHATokenToKVPool,
12
+ MLATokenToKVPool,
13
+ ReqToTokenPool,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool_host import (
16
+ MHATokenToKVPoolHost,
17
+ MLATokenToKVPoolHost,
18
+ )
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DecodeKVCacheOffloadManager:
25
+ """Manage decode-side KV cache offloading lifecycle and operations."""
26
+
27
+ def __init__(
28
+ self,
29
+ req_to_token_pool: ReqToTokenPool,
30
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
+ tp_group: torch.distributed.ProcessGroup,
32
+ tree_cache: BasePrefixCache,
33
+ server_args: ServerArgs,
34
+ ) -> None:
35
+ self.req_to_token_pool = req_to_token_pool
36
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
37
+ self.page_size = server_args.page_size
38
+ self.server_args = server_args
39
+ self.request_counter = 0
40
+ self.tree_cache = tree_cache
41
+ kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
42
+ if isinstance(kv_cache, MHATokenToKVPool):
43
+ self.decode_host_mem_pool = MHATokenToKVPoolHost(
44
+ kv_cache,
45
+ server_args.hicache_ratio,
46
+ server_args.hicache_size,
47
+ self.page_size,
48
+ server_args.hicache_mem_layout,
49
+ )
50
+ elif isinstance(kv_cache, MLATokenToKVPool):
51
+ self.decode_host_mem_pool = MLATokenToKVPoolHost(
52
+ kv_cache,
53
+ server_args.hicache_ratio,
54
+ server_args.hicache_size,
55
+ self.page_size,
56
+ server_args.hicache_mem_layout,
57
+ )
58
+ else:
59
+ raise ValueError("Unsupported KV cache type for decode offload")
60
+
61
+ self.tp_group = tp_group
62
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
63
+ self.cache_controller = HiCacheController(
64
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
65
+ mem_pool_host=self.decode_host_mem_pool,
66
+ page_size=self.page_size,
67
+ tp_group=tp_group,
68
+ io_backend=server_args.hicache_io_backend,
69
+ load_cache_event=threading.Event(),
70
+ storage_backend=server_args.hicache_storage_backend,
71
+ model_name=server_args.served_model_name,
72
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
73
+ )
74
+
75
+ self.ongoing_offload = {}
76
+ self.ongoing_backup = {}
77
+ logger.info("Enable offload kv cache for decode side")
78
+
79
+ def offload_kv_cache(self, req) -> bool:
80
+ """Offload a finished request's KV cache to storage."""
81
+
82
+ if self.cache_controller is None or self.decode_host_mem_pool is None:
83
+ return False
84
+
85
+ if req.req_pool_idx == -1:
86
+ return False
87
+
88
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
89
+ if token_indices.dim() == 0 or token_indices.numel() == 0:
90
+ logger.debug(
91
+ f"Request {req.rid} has invalid token_indices: {token_indices}"
92
+ )
93
+ return False
94
+
95
+ tokens = req.origin_input_ids + req.output_ids
96
+ aligned_len = (len(tokens) // self.page_size) * self.page_size
97
+ if aligned_len == 0:
98
+ return False
99
+
100
+ token_indices = token_indices[:aligned_len]
101
+ tokens = tokens[:aligned_len]
102
+
103
+ # Asynchronously offload KV cache from device to host by cache controller
104
+ self.request_counter += 1
105
+ ack_id = self.request_counter
106
+ host_indices = self.cache_controller.write(
107
+ device_indices=token_indices.long(),
108
+ node_id=ack_id,
109
+ )
110
+ if host_indices is None:
111
+ logger.error(f"Not enough host memory for request {req.rid}")
112
+ return False
113
+
114
+ self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
115
+ return True
116
+
117
+ def check_offload_progress(self):
118
+ """Check the progress of offload from device to host and backup from host to storage."""
119
+ cc = self.cache_controller
120
+
121
+ qsizes = torch.tensor(
122
+ [
123
+ len(cc.ack_write_queue),
124
+ cc.ack_backup_queue.qsize(),
125
+ ],
126
+ dtype=torch.int,
127
+ )
128
+ if self.tp_world_size > 1:
129
+ torch.distributed.all_reduce(
130
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
131
+ )
132
+
133
+ n_write, n_backup = map(int, qsizes.tolist())
134
+ self._check_offload_progress(n_write)
135
+ self._check_backup_progress(n_backup)
136
+
137
+ def _check_offload_progress(self, finish_count):
138
+ """Check the progress of offload from device to host."""
139
+ while finish_count > 0:
140
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
141
+ finish_event.synchronize()
142
+ for ack_id in ack_list:
143
+ req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
144
+
145
+ # Release device
146
+ self.tree_cache.cache_finished_req(req)
147
+
148
+ # Trigger async backup from host to storage by cache controller
149
+ self._trigger_backup(req.rid, host_indices, tokens, start_time)
150
+ finish_count -= 1
151
+
152
+ def _check_backup_progress(self, finish_count):
153
+ """Check the progress of backup from host to storage."""
154
+ for _ in range(finish_count):
155
+ storage_operation = self.cache_controller.ack_backup_queue.get()
156
+ ack_id = storage_operation.id
157
+ req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
158
+
159
+ # Release host memory
160
+ self.decode_host_mem_pool.free(host_indices)
161
+
162
+ logger.debug(
163
+ f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
164
+ )
165
+
166
+ def _trigger_backup(self, req_id, host_indices, tokens, start_time):
167
+ """Trigger async backup from host to storage by cache controller."""
168
+
169
+ # Generate page hashes and write to storage
170
+ page_hashes = self._compute_prefix_hash(tokens)
171
+ ack_id = self.cache_controller.write_storage(
172
+ host_indices,
173
+ tokens,
174
+ hash_value=page_hashes,
175
+ )
176
+ self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
177
+
178
+ def _compute_prefix_hash(self, tokens):
179
+ last_hash = ""
180
+ page_hashes = []
181
+ for offset in range(0, len(tokens), self.page_size):
182
+ page_tokens = tokens[offset : offset + self.page_size]
183
+ last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
184
+ page_hashes.append(last_hash)
185
+ return page_hashes
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
76
76
  req_pool_indices, dtype=torch.int64, device=self.device
77
77
  )
78
78
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
79
80
  self.orig_seq_lens = torch.tensor(
80
81
  seq_lens, dtype=torch.int32, device=self.device
81
82
  )
@@ -125,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
125
126
  req.grammar.finished = req.finished()
126
127
  self.output_ids = torch.tensor(self.output_ids, device=self.device)
127
128
 
128
- # Simulate the eagle run. We add mock data to hidden states for the
129
- # ease of implementation now meaning the first token will have acc rate
130
- # of 0.
131
- if not self.spec_algorithm.is_none():
129
+ # Simulate the eagle run.
130
+ if self.spec_algorithm.is_eagle():
132
131
 
133
132
  b = len(self.reqs)
134
- topk_p = torch.arange(
135
- b * server_args.speculative_eagle_topk,
136
- 0,
137
- -1,
138
- device=self.device,
139
- dtype=torch.float32,
133
+ topk = server_args.speculative_eagle_topk
134
+ topk_p = torch.stack(
135
+ [
136
+ torch.as_tensor(
137
+ req.output_topk_p[:topk],
138
+ device=self.device,
139
+ dtype=torch.float32,
140
+ )
141
+ for req in self.reqs
142
+ ],
143
+ dim=0,
140
144
  )
141
- topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
142
- topk_p /= b * server_args.speculative_eagle_topk
143
- topk_index = torch.arange(
144
- b * server_args.speculative_eagle_topk, device=self.device
145
+ topk_index = torch.stack(
146
+ [
147
+ torch.as_tensor(
148
+ req.output_topk_index[:topk],
149
+ device=self.device,
150
+ dtype=torch.int64,
151
+ )
152
+ for req in self.reqs
153
+ ],
154
+ dim=0,
145
155
  )
146
- topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
147
156
 
148
157
  hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
149
158
  hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
150
159
 
151
160
  # local import to avoid circular import
152
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
161
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
153
162
 
154
163
  spec_info = EagleDraftInput(
155
164
  topk_p=topk_p,
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
48
48
  def send(
49
49
  self,
50
50
  kv_indices: npt.NDArray[np.int32],
51
+ state_indices: Optional[List[int]] = None,
51
52
  ):
52
53
  self.has_sent = True
53
- logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
54
+ logger.debug(
55
+ f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
56
+ )
54
57
 
55
58
  def failure_exception(self):
56
59
  raise Exception("Fake KVSender Exception")
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
75
78
  logger.debug("FakeKVReceiver poll success")
76
79
  return KVPoll.Success
77
80
 
78
- def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
81
+ def init(
82
+ self,
83
+ kv_indices: list[int],
84
+ aux_index: Optional[int] = None,
85
+ state_indices: Optional[List[int]] = None,
86
+ ):
79
87
  self.has_init = True
80
88
  logger.debug(
81
- f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
89
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
82
90
  )
83
91
 
84
92
  def failure_exception(self):