sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import enum
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -35,6 +37,7 @@ import copy
35
37
  import dataclasses
36
38
  import logging
37
39
  import threading
40
+ import time
38
41
  from enum import Enum, auto
39
42
  from http import HTTPStatus
40
43
  from itertools import chain
@@ -51,18 +54,18 @@ from sglang.srt.disaggregation.base import BaseKVSender
51
54
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
55
  ScheduleBatchDisaggregationDecodeMixin,
53
56
  )
57
+ from sglang.srt.disaggregation.utils import DisaggregationMode
54
58
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.layers.moe import is_tbo_enabled
56
59
  from sglang.srt.mem_cache.allocator import (
57
60
  BaseTokenToKVPoolAllocator,
58
61
  SWATokenToKVPoolAllocator,
59
62
  )
60
63
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
61
64
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
62
- from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
63
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
65
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
66
+ from sglang.srt.mem_cache.radix_cache import RadixKey
64
67
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
65
- from sglang.srt.metrics.collector import TimeStats
68
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
66
69
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
67
70
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
68
71
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -71,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
71
74
 
72
75
  if TYPE_CHECKING:
73
76
  from sglang.srt.configs.model_config import ModelConfig
74
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
75
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
76
78
 
77
79
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
78
80
 
@@ -87,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
87
89
  "disable_flashinfer_cutlass_moe_fp4_allgather",
88
90
  "disable_radix_cache",
89
91
  "enable_dp_lm_head",
92
+ "enable_fp32_lm_head",
90
93
  "flashinfer_mxfp4_moe_precision",
91
94
  "enable_flashinfer_allreduce_fusion",
92
95
  "moe_dense_tp_size",
@@ -99,15 +102,18 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
102
  "sampling_backend",
100
103
  "speculative_accept_threshold_single",
101
104
  "speculative_accept_threshold_acc",
105
+ "speculative_attention_mode",
102
106
  "torchao_config",
103
107
  "triton_attention_reduce_in_fp32",
104
108
  "num_reserved_decode_tokens",
105
109
  "weight_loader_disable_mmap",
106
110
  "enable_multimodal",
107
111
  "enable_symm_mem",
108
- "quantization",
109
112
  "enable_custom_logit_processor",
110
113
  "disaggregation_mode",
114
+ "enable_deterministic_inference",
115
+ "nsa_prefill",
116
+ "nsa_decode",
111
117
  ]
112
118
 
113
119
  # Put some global args for easy access
@@ -408,6 +414,23 @@ class MultimodalInputs:
408
414
  # other args would be kept intact
409
415
 
410
416
 
417
+ class RequestStage(str, enum.Enum):
418
+ # prefill
419
+ PREFILL_WAITING = "prefill_waiting"
420
+
421
+ # disaggregation prefill
422
+ PREFILL_PREPARE = "prefill_prepare"
423
+ PREFILL_BOOTSTRAP = "prefill_bootstrap"
424
+ PREFILL_FORWARD = "prefill_forward"
425
+ PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
426
+
427
+ # disaggregation decode
428
+ DECODE_PREPARE = "decode_prepare"
429
+ DECODE_BOOTSTRAP = "decode_bootstrap"
430
+ DECODE_WAITING = "decode_waiting"
431
+ DECODE_TRANSFERRED = "decode_transferred"
432
+
433
+
411
434
  class Req:
412
435
  """The input and output status of a request."""
413
436
 
@@ -432,8 +455,12 @@ class Req:
432
455
  bootstrap_host: Optional[str] = None,
433
456
  bootstrap_port: Optional[int] = None,
434
457
  bootstrap_room: Optional[int] = None,
458
+ disagg_mode: Optional[DisaggregationMode] = None,
435
459
  data_parallel_rank: Optional[int] = None,
436
460
  vocab_size: Optional[int] = None,
461
+ priority: Optional[int] = None,
462
+ metrics_collector: Optional[SchedulerMetricsCollector] = None,
463
+ extra_key: Optional[str] = None,
437
464
  ):
438
465
  # Input and output info
439
466
  self.rid = rid
@@ -466,6 +493,14 @@ class Req:
466
493
  self.sampling_params = sampling_params
467
494
  self.custom_logit_processor = custom_logit_processor
468
495
  self.return_hidden_states = return_hidden_states
496
+
497
+ # extra key for classifying the request (e.g. cache_salt)
498
+ if lora_id is not None:
499
+ extra_key = (
500
+ extra_key or ""
501
+ ) + lora_id # lora_id is concatenated to the extra key
502
+
503
+ self.extra_key = extra_key
469
504
  self.lora_id = lora_id
470
505
 
471
506
  # Memory pool info
@@ -484,6 +519,7 @@ class Req:
484
519
  self.stream = stream
485
520
  self.eos_token_ids = eos_token_ids
486
521
  self.vocab_size = vocab_size
522
+ self.priority = priority
487
523
 
488
524
  # For incremental decoding
489
525
  # ----- | --------- read_ids -------|
@@ -513,6 +549,8 @@ class Req:
513
549
  self.host_hit_length = 0
514
550
  # The node to lock until for swa radix tree lock ref
515
551
  self.swa_uuid_for_lock: Optional[int] = None
552
+ # The prefix length of the last prefix matching
553
+ self.last_matched_prefix_len: int = 0
516
554
 
517
555
  # Whether or not if it is chunked. It increments whenever
518
556
  # it is chunked, and decrement whenever chunked request is
@@ -561,7 +599,10 @@ class Req:
561
599
  # shape: (bs, k)
562
600
  self.output_top_logprobs_val = []
563
601
  self.output_top_logprobs_idx = []
564
- self.output_token_ids_logprobs_val = []
602
+ # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
603
+ self.output_token_ids_logprobs_val: List[
604
+ Union[List[float], torch.Tensor]
605
+ ] = []
565
606
  self.output_token_ids_logprobs_idx = []
566
607
  else:
567
608
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
@@ -571,6 +612,8 @@ class Req:
571
612
  ) = None
572
613
  self.hidden_states: List[List[float]] = []
573
614
  self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
615
+ self.output_topk_p = None
616
+ self.output_topk_index = None
574
617
 
575
618
  # Embedding (return values)
576
619
  self.embedding = None
@@ -588,10 +631,10 @@ class Req:
588
631
  self.spec_verify_ct = 0
589
632
 
590
633
  # For metrics
591
- self.time_stats: TimeStats = TimeStats()
634
+ self.metrics_collector = metrics_collector
635
+ self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
592
636
  self.has_log_time_stats: bool = False
593
- self.queue_time_start = None
594
- self.queue_time_end = None
637
+ self.last_tic = time.monotonic()
595
638
 
596
639
  # For disaggregation
597
640
  self.bootstrap_host: str = bootstrap_host
@@ -619,6 +662,25 @@ class Req:
619
662
  def seqlen(self):
620
663
  return len(self.origin_input_ids) + len(self.output_ids)
621
664
 
665
+ @property
666
+ def is_prefill_only(self) -> bool:
667
+ """Check if this request is prefill-only (no token generation needed)."""
668
+ # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
+ return (
670
+ self.sampling_params.max_new_tokens == 0
671
+ and global_server_args_dict["speculative_algorithm"] is None
672
+ )
673
+
674
+ def add_latency(self, stage: RequestStage):
675
+ if self.metrics_collector is None:
676
+ return
677
+
678
+ now = time.monotonic()
679
+ self.metrics_collector.observe_per_stage_req_latency(
680
+ stage.value, now - self.last_tic
681
+ )
682
+ self.last_tic = now
683
+
622
684
  def extend_image_inputs(self, image_inputs):
623
685
  if self.multimodal_inputs is None:
624
686
  self.multimodal_inputs = image_inputs
@@ -635,26 +697,17 @@ class Req:
635
697
  ):
636
698
  self.fill_ids = self.origin_input_ids + self.output_ids
637
699
  if tree_cache is not None:
638
- if isinstance(tree_cache, LoRARadixCache):
639
- (
640
- self.prefix_indices,
641
- self.last_node,
642
- self.last_host_node,
643
- self.host_hit_length,
644
- ) = tree_cache.match_prefix_with_lora_id(
645
- key=LoRAKey(
646
- lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
647
- ),
648
- )
649
- else:
650
- (
651
- self.prefix_indices,
652
- self.last_node,
653
- self.last_host_node,
654
- self.host_hit_length,
655
- ) = tree_cache.match_prefix(
656
- key=self.adjust_max_prefix_ids(),
657
- )
700
+ (
701
+ self.prefix_indices,
702
+ self.last_node,
703
+ self.last_host_node,
704
+ self.host_hit_length,
705
+ ) = tree_cache.match_prefix(
706
+ key=RadixKey(
707
+ token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
708
+ ),
709
+ )
710
+ self.last_matched_prefix_len = len(self.prefix_indices)
658
711
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
659
712
 
660
713
  def adjust_max_prefix_ids(self):
@@ -684,9 +737,15 @@ class Req:
684
737
  self.surr_offset = max(
685
738
  self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
686
739
  )
740
+ self.surr_and_decode_ids = (
741
+ self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
742
+ )
743
+ self.cur_decode_ids_len = len(self.output_ids)
744
+ else:
745
+ self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
746
+ self.cur_decode_ids_len = len(self.output_ids)
687
747
 
688
- all_ids = self.origin_input_ids_unpadded + self.output_ids
689
- return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
748
+ return self.surr_and_decode_ids, self.read_offset - self.surr_offset
690
749
 
691
750
  def check_finished(self):
692
751
  if self.finished():
@@ -781,10 +840,10 @@ class Req:
781
840
  return
782
841
 
783
842
  if self.bootstrap_room is not None:
784
- prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
843
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
785
844
  else:
786
- prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
787
- logger.info(f"{prefix}: {self.time_stats}")
845
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
846
+ logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
788
847
  self.has_log_time_stats = True
789
848
 
790
849
  def set_finish_with_abort(self, error_msg: str):
@@ -807,10 +866,6 @@ class Req:
807
866
  )
808
867
 
809
868
 
810
- # Batch id
811
- bid = 0
812
-
813
-
814
869
  @dataclasses.dataclass
815
870
  class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
816
871
  """Store all information of a batch on the scheduler."""
@@ -847,6 +902,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
847
902
  token_type_ids: torch.Tensor = None # shape: [b], int64
848
903
  req_pool_indices: torch.Tensor = None # shape: [b], int64
849
904
  seq_lens: torch.Tensor = None # shape: [b], int64
905
+ seq_lens_cpu: torch.Tensor = None # shape: [b], int64
850
906
  # The output locations of the KV cache
851
907
  out_cache_loc: torch.Tensor = None # shape: [b], int64
852
908
  output_ids: torch.Tensor = None # shape: [b], int64
@@ -902,7 +958,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
902
958
 
903
959
  # Speculative decoding
904
960
  spec_algorithm: SpeculativeAlgorithm = None
905
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
961
+ # spec_info: Optional[SpecInput] = None
962
+ spec_info: Optional[SpecInput] = None
906
963
 
907
964
  # Whether to return hidden states
908
965
  return_hidden_states: bool = False
@@ -911,7 +968,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
911
968
  is_prefill_only: bool = False
912
969
 
913
970
  # hicache pointer for synchronizing data loading from CPU to GPU
914
- hicache_consumer_index: int = 0
971
+ hicache_consumer_index: int = -1
915
972
 
916
973
  @classmethod
917
974
  def init_new(
@@ -950,9 +1007,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
950
1007
  device=req_to_token_pool.device,
951
1008
  spec_algorithm=spec_algorithm,
952
1009
  return_hidden_states=any(req.return_hidden_states for req in reqs),
953
- is_prefill_only=all(
954
- req.sampling_params.max_new_tokens == 0 for req in reqs
955
- ),
1010
+ is_prefill_only=all(req.is_prefill_only for req in reqs),
956
1011
  chunked_req=chunked_req,
957
1012
  )
958
1013
 
@@ -962,8 +1017,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
962
1017
  def is_empty(self):
963
1018
  return len(self.reqs) == 0
964
1019
 
965
- def alloc_req_slots(self, num_reqs: int):
966
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
1020
+ def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
1021
+ if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
1022
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
1023
+ else:
1024
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
967
1025
  if req_pool_indices is None:
968
1026
  raise RuntimeError(
969
1027
  "alloc_req_slots runs out of memory. "
@@ -1000,7 +1058,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1000
1058
  def alloc_paged_token_slots_extend(
1001
1059
  self,
1002
1060
  prefix_lens: torch.Tensor,
1061
+ prefix_lens_cpu: torch.Tensor,
1003
1062
  seq_lens: torch.Tensor,
1063
+ seq_lens_cpu: torch.Tensor,
1004
1064
  last_loc: torch.Tensor,
1005
1065
  extend_num_tokens: int,
1006
1066
  backup_state: bool = False,
@@ -1008,7 +1068,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1008
1068
  # Over estimate the number of tokens: assume each request needs a new page.
1009
1069
  num_tokens = (
1010
1070
  extend_num_tokens
1011
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1071
+ + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
1012
1072
  )
1013
1073
  self._evict_tree_cache_if_needed(num_tokens)
1014
1074
 
@@ -1016,7 +1076,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1016
1076
  state = self.token_to_kv_pool_allocator.backup_state()
1017
1077
 
1018
1078
  out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
1019
- prefix_lens, seq_lens, last_loc, extend_num_tokens
1079
+ prefix_lens,
1080
+ prefix_lens_cpu,
1081
+ seq_lens,
1082
+ seq_lens_cpu,
1083
+ last_loc,
1084
+ extend_num_tokens,
1020
1085
  )
1021
1086
  if out_cache_loc is None:
1022
1087
  error_msg = (
@@ -1035,6 +1100,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1035
1100
  def alloc_paged_token_slots_decode(
1036
1101
  self,
1037
1102
  seq_lens: torch.Tensor,
1103
+ seq_lens_cpu: torch.Tensor,
1038
1104
  last_loc: torch.Tensor,
1039
1105
  backup_state: bool = False,
1040
1106
  ):
@@ -1045,7 +1111,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1045
1111
  if backup_state:
1046
1112
  state = self.token_to_kv_pool_allocator.backup_state()
1047
1113
 
1048
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
1114
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
1115
+ seq_lens, seq_lens_cpu, last_loc
1116
+ )
1049
1117
  if out_cache_loc is None:
1050
1118
  error_msg = (
1051
1119
  f"Decode out of memory. Try to lower your batch size.\n"
@@ -1114,6 +1182,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1114
1182
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1115
1183
  self.device, non_blocking=True
1116
1184
  )
1185
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1117
1186
 
1118
1187
  if not decoder_out_cache_loc:
1119
1188
  self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1138,7 +1207,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1138
1207
 
1139
1208
  # Allocate req slots
1140
1209
  bs = len(self.reqs)
1141
- req_pool_indices = self.alloc_req_slots(bs)
1210
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1142
1211
 
1143
1212
  # Init tensors
1144
1213
  reqs = self.reqs
@@ -1162,12 +1231,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1162
1231
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1163
1232
  self.device, non_blocking=True
1164
1233
  )
1234
+ seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1165
1235
  orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1166
1236
  self.device, non_blocking=True
1167
1237
  )
1168
1238
  prefix_lens_tensor = torch.tensor(
1169
1239
  prefix_lens, dtype=torch.int64, device=self.device
1170
1240
  )
1241
+ prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
1171
1242
 
1172
1243
  token_type_ids_tensor = None
1173
1244
  if len(token_type_ids) > 0:
@@ -1207,13 +1278,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1207
1278
  req.is_retracted = False
1208
1279
 
1209
1280
  # Compute the relative logprob_start_len in an extend batch
1281
+ #
1282
+ # Key variables:
1283
+ # - logprob_start_len: Absolute position in full sequence where logprob computation begins
1284
+ # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
1285
+ # - extend_input_len: Number of tokens that need to be processed in this extend batch
1286
+ # (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
1287
+ # and prefix_indices are the cached/shared prefix tokens)
1288
+ #
1210
1289
  if req.logprob_start_len >= pre_len:
1211
- req.extend_logprob_start_len = min(
1212
- req.logprob_start_len - pre_len,
1213
- req.extend_input_len,
1214
- req.seqlen - 1,
1215
- )
1290
+ # Optimization for prefill-only requests: When we only need logprobs at
1291
+ # positions beyond the input sequence (to score next-token likelihood), skip all
1292
+ # input logprob computation during prefill since no generation will occur.
1293
+ if self.is_prefill_only and req.logprob_start_len == len(
1294
+ req.origin_input_ids
1295
+ ):
1296
+ # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
1297
+ req.extend_logprob_start_len = req.extend_input_len
1298
+ else:
1299
+ # Convert absolute logprob_start_len to relative extend_logprob_start_len
1300
+ #
1301
+ # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
1302
+ # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
1303
+ # This means: "compute logprobs from position 3 onwards in extend batch"
1304
+ req.extend_logprob_start_len = min(
1305
+ req.logprob_start_len - pre_len,
1306
+ req.extend_input_len,
1307
+ req.seqlen - 1,
1308
+ )
1216
1309
  else:
1310
+ # logprob_start_len is before the current extend batch, so start from beginning
1217
1311
  req.extend_logprob_start_len = 0
1218
1312
 
1219
1313
  if self.return_logprob:
@@ -1271,13 +1365,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1271
1365
  prefix_lens_tensor,
1272
1366
  )
1273
1367
  out_cache_loc = self.alloc_paged_token_slots_extend(
1274
- prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
1368
+ prefix_lens_tensor,
1369
+ prefix_lens_cpu_tensor,
1370
+ seq_lens_tensor,
1371
+ seq_lens_cpu,
1372
+ last_loc,
1373
+ extend_num_tokens,
1275
1374
  )
1276
1375
 
1277
1376
  # Set fields
1278
1377
  self.input_ids = input_ids_tensor
1279
1378
  self.req_pool_indices = req_pool_indices_tensor
1280
1379
  self.seq_lens = seq_lens_tensor
1380
+ self.seq_lens_cpu = seq_lens_cpu
1281
1381
  self.orig_seq_lens = orig_seq_lens_tensor
1282
1382
  self.out_cache_loc = out_cache_loc
1283
1383
  self.input_embeds = (
@@ -1372,21 +1472,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1372
1472
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1373
1473
  self.extend_logprob_start_lens.extend([0] * running_bs)
1374
1474
 
1375
- def new_page_count_next_decode(self):
1475
+ def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1376
1476
  page_size = self.token_to_kv_pool_allocator.page_size
1477
+ requests = (
1478
+ self.reqs
1479
+ if selected_indices is None
1480
+ else [self.reqs[i] for i in selected_indices]
1481
+ )
1377
1482
  if page_size == 1:
1378
- return len(self.reqs)
1483
+ return len(requests)
1379
1484
  # In the decoding phase, the length of a request's KV cache should be
1380
1485
  # the total length of the request minus 1
1381
1486
  return (
1382
- sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1487
+ sum(1 for req in requests if req.seqlen % page_size == 0)
1383
1488
  if self.enable_overlap
1384
- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1489
+ else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
1385
1490
  )
1386
1491
 
1387
- def check_decode_mem(self, buf_multiplier=1):
1492
+ def check_decode_mem(
1493
+ self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
1494
+ ):
1388
1495
  num_tokens = (
1389
- self.new_page_count_next_decode()
1496
+ self.new_page_count_next_decode(selected_indices)
1390
1497
  * buf_multiplier
1391
1498
  * self.token_to_kv_pool_allocator.page_size
1392
1499
  )
@@ -1412,34 +1519,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1412
1519
  reverse=True,
1413
1520
  )
1414
1521
 
1415
- def get_required_tokens(num_reqs: int):
1416
- headroom_for_spec_decode = 0
1417
- if server_args.speculative_algorithm:
1418
- headroom_for_spec_decode += (
1419
- num_reqs
1420
- * server_args.speculative_eagle_topk
1421
- * server_args.speculative_num_steps
1422
- + num_reqs * server_args.speculative_num_draft_tokens
1423
- )
1424
- return (
1425
- num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1426
- )
1427
-
1428
- def _get_available_size():
1429
- if self.is_hybrid:
1430
- return min(
1431
- self.token_to_kv_pool_allocator.full_available_size(),
1432
- self.token_to_kv_pool_allocator.swa_available_size(),
1433
- )
1434
- else:
1435
- return self.token_to_kv_pool_allocator.available_size()
1436
-
1437
1522
  retracted_reqs = []
1438
- seq_lens_cpu = self.seq_lens.cpu().numpy()
1439
1523
  first_iter = True
1440
- while (
1441
- _get_available_size() < get_required_tokens(len(sorted_indices))
1442
- or first_iter
1524
+ while first_iter or (
1525
+ not self.check_decode_mem(selected_indices=sorted_indices)
1443
1526
  ):
1444
1527
  if len(sorted_indices) == 1:
1445
1528
  # Corner case: only one request left
@@ -1463,41 +1546,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1463
1546
  idx = sorted_indices.pop()
1464
1547
  req = self.reqs[idx]
1465
1548
  retracted_reqs.append(req)
1466
-
1467
- if server_args.disaggregation_mode == "decode":
1468
- req.offload_kv_cache(
1469
- self.req_to_token_pool, self.token_to_kv_pool_allocator
1470
- )
1471
-
1472
- if isinstance(self.tree_cache, ChunkCache):
1473
- # ChunkCache does not have eviction
1474
- token_indices = self.req_to_token_pool.req_to_token[
1475
- req.req_pool_idx, : seq_lens_cpu[idx]
1476
- ]
1477
- self.token_to_kv_pool_allocator.free(token_indices)
1478
- self.req_to_token_pool.free(req.req_pool_idx)
1479
- else:
1480
- # TODO: apply more fine-grained retraction
1481
- last_uncached_pos = (
1482
- len(req.prefix_indices) // server_args.page_size
1483
- ) * server_args.page_size
1484
- token_indices = self.req_to_token_pool.req_to_token[
1485
- req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1486
- ]
1487
- self.token_to_kv_pool_allocator.free(token_indices)
1488
- self.req_to_token_pool.free(req.req_pool_idx)
1489
-
1490
- # release the last node
1491
- if self.is_hybrid:
1492
- self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1493
- else:
1494
- self.tree_cache.dec_lock_ref(req.last_node)
1495
-
1496
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1497
- num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1498
- self._evict_tree_cache_if_needed(num_tokens)
1499
-
1500
- req.reset_for_retract()
1549
+ self.release_req(idx, len(sorted_indices), server_args)
1501
1550
 
1502
1551
  if len(retracted_reqs) == 0:
1503
1552
  # Corner case: only one request left
@@ -1516,7 +1565,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1516
1565
  ) / total_max_new_tokens
1517
1566
  new_estimate_ratio = min(1.0, new_estimate_ratio)
1518
1567
 
1519
- return retracted_reqs, new_estimate_ratio
1568
+ return retracted_reqs, new_estimate_ratio, []
1569
+
1570
+ def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
1571
+ req = self.reqs[idx]
1572
+ seq_lens_cpu = self.seq_lens_cpu.numpy()
1573
+
1574
+ if server_args.disaggregation_mode == "decode":
1575
+ req.offload_kv_cache(
1576
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1577
+ )
1578
+ if isinstance(self.tree_cache, ChunkCache):
1579
+ # ChunkCache does not have eviction
1580
+ token_indices = self.req_to_token_pool.req_to_token[
1581
+ req.req_pool_idx, : seq_lens_cpu[idx]
1582
+ ]
1583
+ self.token_to_kv_pool_allocator.free(token_indices)
1584
+ self.req_to_token_pool.free(req.req_pool_idx)
1585
+ else:
1586
+ # TODO: apply more fine-grained retraction
1587
+ last_uncached_pos = (
1588
+ len(req.prefix_indices) // server_args.page_size
1589
+ ) * server_args.page_size
1590
+ token_indices = self.req_to_token_pool.req_to_token[
1591
+ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1592
+ ]
1593
+ self.token_to_kv_pool_allocator.free(token_indices)
1594
+ self.req_to_token_pool.free(req.req_pool_idx)
1595
+
1596
+ # release the last node
1597
+ if self.is_hybrid:
1598
+ self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1599
+ else:
1600
+ self.tree_cache.dec_lock_ref(req.last_node)
1601
+
1602
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
1603
+ num_tokens = remaing_req_count * global_config.retract_decode_steps
1604
+ self._evict_tree_cache_if_needed(num_tokens)
1605
+
1606
+ req.reset_for_retract()
1520
1607
 
1521
1608
  def prepare_encoder_info_decode(self):
1522
1609
  # Reset the encoder cached status
@@ -1526,6 +1613,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1526
1613
  self.forward_mode = ForwardMode.IDLE
1527
1614
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1528
1615
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1616
+ self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1529
1617
  self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1530
1618
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1531
1619
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
@@ -1540,7 +1628,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1540
1628
  self.forward_mode = ForwardMode.DECODE
1541
1629
  bs = len(self.reqs)
1542
1630
 
1543
- if self.spec_algorithm.is_eagle():
1631
+ if (
1632
+ self.spec_algorithm.is_eagle()
1633
+ or self.spec_algorithm.is_standalone()
1634
+ or self.spec_algorithm.is_ngram()
1635
+ ):
1544
1636
  # if spec decoding is used, the decode batch is prepared inside
1545
1637
  # `forward_batch_speculative_generation` after running draft models.
1546
1638
  return
@@ -1581,10 +1673,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1581
1673
  if self.enable_overlap:
1582
1674
  # Do not use in-place operations in the overlap mode
1583
1675
  self.seq_lens = self.seq_lens + 1
1676
+ self.seq_lens_cpu = self.seq_lens_cpu + 1
1584
1677
  self.orig_seq_lens = self.orig_seq_lens + 1
1585
1678
  else:
1586
1679
  # A faster in-place version
1587
1680
  self.seq_lens.add_(1)
1681
+ self.seq_lens_cpu.add_(1)
1588
1682
  self.orig_seq_lens.add_(1)
1589
1683
  self.seq_lens_sum += bs
1590
1684
 
@@ -1603,7 +1697,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1603
1697
  self.req_pool_indices, self.seq_lens - 2
1604
1698
  ]
1605
1699
  self.out_cache_loc = self.alloc_paged_token_slots_decode(
1606
- self.seq_lens, last_loc
1700
+ self.seq_lens, self.seq_lens_cpu, last_loc
1607
1701
  )
1608
1702
 
1609
1703
  self.req_to_token_pool.write(
@@ -1649,6 +1743,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1649
1743
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1650
1744
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1651
1745
  self.seq_lens = self.seq_lens[keep_indices_device]
1746
+ self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1652
1747
  self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1653
1748
  self.out_cache_loc = None
1654
1749
  self.seq_lens_sum = self.seq_lens.sum().item()
@@ -1666,7 +1761,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1666
1761
 
1667
1762
  self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1668
1763
  if self.spec_info:
1669
- self.spec_info.filter_batch(keep_indices_device)
1764
+ if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
1765
+ has_been_filtered = False
1766
+ else:
1767
+ has_been_filtered = True
1768
+ self.spec_info.filter_batch(
1769
+ new_indices=keep_indices_device,
1770
+ has_been_filtered=has_been_filtered,
1771
+ )
1670
1772
 
1671
1773
  def merge_batch(self, other: "ScheduleBatch"):
1672
1774
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1682,6 +1784,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1682
1784
  [self.req_pool_indices, other.req_pool_indices]
1683
1785
  )
1684
1786
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1787
+ self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1685
1788
  self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1686
1789
  self.out_cache_loc = None
1687
1790
  self.seq_lens_sum += other.seq_lens_sum
@@ -1725,15 +1828,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1725
1828
  self.sampling_info.grammars = None
1726
1829
 
1727
1830
  seq_lens_cpu = (
1728
- seq_lens_cpu_cache
1729
- if seq_lens_cpu_cache is not None
1730
- else self.seq_lens.cpu()
1831
+ seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
1731
1832
  )
1732
1833
 
1733
- global bid
1734
- bid += 1
1735
1834
  return ModelWorkerBatch(
1736
- bid=bid,
1737
1835
  forward_mode=self.forward_mode,
1738
1836
  input_ids=self.input_ids,
1739
1837
  req_pool_indices=self.req_pool_indices,
@@ -1780,6 +1878,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1780
1878
  ),
1781
1879
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1782
1880
  launch_done=self.launch_done,
1881
+ is_prefill_only=self.is_prefill_only,
1783
1882
  )
1784
1883
 
1785
1884
  def copy(self):
@@ -1852,8 +1951,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1852
1951
 
1853
1952
  @dataclasses.dataclass
1854
1953
  class ModelWorkerBatch:
1855
- # The batch id
1856
- bid: int
1857
1954
  # The forward mode
1858
1955
  forward_mode: ForwardMode
1859
1956
  # The input ids
@@ -1914,14 +2011,19 @@ class ModelWorkerBatch:
1914
2011
 
1915
2012
  # Speculative decoding
1916
2013
  spec_algorithm: SpeculativeAlgorithm = None
1917
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
2014
+
2015
+ spec_info: Optional[SpecInput] = None
2016
+
1918
2017
  # If set, the output of the batch contains the hidden states of the run.
1919
2018
  capture_hidden_mode: CaptureHiddenMode = None
1920
- hicache_consumer_index: int = 0
2019
+ hicache_consumer_index: int = -1
1921
2020
 
1922
2021
  # Overlap event
1923
2022
  launch_done: Optional[threading.Event] = None
1924
2023
 
2024
+ # Whether this batch is prefill-only (no token generation needed)
2025
+ is_prefill_only: bool = False
2026
+
1925
2027
 
1926
2028
  @triton.jit
1927
2029
  def write_req_to_token_pool_triton(