sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,11 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import time
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
27
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
28
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28
29
 
29
30
  import torch
30
31
  from torch.distributed import ProcessGroup
@@ -45,7 +46,7 @@ from sglang.srt.disaggregation.utils import (
45
46
  prepare_abort,
46
47
  )
47
48
  from sglang.srt.layers.dp_attention import get_attention_tp_size
48
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
49
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
49
50
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
50
51
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
52
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
@@ -218,8 +219,10 @@ class DecodePreallocQueue:
218
219
 
219
220
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
220
221
  kv_args.gpu_id = self.scheduler.gpu_id
221
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
222
- kv_manager = kv_manager_class(
222
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
223
+ self.transfer_backend, KVClassType.MANAGER
224
+ )
225
+ kv_manager: BaseKVManager = kv_manager_class(
223
226
  kv_args,
224
227
  DisaggregationMode.DECODE,
225
228
  self.scheduler.server_args,
@@ -248,9 +251,10 @@ class DecodePreallocQueue:
248
251
  mgr=self.kv_manager,
249
252
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
250
253
  bootstrap_room=req.bootstrap_room,
251
- data_parallel_rank=req.data_parallel_rank,
254
+ prefill_dp_rank=req.data_parallel_rank,
252
255
  )
253
256
 
257
+ req.add_latency(RequestStage.DECODE_PREPARE)
254
258
  self.queue.append(
255
259
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
256
260
  )
@@ -419,8 +423,13 @@ class DecodePreallocQueue:
419
423
  kv_indices, self.token_to_kv_pool_allocator.page_size
420
424
  )
421
425
  decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
426
+
422
427
  preallocated_reqs.append(decode_req)
423
428
  indices_to_remove.add(i)
429
+ decode_req.req.time_stats.decode_transfer_queue_entry_time = (
430
+ time.perf_counter()
431
+ )
432
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
424
433
 
425
434
  self.queue = [
426
435
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -514,11 +523,19 @@ class DecodePreallocQueue:
514
523
  dtype=torch.int64,
515
524
  device=self.token_to_kv_pool_allocator.device,
516
525
  ),
526
+ prefix_lens_cpu=torch.tensor(
527
+ [0],
528
+ dtype=torch.int64,
529
+ ),
517
530
  seq_lens=torch.tensor(
518
531
  [num_tokens],
519
532
  dtype=torch.int64,
520
533
  device=self.token_to_kv_pool_allocator.device,
521
534
  ),
535
+ seq_lens_cpu=torch.tensor(
536
+ [num_tokens],
537
+ dtype=torch.int64,
538
+ ),
522
539
  last_loc=torch.tensor(
523
540
  [-1],
524
541
  dtype=torch.int64,
@@ -605,16 +622,23 @@ class DecodeTransferQueue:
605
622
  idx = decode_req.metadata_buffer_index
606
623
  (
607
624
  output_id,
625
+ cached_tokens,
608
626
  output_token_logprobs_val,
609
627
  output_token_logprobs_idx,
610
628
  output_top_logprobs_val,
611
629
  output_top_logprobs_idx,
630
+ output_topk_p,
631
+ output_topk_index,
612
632
  output_hidden_states,
613
633
  ) = self.metadata_buffers.get_buf(idx)
614
634
 
615
635
  decode_req.req.output_ids.append(output_id[0].item())
636
+ decode_req.req.cached_tokens = cached_tokens[0].item()
616
637
  if not self.spec_algorithm.is_none():
638
+ decode_req.req.output_topk_p = output_topk_p
639
+ decode_req.req.output_topk_index = output_topk_index
617
640
  decode_req.req.hidden_states_tensor = output_hidden_states
641
+
618
642
  if decode_req.req.return_logprob:
619
643
  decode_req.req.output_token_logprobs_val.append(
620
644
  output_token_logprobs_val[0].item()
@@ -635,10 +659,17 @@ class DecodeTransferQueue:
635
659
 
636
660
  if hasattr(decode_req.kv_receiver, "clear"):
637
661
  decode_req.kv_receiver.clear()
662
+ decode_req.kv_receiver = None
663
+
664
+ indices_to_remove.add(i)
665
+ decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
638
666
 
639
667
  # special handling for sampling_params.max_new_tokens == 1
640
668
  if decode_req.req.sampling_params.max_new_tokens == 1:
641
669
  # finish immediately
670
+ decode_req.req.time_stats.forward_entry_time = (
671
+ decode_req.req.time_stats.completion_time
672
+ ) = time.perf_counter()
642
673
  decode_req.req.check_finished()
643
674
  self.scheduler.stream_output(
644
675
  [decode_req.req], decode_req.req.return_logprob
@@ -646,8 +677,6 @@ class DecodeTransferQueue:
646
677
  self.tree_cache.cache_finished_req(decode_req.req)
647
678
  else:
648
679
  transferred_reqs.append(decode_req.req)
649
-
650
- indices_to_remove.add(i)
651
680
  elif poll in [
652
681
  KVPoll.Bootstrapping,
653
682
  KVPoll.WaitingForInput,
@@ -660,6 +689,7 @@ class DecodeTransferQueue:
660
689
  for i in indices_to_remove:
661
690
  idx = self.queue[i].metadata_buffer_index
662
691
  assert idx != -1
692
+ self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
663
693
  self.req_to_metadata_buffer_idx_allocator.free(idx)
664
694
 
665
695
  self.queue = [
@@ -702,23 +732,28 @@ class SchedulerDisaggregationDecodeMixin:
702
732
  elif prepare_mlp_sync_flag:
703
733
  batch, _ = self._prepare_idle_batch_and_run(None)
704
734
 
705
- if batch is None and (
735
+ queue_size = (
706
736
  len(self.waiting_queue)
707
737
  + len(self.disagg_decode_transfer_queue.queue)
708
738
  + len(self.disagg_decode_prealloc_queue.queue)
709
- == 0
710
- ):
739
+ )
740
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
741
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
742
+
743
+ if batch is None and queue_size == 0:
711
744
  self.self_check_during_idle()
712
745
 
713
746
  self.last_batch = batch
714
747
 
715
748
  @torch.no_grad()
716
749
  def event_loop_overlap_disagg_decode(self: Scheduler):
717
- result_queue = deque()
750
+ self.result_queue = deque()
718
751
  self.last_batch: Optional[ScheduleBatch] = None
719
752
  self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
720
753
 
721
754
  while True:
755
+ self.launch_last_batch_sample_if_needed()
756
+
722
757
  recv_reqs = self.recv_requests()
723
758
  self.process_input_requests(recv_reqs)
724
759
  # polling and allocating kv cache
@@ -741,23 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
741
776
  None, delay_process=True
742
777
  )
743
778
  if batch_:
744
- result_queue.append((batch_.copy(), result))
779
+ self.result_queue.append((batch_.copy(), result))
745
780
  last_batch_in_queue = True
746
781
  else:
747
782
  if prepare_mlp_sync_flag:
748
783
  self.prepare_mlp_sync_batch(batch)
749
784
  result = self.run_batch(batch)
750
- result_queue.append((batch.copy(), result))
751
-
752
- if (self.last_batch is None) or (not self.last_batch_in_queue):
753
- # Create a dummy first batch to start the pipeline for overlap schedule.
754
- # It is now used for triggering the sampling_info_done event.
755
- tmp_batch = ScheduleBatch(
756
- reqs=None,
757
- forward_mode=ForwardMode.DUMMY_FIRST,
758
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
759
- )
760
- self.set_next_batch_sampling_info_done(tmp_batch)
785
+ self.result_queue.append((batch.copy(), result))
761
786
  last_batch_in_queue = True
762
787
 
763
788
  elif prepare_mlp_sync_flag:
@@ -765,23 +790,23 @@ class SchedulerDisaggregationDecodeMixin:
765
790
  None, delay_process=True
766
791
  )
767
792
  if batch:
768
- result_queue.append((batch.copy(), result))
793
+ self.result_queue.append((batch.copy(), result))
769
794
  last_batch_in_queue = True
770
795
 
771
796
  # Process the results of the previous batch but skip if the last batch is extend
772
797
  if self.last_batch and self.last_batch_in_queue:
773
- tmp_batch, tmp_result = result_queue.popleft()
774
- tmp_batch.next_batch_sampling_info = (
775
- self.tp_worker.cur_sampling_info if batch else None
776
- )
798
+ tmp_batch, tmp_result = self.result_queue.popleft()
777
799
  self.process_batch_result(tmp_batch, tmp_result)
778
800
 
779
- if batch is None and (
801
+ queue_size = (
780
802
  len(self.waiting_queue)
781
803
  + len(self.disagg_decode_transfer_queue.queue)
782
804
  + len(self.disagg_decode_prealloc_queue.queue)
783
- == 0
784
- ):
805
+ )
806
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
807
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
808
+
809
+ if batch is None and queue_size == 0:
785
810
  self.self_check_during_idle()
786
811
 
787
812
  self.last_batch = batch
@@ -851,6 +876,7 @@ class SchedulerDisaggregationDecodeMixin:
851
876
  # we can only add at least `num_not_used_batch` new batch to the running queue
852
877
  if i < num_not_used_batch:
853
878
  can_run_list.append(req)
879
+ req.add_latency(RequestStage.DECODE_WAITING)
854
880
  req.init_next_round_input(self.tree_cache)
855
881
  else:
856
882
  waiting_queue.append(req)
@@ -859,6 +885,9 @@ class SchedulerDisaggregationDecodeMixin:
859
885
  if len(can_run_list) == 0:
860
886
  return None
861
887
 
888
+ for req in can_run_list:
889
+ req.time_stats.forward_entry_time = time.perf_counter()
890
+
862
891
  # construct a schedule batch with those requests and mark as decode
863
892
  new_batch = ScheduleBatch.init_new(
864
893
  can_run_list,
@@ -884,9 +913,21 @@ class SchedulerDisaggregationDecodeMixin:
884
913
  # if there are still retracted requests, we do not allocate new requests
885
914
  return
886
915
 
887
- req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
888
- self.disagg_decode_transfer_queue.extend(req_conns)
889
- alloc_reqs = (
890
- self.disagg_decode_transfer_queue.pop_transferred()
891
- ) # the requests which kv has arrived
892
- self.waiting_queue.extend(alloc_reqs)
916
+ if not hasattr(self, "polling_count"):
917
+ self.polling_count = 0
918
+ self.polling_interval = (
919
+ self.server_args.disaggregation_decode_polling_interval
920
+ )
921
+
922
+ self.polling_count = (self.polling_count + 1) % self.polling_interval
923
+
924
+ if self.polling_count % self.polling_interval == 0:
925
+ req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
926
+ self.disagg_decode_transfer_queue.extend(req_conns)
927
+ alloc_reqs = (
928
+ self.disagg_decode_transfer_queue.pop_transferred()
929
+ ) # the requests which kv has arrived
930
+ self.waiting_queue.extend(alloc_reqs)
931
+
932
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
933
+ self.decode_offload_manager.check_offload_progress()
@@ -0,0 +1,185 @@
1
+ import logging
2
+ import threading
3
+ import time
4
+
5
+ import torch
6
+
7
+ from sglang.srt.managers.cache_controller import HiCacheController
8
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
+ from sglang.srt.mem_cache.memory_pool import (
11
+ MHATokenToKVPool,
12
+ MLATokenToKVPool,
13
+ ReqToTokenPool,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool_host import (
16
+ MHATokenToKVPoolHost,
17
+ MLATokenToKVPoolHost,
18
+ )
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DecodeKVCacheOffloadManager:
25
+ """Manage decode-side KV cache offloading lifecycle and operations."""
26
+
27
+ def __init__(
28
+ self,
29
+ req_to_token_pool: ReqToTokenPool,
30
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
+ tp_group: torch.distributed.ProcessGroup,
32
+ tree_cache: BasePrefixCache,
33
+ server_args: ServerArgs,
34
+ ) -> None:
35
+ self.req_to_token_pool = req_to_token_pool
36
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
37
+ self.page_size = server_args.page_size
38
+ self.server_args = server_args
39
+ self.request_counter = 0
40
+ self.tree_cache = tree_cache
41
+ kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
42
+ if isinstance(kv_cache, MHATokenToKVPool):
43
+ self.decode_host_mem_pool = MHATokenToKVPoolHost(
44
+ kv_cache,
45
+ server_args.hicache_ratio,
46
+ server_args.hicache_size,
47
+ self.page_size,
48
+ server_args.hicache_mem_layout,
49
+ )
50
+ elif isinstance(kv_cache, MLATokenToKVPool):
51
+ self.decode_host_mem_pool = MLATokenToKVPoolHost(
52
+ kv_cache,
53
+ server_args.hicache_ratio,
54
+ server_args.hicache_size,
55
+ self.page_size,
56
+ server_args.hicache_mem_layout,
57
+ )
58
+ else:
59
+ raise ValueError("Unsupported KV cache type for decode offload")
60
+
61
+ self.tp_group = tp_group
62
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
63
+ self.cache_controller = HiCacheController(
64
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
65
+ mem_pool_host=self.decode_host_mem_pool,
66
+ page_size=self.page_size,
67
+ tp_group=tp_group,
68
+ io_backend=server_args.hicache_io_backend,
69
+ load_cache_event=threading.Event(),
70
+ storage_backend=server_args.hicache_storage_backend,
71
+ model_name=server_args.served_model_name,
72
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
73
+ )
74
+
75
+ self.ongoing_offload = {}
76
+ self.ongoing_backup = {}
77
+ logger.info("Enable offload kv cache for decode side")
78
+
79
+ def offload_kv_cache(self, req) -> bool:
80
+ """Offload a finished request's KV cache to storage."""
81
+
82
+ if self.cache_controller is None or self.decode_host_mem_pool is None:
83
+ return False
84
+
85
+ if req.req_pool_idx == -1:
86
+ return False
87
+
88
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
89
+ if token_indices.dim() == 0 or token_indices.numel() == 0:
90
+ logger.debug(
91
+ f"Request {req.rid} has invalid token_indices: {token_indices}"
92
+ )
93
+ return False
94
+
95
+ tokens = req.origin_input_ids + req.output_ids
96
+ aligned_len = (len(tokens) // self.page_size) * self.page_size
97
+ if aligned_len == 0:
98
+ return False
99
+
100
+ token_indices = token_indices[:aligned_len]
101
+ tokens = tokens[:aligned_len]
102
+
103
+ # Asynchronously offload KV cache from device to host by cache controller
104
+ self.request_counter += 1
105
+ ack_id = self.request_counter
106
+ host_indices = self.cache_controller.write(
107
+ device_indices=token_indices.long(),
108
+ node_id=ack_id,
109
+ )
110
+ if host_indices is None:
111
+ logger.error(f"Not enough host memory for request {req.rid}")
112
+ return False
113
+
114
+ self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
115
+ return True
116
+
117
+ def check_offload_progress(self):
118
+ """Check the progress of offload from device to host and backup from host to storage."""
119
+ cc = self.cache_controller
120
+
121
+ qsizes = torch.tensor(
122
+ [
123
+ len(cc.ack_write_queue),
124
+ cc.ack_backup_queue.qsize(),
125
+ ],
126
+ dtype=torch.int,
127
+ )
128
+ if self.tp_world_size > 1:
129
+ torch.distributed.all_reduce(
130
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
131
+ )
132
+
133
+ n_write, n_backup = map(int, qsizes.tolist())
134
+ self._check_offload_progress(n_write)
135
+ self._check_backup_progress(n_backup)
136
+
137
+ def _check_offload_progress(self, finish_count):
138
+ """Check the progress of offload from device to host."""
139
+ while finish_count > 0:
140
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
141
+ finish_event.synchronize()
142
+ for ack_id in ack_list:
143
+ req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
144
+
145
+ # Release device
146
+ self.tree_cache.cache_finished_req(req)
147
+
148
+ # Trigger async backup from host to storage by cache controller
149
+ self._trigger_backup(req.rid, host_indices, tokens, start_time)
150
+ finish_count -= 1
151
+
152
+ def _check_backup_progress(self, finish_count):
153
+ """Check the progress of backup from host to storage."""
154
+ for _ in range(finish_count):
155
+ storage_operation = self.cache_controller.ack_backup_queue.get()
156
+ ack_id = storage_operation.id
157
+ req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
158
+
159
+ # Release host memory
160
+ self.decode_host_mem_pool.free(host_indices)
161
+
162
+ logger.debug(
163
+ f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
164
+ )
165
+
166
+ def _trigger_backup(self, req_id, host_indices, tokens, start_time):
167
+ """Trigger async backup from host to storage by cache controller."""
168
+
169
+ # Generate page hashes and write to storage
170
+ page_hashes = self._compute_prefix_hash(tokens)
171
+ ack_id = self.cache_controller.write_storage(
172
+ host_indices,
173
+ tokens,
174
+ hash_value=page_hashes,
175
+ )
176
+ self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
177
+
178
+ def _compute_prefix_hash(self, tokens):
179
+ last_hash = ""
180
+ page_hashes = []
181
+ for offset in range(0, len(tokens), self.page_size):
182
+ page_tokens = tokens[offset : offset + self.page_size]
183
+ last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
184
+ page_hashes.append(last_hash)
185
+ return page_hashes
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
76
76
  req_pool_indices, dtype=torch.int64, device=self.device
77
77
  )
78
78
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
79
80
  self.orig_seq_lens = torch.tensor(
80
81
  seq_lens, dtype=torch.int32, device=self.device
81
82
  )
@@ -110,7 +111,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
110
111
  if req.grammar is not None:
111
112
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
112
113
  try:
113
- req.grammar.accept_token(req.output_ids[-1])
114
+ # if it is not None, then the grammar is from a retracted request, and we should not
115
+ # accept the token as it's already accepted
116
+ if req.grammar.current_token is None:
117
+ req.grammar.accept_token(req.output_ids[-1])
114
118
  except ValueError as e:
115
119
  # Grammar accept_token can raise ValueError if the token is not in the grammar.
116
120
  # This can happen if the grammar is not set correctly or the token is invalid.
@@ -122,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
122
126
  req.grammar.finished = req.finished()
123
127
  self.output_ids = torch.tensor(self.output_ids, device=self.device)
124
128
 
125
- # Simulate the eagle run. We add mock data to hidden states for the
126
- # ease of implementation now meaning the first token will have acc rate
127
- # of 0.
128
- if not self.spec_algorithm.is_none():
129
+ # Simulate the eagle run.
130
+ if self.spec_algorithm.is_eagle():
129
131
 
130
132
  b = len(self.reqs)
131
- topk_p = torch.arange(
132
- b * server_args.speculative_eagle_topk,
133
- 0,
134
- -1,
135
- device=self.device,
136
- dtype=torch.float32,
133
+ topk = server_args.speculative_eagle_topk
134
+ topk_p = torch.stack(
135
+ [
136
+ torch.as_tensor(
137
+ req.output_topk_p[:topk],
138
+ device=self.device,
139
+ dtype=torch.float32,
140
+ )
141
+ for req in self.reqs
142
+ ],
143
+ dim=0,
137
144
  )
138
- topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
139
- topk_p /= b * server_args.speculative_eagle_topk
140
- topk_index = torch.arange(
141
- b * server_args.speculative_eagle_topk, device=self.device
145
+ topk_index = torch.stack(
146
+ [
147
+ torch.as_tensor(
148
+ req.output_topk_index[:topk],
149
+ device=self.device,
150
+ dtype=torch.int64,
151
+ )
152
+ for req in self.reqs
153
+ ],
154
+ dim=0,
142
155
  )
143
- topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
144
156
 
145
157
  hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
146
158
  hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
147
159
 
148
160
  # local import to avoid circular import
149
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
161
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
150
162
 
151
163
  spec_info = EagleDraftInput(
152
164
  topk_p=topk_p,
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
62
62
  mgr: BaseKVManager,
63
63
  bootstrap_addr: str,
64
64
  bootstrap_room: Optional[int] = None,
65
- data_parallel_rank: Optional[int] = None,
65
+ prefill_dp_rank: Optional[int] = None,
66
66
  ):
67
67
  self.has_init = False
68
68