sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import enum
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -35,6 +37,7 @@ import copy
35
37
  import dataclasses
36
38
  import logging
37
39
  import threading
40
+ import time
38
41
  from enum import Enum, auto
39
42
  from http import HTTPStatus
40
43
  from itertools import chain
@@ -51,18 +54,18 @@ from sglang.srt.disaggregation.base import BaseKVSender
51
54
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
55
  ScheduleBatchDisaggregationDecodeMixin,
53
56
  )
57
+ from sglang.srt.disaggregation.utils import DisaggregationMode
54
58
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.layers.moe import is_tbo_enabled
56
59
  from sglang.srt.mem_cache.allocator import (
57
60
  BaseTokenToKVPoolAllocator,
58
61
  SWATokenToKVPoolAllocator,
59
62
  )
60
63
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
61
64
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
62
- from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
63
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
65
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
66
+ from sglang.srt.mem_cache.radix_cache import RadixKey
64
67
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
65
- from sglang.srt.metrics.collector import TimeStats
68
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
66
69
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
67
70
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
68
71
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -71,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
71
74
 
72
75
  if TYPE_CHECKING:
73
76
  from sglang.srt.configs.model_config import ModelConfig
74
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
75
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
76
78
 
77
79
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
78
80
 
@@ -87,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
87
89
  "disable_flashinfer_cutlass_moe_fp4_allgather",
88
90
  "disable_radix_cache",
89
91
  "enable_dp_lm_head",
92
+ "enable_fp32_lm_head",
90
93
  "flashinfer_mxfp4_moe_precision",
91
94
  "enable_flashinfer_allreduce_fusion",
92
95
  "moe_dense_tp_size",
@@ -94,20 +97,24 @@ GLOBAL_SERVER_ARGS_KEYS = [
94
97
  "ep_num_redundant_experts",
95
98
  "enable_nan_detection",
96
99
  "flashinfer_mla_disable_ragged",
97
- "max_micro_batch_size",
100
+ "pp_max_micro_batch_size",
98
101
  "disable_shared_experts_fusion",
99
102
  "sampling_backend",
100
103
  "speculative_accept_threshold_single",
101
104
  "speculative_accept_threshold_acc",
105
+ "speculative_attention_mode",
102
106
  "torchao_config",
103
107
  "triton_attention_reduce_in_fp32",
104
108
  "num_reserved_decode_tokens",
105
109
  "weight_loader_disable_mmap",
106
110
  "enable_multimodal",
107
111
  "enable_symm_mem",
108
- "quantization",
109
112
  "enable_custom_logit_processor",
110
113
  "disaggregation_mode",
114
+ "enable_deterministic_inference",
115
+ "nsa_prefill",
116
+ "nsa_decode",
117
+ "multi_item_scoring_delimiter",
111
118
  ]
112
119
 
113
120
  # Put some global args for easy access
@@ -408,6 +415,23 @@ class MultimodalInputs:
408
415
  # other args would be kept intact
409
416
 
410
417
 
418
+ class RequestStage(str, enum.Enum):
419
+ # prefill
420
+ PREFILL_WAITING = "prefill_waiting"
421
+
422
+ # disaggregation prefill
423
+ PREFILL_PREPARE = "prefill_prepare"
424
+ PREFILL_BOOTSTRAP = "prefill_bootstrap"
425
+ PREFILL_FORWARD = "prefill_forward"
426
+ PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
427
+
428
+ # disaggregation decode
429
+ DECODE_PREPARE = "decode_prepare"
430
+ DECODE_BOOTSTRAP = "decode_bootstrap"
431
+ DECODE_WAITING = "decode_waiting"
432
+ DECODE_TRANSFERRED = "decode_transferred"
433
+
434
+
411
435
  class Req:
412
436
  """The input and output status of a request."""
413
437
 
@@ -432,8 +456,12 @@ class Req:
432
456
  bootstrap_host: Optional[str] = None,
433
457
  bootstrap_port: Optional[int] = None,
434
458
  bootstrap_room: Optional[int] = None,
459
+ disagg_mode: Optional[DisaggregationMode] = None,
435
460
  data_parallel_rank: Optional[int] = None,
436
461
  vocab_size: Optional[int] = None,
462
+ priority: Optional[int] = None,
463
+ metrics_collector: Optional[SchedulerMetricsCollector] = None,
464
+ extra_key: Optional[str] = None,
437
465
  ):
438
466
  # Input and output info
439
467
  self.rid = rid
@@ -466,6 +494,14 @@ class Req:
466
494
  self.sampling_params = sampling_params
467
495
  self.custom_logit_processor = custom_logit_processor
468
496
  self.return_hidden_states = return_hidden_states
497
+
498
+ # extra key for classifying the request (e.g. cache_salt)
499
+ if lora_id is not None:
500
+ extra_key = (
501
+ extra_key or ""
502
+ ) + lora_id # lora_id is concatenated to the extra key
503
+
504
+ self.extra_key = extra_key
469
505
  self.lora_id = lora_id
470
506
 
471
507
  # Memory pool info
@@ -484,6 +520,7 @@ class Req:
484
520
  self.stream = stream
485
521
  self.eos_token_ids = eos_token_ids
486
522
  self.vocab_size = vocab_size
523
+ self.priority = priority
487
524
 
488
525
  # For incremental decoding
489
526
  # ----- | --------- read_ids -------|
@@ -503,7 +540,7 @@ class Req:
503
540
 
504
541
  # Prefix info
505
542
  # The indices to kv cache for the shared prefix.
506
- self.prefix_indices: torch.Tensor = []
543
+ self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
507
544
  # Number of tokens to run prefill.
508
545
  self.extend_input_len = 0
509
546
  # The relative logprob_start_len in an extend batch
@@ -513,6 +550,8 @@ class Req:
513
550
  self.host_hit_length = 0
514
551
  # The node to lock until for swa radix tree lock ref
515
552
  self.swa_uuid_for_lock: Optional[int] = None
553
+ # The prefix length of the last prefix matching
554
+ self.last_matched_prefix_len: int = 0
516
555
 
517
556
  # Whether or not if it is chunked. It increments whenever
518
557
  # it is chunked, and decrement whenever chunked request is
@@ -561,7 +600,10 @@ class Req:
561
600
  # shape: (bs, k)
562
601
  self.output_top_logprobs_val = []
563
602
  self.output_top_logprobs_idx = []
564
- self.output_token_ids_logprobs_val = []
603
+ # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
604
+ self.output_token_ids_logprobs_val: List[
605
+ Union[List[float], torch.Tensor]
606
+ ] = []
565
607
  self.output_token_ids_logprobs_idx = []
566
608
  else:
567
609
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
@@ -571,6 +613,8 @@ class Req:
571
613
  ) = None
572
614
  self.hidden_states: List[List[float]] = []
573
615
  self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
616
+ self.output_topk_p = None
617
+ self.output_topk_index = None
574
618
 
575
619
  # Embedding (return values)
576
620
  self.embedding = None
@@ -588,10 +632,10 @@ class Req:
588
632
  self.spec_verify_ct = 0
589
633
 
590
634
  # For metrics
591
- self.time_stats: TimeStats = TimeStats()
635
+ self.metrics_collector = metrics_collector
636
+ self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
592
637
  self.has_log_time_stats: bool = False
593
- self.queue_time_start = None
594
- self.queue_time_end = None
638
+ self.last_tic = time.monotonic()
595
639
 
596
640
  # For disaggregation
597
641
  self.bootstrap_host: str = bootstrap_host
@@ -619,6 +663,27 @@ class Req:
619
663
  def seqlen(self):
620
664
  return len(self.origin_input_ids) + len(self.output_ids)
621
665
 
666
+ @property
667
+ def is_prefill_only(self) -> bool:
668
+ """Check if this request is prefill-only (no token generation needed)."""
669
+ # NOTE: when spec is enabled, prefill_only optimizations are disabled
670
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
671
+
672
+ spec_alg = global_server_args_dict["speculative_algorithm"]
673
+ return self.sampling_params.max_new_tokens == 0 and (
674
+ spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
675
+ )
676
+
677
+ def add_latency(self, stage: RequestStage):
678
+ if self.metrics_collector is None:
679
+ return
680
+
681
+ now = time.monotonic()
682
+ self.metrics_collector.observe_per_stage_req_latency(
683
+ stage.value, now - self.last_tic
684
+ )
685
+ self.last_tic = now
686
+
622
687
  def extend_image_inputs(self, image_inputs):
623
688
  if self.multimodal_inputs is None:
624
689
  self.multimodal_inputs = image_inputs
@@ -629,51 +694,27 @@ class Req:
629
694
  # Whether request reached finished condition
630
695
  return self.finished_reason is not None
631
696
 
632
- def init_next_round_input(
633
- self,
634
- tree_cache: Optional[BasePrefixCache] = None,
635
- ):
636
- self.fill_ids = self.origin_input_ids + self.output_ids
637
- if tree_cache is not None:
638
- if isinstance(tree_cache, LoRARadixCache):
639
- (
640
- self.prefix_indices,
641
- self.last_node,
642
- self.last_host_node,
643
- self.host_hit_length,
644
- ) = tree_cache.match_prefix_with_lora_id(
645
- key=LoRAKey(
646
- lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
647
- ),
648
- )
649
- else:
650
- (
651
- self.prefix_indices,
652
- self.last_node,
653
- self.last_host_node,
654
- self.host_hit_length,
655
- ) = tree_cache.match_prefix(
656
- key=self.adjust_max_prefix_ids(),
657
- )
658
- self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
659
-
660
- def adjust_max_prefix_ids(self):
697
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
661
698
  self.fill_ids = self.origin_input_ids + self.output_ids
662
699
  input_len = len(self.fill_ids)
663
-
664
- # FIXME: To work around some bugs in logprob computation, we need to ensure each
665
- # request has at least one token. Later, we can relax this requirement and use `input_len`.
700
+ # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
666
701
  max_prefix_len = input_len - 1
667
-
668
- if self.sampling_params.max_new_tokens > 0:
669
- # Need at least one token to compute logits
670
- max_prefix_len = min(max_prefix_len, input_len - 1)
671
-
672
702
  if self.return_logprob:
673
703
  max_prefix_len = min(max_prefix_len, self.logprob_start_len)
674
-
675
704
  max_prefix_len = max(max_prefix_len, 0)
676
- return self.fill_ids[:max_prefix_len]
705
+ token_ids = self.fill_ids[:max_prefix_len]
706
+
707
+ if tree_cache is not None:
708
+ (
709
+ self.prefix_indices,
710
+ self.last_node,
711
+ self.last_host_node,
712
+ self.host_hit_length,
713
+ ) = tree_cache.match_prefix(
714
+ key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
715
+ )
716
+ self.last_matched_prefix_len = len(self.prefix_indices)
717
+ self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
677
718
 
678
719
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
679
720
  def init_incremental_detokenize(self):
@@ -684,9 +725,15 @@ class Req:
684
725
  self.surr_offset = max(
685
726
  self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
686
727
  )
728
+ self.surr_and_decode_ids = (
729
+ self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
730
+ )
731
+ self.cur_decode_ids_len = len(self.output_ids)
732
+ else:
733
+ self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
734
+ self.cur_decode_ids_len = len(self.output_ids)
687
735
 
688
- all_ids = self.origin_input_ids_unpadded + self.output_ids
689
- return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
736
+ return self.surr_and_decode_ids, self.read_offset - self.surr_offset
690
737
 
691
738
  def check_finished(self):
692
739
  if self.finished():
@@ -749,7 +796,7 @@ class Req:
749
796
  return
750
797
 
751
798
  def reset_for_retract(self):
752
- self.prefix_indices = []
799
+ self.prefix_indices = torch.empty((0,), dtype=torch.int64)
753
800
  self.last_node = None
754
801
  self.swa_uuid_for_lock = None
755
802
  self.extend_input_len = 0
@@ -781,10 +828,10 @@ class Req:
781
828
  return
782
829
 
783
830
  if self.bootstrap_room is not None:
784
- prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
831
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
785
832
  else:
786
- prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
787
- logger.info(f"{prefix}: {self.time_stats}")
833
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
834
+ logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
788
835
  self.has_log_time_stats = True
789
836
 
790
837
  def set_finish_with_abort(self, error_msg: str):
@@ -807,10 +854,6 @@ class Req:
807
854
  )
808
855
 
809
856
 
810
- # Batch id
811
- bid = 0
812
-
813
-
814
857
  @dataclasses.dataclass
815
858
  class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
816
859
  """Store all information of a batch on the scheduler."""
@@ -831,15 +874,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
831
874
  # This is an optimization to reduce the overhead of the prefill check.
832
875
  batch_is_full: bool = False
833
876
 
834
- # Events
835
- launch_done: Optional[threading.Event] = None
836
-
837
877
  # For chunked prefill in PP
838
878
  chunked_req: Optional[Req] = None
839
879
 
840
880
  # Sampling info
841
881
  sampling_info: SamplingBatchInfo = None
842
- next_batch_sampling_info: SamplingBatchInfo = None
843
882
 
844
883
  # Batched arguments to model runner
845
884
  input_ids: torch.Tensor = None # shape: [b], int64
@@ -847,6 +886,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
847
886
  token_type_ids: torch.Tensor = None # shape: [b], int64
848
887
  req_pool_indices: torch.Tensor = None # shape: [b], int64
849
888
  seq_lens: torch.Tensor = None # shape: [b], int64
889
+ seq_lens_cpu: torch.Tensor = None # shape: [b], int64
850
890
  # The output locations of the KV cache
851
891
  out_cache_loc: torch.Tensor = None # shape: [b], int64
852
892
  output_ids: torch.Tensor = None # shape: [b], int64
@@ -902,7 +942,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
902
942
 
903
943
  # Speculative decoding
904
944
  spec_algorithm: SpeculativeAlgorithm = None
905
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
945
+ # spec_info: Optional[SpecInput] = None
946
+ spec_info: Optional[SpecInput] = None
906
947
 
907
948
  # Whether to return hidden states
908
949
  return_hidden_states: bool = False
@@ -911,7 +952,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
911
952
  is_prefill_only: bool = False
912
953
 
913
954
  # hicache pointer for synchronizing data loading from CPU to GPU
914
- hicache_consumer_index: int = 0
955
+ hicache_consumer_index: int = -1
915
956
 
916
957
  @classmethod
917
958
  def init_new(
@@ -950,9 +991,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
950
991
  device=req_to_token_pool.device,
951
992
  spec_algorithm=spec_algorithm,
952
993
  return_hidden_states=any(req.return_hidden_states for req in reqs),
953
- is_prefill_only=all(
954
- req.sampling_params.max_new_tokens == 0 for req in reqs
955
- ),
994
+ is_prefill_only=all(req.is_prefill_only for req in reqs),
956
995
  chunked_req=chunked_req,
957
996
  )
958
997
 
@@ -962,8 +1001,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
962
1001
  def is_empty(self):
963
1002
  return len(self.reqs) == 0
964
1003
 
965
- def alloc_req_slots(self, num_reqs: int):
966
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
1004
+ def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
1005
+ if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
1006
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
1007
+ else:
1008
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
967
1009
  if req_pool_indices is None:
968
1010
  raise RuntimeError(
969
1011
  "alloc_req_slots runs out of memory. "
@@ -1000,7 +1042,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1000
1042
  def alloc_paged_token_slots_extend(
1001
1043
  self,
1002
1044
  prefix_lens: torch.Tensor,
1045
+ prefix_lens_cpu: torch.Tensor,
1003
1046
  seq_lens: torch.Tensor,
1047
+ seq_lens_cpu: torch.Tensor,
1004
1048
  last_loc: torch.Tensor,
1005
1049
  extend_num_tokens: int,
1006
1050
  backup_state: bool = False,
@@ -1008,7 +1052,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1008
1052
  # Over estimate the number of tokens: assume each request needs a new page.
1009
1053
  num_tokens = (
1010
1054
  extend_num_tokens
1011
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1055
+ + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
1012
1056
  )
1013
1057
  self._evict_tree_cache_if_needed(num_tokens)
1014
1058
 
@@ -1016,7 +1060,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1016
1060
  state = self.token_to_kv_pool_allocator.backup_state()
1017
1061
 
1018
1062
  out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
1019
- prefix_lens, seq_lens, last_loc, extend_num_tokens
1063
+ prefix_lens,
1064
+ prefix_lens_cpu,
1065
+ seq_lens,
1066
+ seq_lens_cpu,
1067
+ last_loc,
1068
+ extend_num_tokens,
1020
1069
  )
1021
1070
  if out_cache_loc is None:
1022
1071
  error_msg = (
@@ -1035,6 +1084,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1035
1084
  def alloc_paged_token_slots_decode(
1036
1085
  self,
1037
1086
  seq_lens: torch.Tensor,
1087
+ seq_lens_cpu: torch.Tensor,
1038
1088
  last_loc: torch.Tensor,
1039
1089
  backup_state: bool = False,
1040
1090
  ):
@@ -1045,7 +1095,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1045
1095
  if backup_state:
1046
1096
  state = self.token_to_kv_pool_allocator.backup_state()
1047
1097
 
1048
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
1098
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
1099
+ seq_lens, seq_lens_cpu, last_loc
1100
+ )
1049
1101
  if out_cache_loc is None:
1050
1102
  error_msg = (
1051
1103
  f"Decode out of memory. Try to lower your batch size.\n"
@@ -1060,6 +1112,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1060
1112
  else:
1061
1113
  return out_cache_loc
1062
1114
 
1115
+ def write_cache_indices(
1116
+ self,
1117
+ req_pool_indices: List[int],
1118
+ prefix_lens: List[int],
1119
+ seq_lens: List[int],
1120
+ extend_lens: List[int],
1121
+ out_cache_loc: torch.Tensor,
1122
+ req_pool_indices_tensor: torch.Tensor,
1123
+ prefix_lens_tensor: torch.Tensor,
1124
+ seq_lens_tensor: torch.Tensor,
1125
+ extend_lens_tensor: torch.Tensor,
1126
+ prefix_tensors: list[torch.Tensor],
1127
+ ):
1128
+ if support_triton(global_server_args_dict.get("attention_backend")):
1129
+ prefix_pointers = torch.tensor(
1130
+ [t.data_ptr() for t in prefix_tensors], device=self.device
1131
+ )
1132
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1133
+ write_req_to_token_pool_triton[(len(req_pool_indices),)](
1134
+ self.req_to_token_pool.req_to_token,
1135
+ req_pool_indices_tensor,
1136
+ prefix_pointers,
1137
+ prefix_lens_tensor,
1138
+ seq_lens_tensor,
1139
+ extend_lens_tensor,
1140
+ out_cache_loc,
1141
+ self.req_to_token_pool.req_to_token.shape[1],
1142
+ )
1143
+ else:
1144
+ pt = 0
1145
+ for i in range(len(req_pool_indices)):
1146
+ self.req_to_token_pool.write(
1147
+ (req_pool_indices[i], slice(0, prefix_lens[i])),
1148
+ prefix_tensors[i],
1149
+ )
1150
+ self.req_to_token_pool.write(
1151
+ (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1152
+ out_cache_loc[pt : pt + extend_lens[i]],
1153
+ )
1154
+ pt += extend_lens[i]
1155
+
1063
1156
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
1064
1157
  self.encoder_lens_cpu = []
1065
1158
  self.encoder_cached = []
@@ -1114,6 +1207,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1114
1207
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1115
1208
  self.device, non_blocking=True
1116
1209
  )
1210
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1117
1211
 
1118
1212
  if not decoder_out_cache_loc:
1119
1213
  self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1136,10 +1230,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1136
1230
  def prepare_for_extend(self):
1137
1231
  self.forward_mode = ForwardMode.EXTEND
1138
1232
 
1139
- # Allocate req slots
1140
- bs = len(self.reqs)
1141
- req_pool_indices = self.alloc_req_slots(bs)
1142
-
1143
1233
  # Init tensors
1144
1234
  reqs = self.reqs
1145
1235
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -1153,21 +1243,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1153
1243
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1154
1244
  ]
1155
1245
 
1156
- req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1157
- self.device, non_blocking=True
1158
- )
1159
1246
  input_ids_tensor = torch.tensor(
1160
1247
  list(chain.from_iterable(input_ids)), dtype=torch.int64
1161
1248
  ).to(self.device, non_blocking=True)
1162
1249
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1163
1250
  self.device, non_blocking=True
1164
1251
  )
1252
+ seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1165
1253
  orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1166
1254
  self.device, non_blocking=True
1167
1255
  )
1168
1256
  prefix_lens_tensor = torch.tensor(
1169
1257
  prefix_lens, dtype=torch.int64, device=self.device
1170
1258
  )
1259
+ prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
1171
1260
 
1172
1261
  token_type_ids_tensor = None
1173
1262
  if len(token_type_ids) > 0:
@@ -1177,7 +1266,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1177
1266
 
1178
1267
  extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1179
1268
 
1180
- # Copy prefix and do some basic check
1269
+ # Allocate req slots
1270
+ bs = len(self.reqs)
1271
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1272
+ req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1273
+ self.device, non_blocking=True
1274
+ )
1275
+
1276
+ # Allocate memory
1277
+ if self.token_to_kv_pool_allocator.page_size == 1:
1278
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1279
+ else:
1280
+ last_loc = [
1281
+ (
1282
+ r.prefix_indices[-1:]
1283
+ if len(r.prefix_indices) > 0
1284
+ else torch.tensor([-1], device=self.device)
1285
+ )
1286
+ for r in self.reqs
1287
+ ]
1288
+ out_cache_loc = self.alloc_paged_token_slots_extend(
1289
+ prefix_lens_tensor,
1290
+ prefix_lens_cpu_tensor,
1291
+ seq_lens_tensor,
1292
+ seq_lens_cpu,
1293
+ torch.cat(last_loc),
1294
+ extend_num_tokens,
1295
+ )
1296
+
1297
+ # Write allocated tokens to req_to_token_pool
1298
+ self.write_cache_indices(
1299
+ req_pool_indices,
1300
+ prefix_lens,
1301
+ seq_lens,
1302
+ extend_lens,
1303
+ out_cache_loc,
1304
+ req_pool_indices_tensor,
1305
+ prefix_lens_tensor,
1306
+ seq_lens_tensor,
1307
+ extend_lens_tensor,
1308
+ [r.prefix_indices for r in reqs],
1309
+ )
1310
+
1311
+ # Set fields
1181
1312
  input_embeds = []
1182
1313
  extend_input_logprob_token_ids = []
1183
1314
  multimodal_inputs = []
@@ -1187,9 +1318,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1187
1318
  assert seq_len - pre_len == req.extend_input_len
1188
1319
 
1189
1320
  if pre_len > 0:
1190
- self.req_to_token_pool.write(
1191
- (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1192
- )
1193
1321
  if isinstance(self.tree_cache, SWAChunkCache):
1194
1322
  self.tree_cache.evict_swa(
1195
1323
  req, pre_len, self.model_config.attention_chunk_size
@@ -1207,13 +1335,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1207
1335
  req.is_retracted = False
1208
1336
 
1209
1337
  # Compute the relative logprob_start_len in an extend batch
1338
+ #
1339
+ # Key variables:
1340
+ # - logprob_start_len: Absolute position in full sequence where logprob computation begins
1341
+ # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
1342
+ # - extend_input_len: Number of tokens that need to be processed in this extend batch
1343
+ # (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
1344
+ # and prefix_indices are the cached/shared prefix tokens)
1345
+ #
1210
1346
  if req.logprob_start_len >= pre_len:
1211
- req.extend_logprob_start_len = min(
1212
- req.logprob_start_len - pre_len,
1213
- req.extend_input_len,
1214
- req.seqlen - 1,
1215
- )
1347
+ # Optimization for prefill-only requests: When we only need logprobs at
1348
+ # positions beyond the input sequence (to score next-token likelihood), skip all
1349
+ # input logprob computation during prefill since no generation will occur.
1350
+ if self.is_prefill_only and req.logprob_start_len == len(
1351
+ req.origin_input_ids
1352
+ ):
1353
+ # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
1354
+ req.extend_logprob_start_len = req.extend_input_len
1355
+ else:
1356
+ # Convert absolute logprob_start_len to relative extend_logprob_start_len
1357
+ #
1358
+ # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
1359
+ # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
1360
+ # This means: "compute logprobs from position 3 onwards in extend batch"
1361
+ req.extend_logprob_start_len = min(
1362
+ req.logprob_start_len - pre_len,
1363
+ req.extend_input_len,
1364
+ req.seqlen - 1,
1365
+ )
1216
1366
  else:
1367
+ # logprob_start_len is before the current extend batch, so start from beginning
1217
1368
  req.extend_logprob_start_len = 0
1218
1369
 
1219
1370
  if self.return_logprob:
@@ -1261,23 +1412,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1261
1412
  else:
1262
1413
  extend_input_logprob_token_ids = None
1263
1414
 
1264
- # Allocate memory
1265
- if self.token_to_kv_pool_allocator.page_size == 1:
1266
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1267
- else:
1268
- last_loc = get_last_loc(
1269
- self.req_to_token_pool.req_to_token,
1270
- req_pool_indices_tensor,
1271
- prefix_lens_tensor,
1272
- )
1273
- out_cache_loc = self.alloc_paged_token_slots_extend(
1274
- prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
1275
- )
1276
-
1277
- # Set fields
1278
1415
  self.input_ids = input_ids_tensor
1279
1416
  self.req_pool_indices = req_pool_indices_tensor
1280
1417
  self.seq_lens = seq_lens_tensor
1418
+ self.seq_lens_cpu = seq_lens_cpu
1281
1419
  self.orig_seq_lens = orig_seq_lens_tensor
1282
1420
  self.out_cache_loc = out_cache_loc
1283
1421
  self.input_embeds = (
@@ -1306,28 +1444,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1306
1444
  self.extend_lens = extend_lens
1307
1445
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
1308
1446
 
1309
- # Write to req_to_token_pool
1310
- if support_triton(global_server_args_dict.get("attention_backend")):
1311
- # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1312
-
1313
- write_req_to_token_pool_triton[(bs,)](
1314
- self.req_to_token_pool.req_to_token,
1315
- req_pool_indices_tensor,
1316
- prefix_lens_tensor,
1317
- seq_lens_tensor,
1318
- extend_lens_tensor,
1319
- out_cache_loc,
1320
- self.req_to_token_pool.req_to_token.shape[1],
1321
- )
1322
- else:
1323
- pt = 0
1324
- for i in range(bs):
1325
- self.req_to_token_pool.write(
1326
- (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1327
- out_cache_loc[pt : pt + extend_lens[i]],
1328
- )
1329
- pt += extend_lens[i]
1330
-
1331
1447
  if self.model_config.is_encoder_decoder:
1332
1448
  self.prepare_encoder_info_extend(input_ids, seq_lens)
1333
1449
 
@@ -1372,21 +1488,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1372
1488
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1373
1489
  self.extend_logprob_start_lens.extend([0] * running_bs)
1374
1490
 
1375
- def new_page_count_next_decode(self):
1491
+ def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1376
1492
  page_size = self.token_to_kv_pool_allocator.page_size
1493
+ requests = (
1494
+ self.reqs
1495
+ if selected_indices is None
1496
+ else [self.reqs[i] for i in selected_indices]
1497
+ )
1377
1498
  if page_size == 1:
1378
- return len(self.reqs)
1499
+ return len(requests)
1379
1500
  # In the decoding phase, the length of a request's KV cache should be
1380
1501
  # the total length of the request minus 1
1381
1502
  return (
1382
- sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1503
+ sum(1 for req in requests if req.seqlen % page_size == 0)
1383
1504
  if self.enable_overlap
1384
- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1505
+ else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
1385
1506
  )
1386
1507
 
1387
- def check_decode_mem(self, buf_multiplier=1):
1508
+ def check_decode_mem(
1509
+ self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
1510
+ ):
1388
1511
  num_tokens = (
1389
- self.new_page_count_next_decode()
1512
+ self.new_page_count_next_decode(selected_indices)
1390
1513
  * buf_multiplier
1391
1514
  * self.token_to_kv_pool_allocator.page_size
1392
1515
  )
@@ -1412,34 +1535,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1412
1535
  reverse=True,
1413
1536
  )
1414
1537
 
1415
- def get_required_tokens(num_reqs: int):
1416
- headroom_for_spec_decode = 0
1417
- if server_args.speculative_algorithm:
1418
- headroom_for_spec_decode += (
1419
- num_reqs
1420
- * server_args.speculative_eagle_topk
1421
- * server_args.speculative_num_steps
1422
- + num_reqs * server_args.speculative_num_draft_tokens
1423
- )
1424
- return (
1425
- num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1426
- )
1427
-
1428
- def _get_available_size():
1429
- if self.is_hybrid:
1430
- return min(
1431
- self.token_to_kv_pool_allocator.full_available_size(),
1432
- self.token_to_kv_pool_allocator.swa_available_size(),
1433
- )
1434
- else:
1435
- return self.token_to_kv_pool_allocator.available_size()
1436
-
1437
1538
  retracted_reqs = []
1438
- seq_lens_cpu = self.seq_lens.cpu().numpy()
1439
1539
  first_iter = True
1440
- while (
1441
- _get_available_size() < get_required_tokens(len(sorted_indices))
1442
- or first_iter
1540
+ while first_iter or (
1541
+ not self.check_decode_mem(selected_indices=sorted_indices)
1443
1542
  ):
1444
1543
  if len(sorted_indices) == 1:
1445
1544
  # Corner case: only one request left
@@ -1463,41 +1562,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1463
1562
  idx = sorted_indices.pop()
1464
1563
  req = self.reqs[idx]
1465
1564
  retracted_reqs.append(req)
1466
-
1467
- if server_args.disaggregation_mode == "decode":
1468
- req.offload_kv_cache(
1469
- self.req_to_token_pool, self.token_to_kv_pool_allocator
1470
- )
1471
-
1472
- if isinstance(self.tree_cache, ChunkCache):
1473
- # ChunkCache does not have eviction
1474
- token_indices = self.req_to_token_pool.req_to_token[
1475
- req.req_pool_idx, : seq_lens_cpu[idx]
1476
- ]
1477
- self.token_to_kv_pool_allocator.free(token_indices)
1478
- self.req_to_token_pool.free(req.req_pool_idx)
1479
- else:
1480
- # TODO: apply more fine-grained retraction
1481
- last_uncached_pos = (
1482
- len(req.prefix_indices) // server_args.page_size
1483
- ) * server_args.page_size
1484
- token_indices = self.req_to_token_pool.req_to_token[
1485
- req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1486
- ]
1487
- self.token_to_kv_pool_allocator.free(token_indices)
1488
- self.req_to_token_pool.free(req.req_pool_idx)
1489
-
1490
- # release the last node
1491
- if self.is_hybrid:
1492
- self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1493
- else:
1494
- self.tree_cache.dec_lock_ref(req.last_node)
1495
-
1496
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1497
- num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1498
- self._evict_tree_cache_if_needed(num_tokens)
1499
-
1500
- req.reset_for_retract()
1565
+ self.release_req(idx, len(sorted_indices), server_args)
1501
1566
 
1502
1567
  if len(retracted_reqs) == 0:
1503
1568
  # Corner case: only one request left
@@ -1516,7 +1581,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1516
1581
  ) / total_max_new_tokens
1517
1582
  new_estimate_ratio = min(1.0, new_estimate_ratio)
1518
1583
 
1519
- return retracted_reqs, new_estimate_ratio
1584
+ return retracted_reqs, new_estimate_ratio, []
1585
+
1586
+ def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
1587
+ req = self.reqs[idx]
1588
+ seq_lens_cpu = self.seq_lens_cpu.numpy()
1589
+
1590
+ if server_args.disaggregation_mode == "decode":
1591
+ req.offload_kv_cache(
1592
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1593
+ )
1594
+ if isinstance(self.tree_cache, ChunkCache):
1595
+ # ChunkCache does not have eviction
1596
+ token_indices = self.req_to_token_pool.req_to_token[
1597
+ req.req_pool_idx, : seq_lens_cpu[idx]
1598
+ ]
1599
+ self.token_to_kv_pool_allocator.free(token_indices)
1600
+ self.req_to_token_pool.free(req.req_pool_idx)
1601
+ else:
1602
+ # TODO: apply more fine-grained retraction
1603
+ last_uncached_pos = (
1604
+ len(req.prefix_indices) // server_args.page_size
1605
+ ) * server_args.page_size
1606
+ token_indices = self.req_to_token_pool.req_to_token[
1607
+ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1608
+ ]
1609
+ self.token_to_kv_pool_allocator.free(token_indices)
1610
+ self.req_to_token_pool.free(req.req_pool_idx)
1611
+
1612
+ # release the last node
1613
+ if self.is_hybrid:
1614
+ self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1615
+ else:
1616
+ self.tree_cache.dec_lock_ref(req.last_node)
1617
+
1618
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
1619
+ num_tokens = remaing_req_count * global_config.retract_decode_steps
1620
+ self._evict_tree_cache_if_needed(num_tokens)
1621
+
1622
+ req.reset_for_retract()
1520
1623
 
1521
1624
  def prepare_encoder_info_decode(self):
1522
1625
  # Reset the encoder cached status
@@ -1526,6 +1629,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1526
1629
  self.forward_mode = ForwardMode.IDLE
1527
1630
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1528
1631
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1632
+ self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1529
1633
  self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1530
1634
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1531
1635
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
@@ -1540,7 +1644,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1540
1644
  self.forward_mode = ForwardMode.DECODE
1541
1645
  bs = len(self.reqs)
1542
1646
 
1543
- if self.spec_algorithm.is_eagle():
1647
+ if (
1648
+ self.spec_algorithm.is_eagle()
1649
+ or self.spec_algorithm.is_standalone()
1650
+ or self.spec_algorithm.is_ngram()
1651
+ ):
1544
1652
  # if spec decoding is used, the decode batch is prepared inside
1545
1653
  # `forward_batch_speculative_generation` after running draft models.
1546
1654
  return
@@ -1581,10 +1689,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1581
1689
  if self.enable_overlap:
1582
1690
  # Do not use in-place operations in the overlap mode
1583
1691
  self.seq_lens = self.seq_lens + 1
1692
+ self.seq_lens_cpu = self.seq_lens_cpu + 1
1584
1693
  self.orig_seq_lens = self.orig_seq_lens + 1
1585
1694
  else:
1586
1695
  # A faster in-place version
1587
1696
  self.seq_lens.add_(1)
1697
+ self.seq_lens_cpu.add_(1)
1588
1698
  self.orig_seq_lens.add_(1)
1589
1699
  self.seq_lens_sum += bs
1590
1700
 
@@ -1603,7 +1713,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1603
1713
  self.req_pool_indices, self.seq_lens - 2
1604
1714
  ]
1605
1715
  self.out_cache_loc = self.alloc_paged_token_slots_decode(
1606
- self.seq_lens, last_loc
1716
+ self.seq_lens, self.seq_lens_cpu, last_loc
1607
1717
  )
1608
1718
 
1609
1719
  self.req_to_token_pool.write(
@@ -1649,6 +1759,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1649
1759
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1650
1760
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1651
1761
  self.seq_lens = self.seq_lens[keep_indices_device]
1762
+ self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1652
1763
  self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1653
1764
  self.out_cache_loc = None
1654
1765
  self.seq_lens_sum = self.seq_lens.sum().item()
@@ -1666,7 +1777,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1666
1777
 
1667
1778
  self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1668
1779
  if self.spec_info:
1669
- self.spec_info.filter_batch(keep_indices_device)
1780
+ if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
1781
+ has_been_filtered = False
1782
+ else:
1783
+ has_been_filtered = True
1784
+ self.spec_info.filter_batch(
1785
+ new_indices=keep_indices_device,
1786
+ has_been_filtered=has_been_filtered,
1787
+ )
1670
1788
 
1671
1789
  def merge_batch(self, other: "ScheduleBatch"):
1672
1790
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1682,6 +1800,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1682
1800
  [self.req_pool_indices, other.req_pool_indices]
1683
1801
  )
1684
1802
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1803
+ self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1685
1804
  self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1686
1805
  self.out_cache_loc = None
1687
1806
  self.seq_lens_sum += other.seq_lens_sum
@@ -1725,15 +1844,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1725
1844
  self.sampling_info.grammars = None
1726
1845
 
1727
1846
  seq_lens_cpu = (
1728
- seq_lens_cpu_cache
1729
- if seq_lens_cpu_cache is not None
1730
- else self.seq_lens.cpu()
1847
+ seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
1731
1848
  )
1732
1849
 
1733
- global bid
1734
- bid += 1
1735
1850
  return ModelWorkerBatch(
1736
- bid=bid,
1737
1851
  forward_mode=self.forward_mode,
1738
1852
  input_ids=self.input_ids,
1739
1853
  req_pool_indices=self.req_pool_indices,
@@ -1779,7 +1893,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1779
1893
  )
1780
1894
  ),
1781
1895
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1782
- launch_done=self.launch_done,
1896
+ is_prefill_only=self.is_prefill_only,
1783
1897
  )
1784
1898
 
1785
1899
  def copy(self):
@@ -1852,8 +1966,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1852
1966
 
1853
1967
  @dataclasses.dataclass
1854
1968
  class ModelWorkerBatch:
1855
- # The batch id
1856
- bid: int
1857
1969
  # The forward mode
1858
1970
  forward_mode: ForwardMode
1859
1971
  # The input ids
@@ -1914,19 +2026,25 @@ class ModelWorkerBatch:
1914
2026
 
1915
2027
  # Speculative decoding
1916
2028
  spec_algorithm: SpeculativeAlgorithm = None
1917
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
2029
+
2030
+ spec_info: Optional[SpecInput] = None
2031
+
1918
2032
  # If set, the output of the batch contains the hidden states of the run.
1919
2033
  capture_hidden_mode: CaptureHiddenMode = None
1920
- hicache_consumer_index: int = 0
2034
+ hicache_consumer_index: int = -1
2035
+
2036
+ # Overlap scheduler related
2037
+ delay_sample_launch: bool = False
1921
2038
 
1922
- # Overlap event
1923
- launch_done: Optional[threading.Event] = None
2039
+ # Whether this batch is prefill-only (no token generation needed)
2040
+ is_prefill_only: bool = False
1924
2041
 
1925
2042
 
1926
2043
  @triton.jit
1927
2044
  def write_req_to_token_pool_triton(
1928
2045
  req_to_token_ptr, # [max_batch, max_context_len]
1929
2046
  req_pool_indices,
2047
+ prefix_tensors,
1930
2048
  pre_lens,
1931
2049
  seq_lens,
1932
2050
  extend_lens,
@@ -1939,6 +2057,19 @@ def write_req_to_token_pool_triton(
1939
2057
  req_pool_index = tl.load(req_pool_indices + pid)
1940
2058
  pre_len = tl.load(pre_lens + pid)
1941
2059
  seq_len = tl.load(seq_lens + pid)
2060
+ prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
2061
+
2062
+ # write prefix
2063
+ num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
2064
+ for i in range(num_loop):
2065
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
2066
+ mask = offset < pre_len
2067
+ value = tl.load(prefix_tensor + offset, mask=mask)
2068
+ tl.store(
2069
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
2070
+ value,
2071
+ mask=mask,
2072
+ )
1942
2073
 
1943
2074
  # NOTE: This can be slow for large bs
1944
2075
  cumsum_start = tl.cast(0, tl.int64)