sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
36
36
  import copy
37
37
  import dataclasses
38
38
  import logging
39
- import threading
39
+ import re
40
40
  import time
41
41
  from enum import Enum, auto
42
42
  from http import HTTPStatus
@@ -45,10 +45,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
45
45
 
46
46
  import numpy as np
47
47
  import torch
48
- import triton
49
- import triton.language as tl
50
48
 
51
- from sglang.global_config import global_config
52
49
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
53
50
  from sglang.srt.disaggregation.base import BaseKVSender
54
51
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
@@ -56,68 +53,36 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
56
53
  )
57
54
  from sglang.srt.disaggregation.utils import DisaggregationMode
58
55
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
+ from sglang.srt.environ import envs
59
57
  from sglang.srt.mem_cache.allocator import (
60
58
  BaseTokenToKVPoolAllocator,
61
59
  SWATokenToKVPoolAllocator,
62
60
  )
63
61
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
64
- from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
65
- from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
62
+ from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
63
+ from sglang.srt.mem_cache.common import (
64
+ alloc_for_decode,
65
+ alloc_for_extend,
66
+ evict_from_tree_cache,
67
+ )
68
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
69
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
66
70
  from sglang.srt.mem_cache.radix_cache import RadixKey
67
71
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
68
72
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
69
73
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
70
74
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
71
75
  from sglang.srt.sampling.sampling_params import SamplingParams
72
- from sglang.srt.server_args import ServerArgs
73
- from sglang.srt.utils import flatten_nested_list, support_triton
76
+ from sglang.srt.server_args import ServerArgs, get_global_server_args
77
+ from sglang.srt.utils import flatten_nested_list
74
78
 
75
79
  if TYPE_CHECKING:
76
80
  from sglang.srt.configs.model_config import ModelConfig
81
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
77
82
  from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
78
83
 
79
84
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
80
85
 
81
- GLOBAL_SERVER_ARGS_KEYS = [
82
- "attention_backend",
83
- "mm_attention_backend",
84
- "debug_tensor_dump_inject",
85
- "debug_tensor_dump_output_folder",
86
- "chunked_prefill_size",
87
- "device",
88
- "disable_chunked_prefix_cache",
89
- "disable_flashinfer_cutlass_moe_fp4_allgather",
90
- "disable_radix_cache",
91
- "enable_dp_lm_head",
92
- "enable_fp32_lm_head",
93
- "flashinfer_mxfp4_moe_precision",
94
- "enable_flashinfer_allreduce_fusion",
95
- "moe_dense_tp_size",
96
- "ep_dispatch_algorithm",
97
- "ep_num_redundant_experts",
98
- "enable_nan_detection",
99
- "flashinfer_mla_disable_ragged",
100
- "max_micro_batch_size",
101
- "disable_shared_experts_fusion",
102
- "sampling_backend",
103
- "speculative_accept_threshold_single",
104
- "speculative_accept_threshold_acc",
105
- "speculative_attention_mode",
106
- "torchao_config",
107
- "triton_attention_reduce_in_fp32",
108
- "num_reserved_decode_tokens",
109
- "weight_loader_disable_mmap",
110
- "enable_multimodal",
111
- "enable_symm_mem",
112
- "enable_custom_logit_processor",
113
- "disaggregation_mode",
114
- "enable_deterministic_inference",
115
- "nsa_prefill",
116
- "nsa_decode",
117
- ]
118
-
119
- # Put some global args for easy access
120
- global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
121
86
 
122
87
  logger = logging.getLogger(__name__)
123
88
 
@@ -154,6 +119,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
154
119
  }
155
120
 
156
121
 
122
+ class FINISHED_MATCHED_REGEX(BaseFinishReason):
123
+ def __init__(self, matched: str):
124
+ super().__init__()
125
+ self.matched = matched
126
+
127
+ def to_json(self):
128
+ return {
129
+ "type": "stop", # to match OpenAI API's return value
130
+ "matched": self.matched,
131
+ }
132
+
133
+
157
134
  class FINISH_LENGTH(BaseFinishReason):
158
135
  def __init__(self, length: int):
159
136
  super().__init__()
@@ -461,6 +438,7 @@ class Req:
461
438
  priority: Optional[int] = None,
462
439
  metrics_collector: Optional[SchedulerMetricsCollector] = None,
463
440
  extra_key: Optional[str] = None,
441
+ http_worker_ipc: Optional[str] = None,
464
442
  ):
465
443
  # Input and output info
466
444
  self.rid = rid
@@ -484,6 +462,9 @@ class Req:
484
462
  # The length of KV that have been removed in local attention chunked prefill
485
463
  self.evicted_seqlen_local = 0
486
464
 
465
+ # For multi-http worker
466
+ self.http_worker_ipc = http_worker_ipc
467
+
487
468
  # Sampling info
488
469
  if isinstance(sampling_params.custom_params, dict):
489
470
  sampling_params = copy.copy(sampling_params)
@@ -505,10 +486,13 @@ class Req:
505
486
 
506
487
  # Memory pool info
507
488
  self.req_pool_idx: Optional[int] = None
489
+ self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
508
490
 
509
491
  # Check finish
510
492
  self.tokenizer = None
511
493
  self.finished_reason = None
494
+ # finished position (in output_ids), used when checking stop conditions with speculative decoding
495
+ self.finished_len = None
512
496
  # Whether this request has finished output
513
497
  self.finished_output = None
514
498
  # If we want to abort the request in the middle of the event loop, set this to true
@@ -539,7 +523,7 @@ class Req:
539
523
 
540
524
  # Prefix info
541
525
  # The indices to kv cache for the shared prefix.
542
- self.prefix_indices: torch.Tensor = []
526
+ self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
543
527
  # Number of tokens to run prefill.
544
528
  self.extend_input_len = 0
545
529
  # The relative logprob_start_len in an extend batch
@@ -630,6 +614,10 @@ class Req:
630
614
  # This is used to compute the average acceptance length per request.
631
615
  self.spec_verify_ct = 0
632
616
 
617
+ # The number of accepted tokens in speculative decoding for this request.
618
+ # This is used to compute the acceptance rate and average acceptance length per request.
619
+ self.spec_accepted_tokens = 0
620
+
633
621
  # For metrics
634
622
  self.metrics_collector = metrics_collector
635
623
  self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
@@ -666,10 +654,16 @@ class Req:
666
654
  def is_prefill_only(self) -> bool:
667
655
  """Check if this request is prefill-only (no token generation needed)."""
668
656
  # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
- return (
670
- self.sampling_params.max_new_tokens == 0
671
- and global_server_args_dict["speculative_algorithm"] is None
672
- )
657
+
658
+ spec_alg = get_global_server_args().speculative_algorithm
659
+ return self.sampling_params.max_new_tokens == 0 and spec_alg is None
660
+
661
+ @property
662
+ def output_ids_through_stop(self) -> List[int]:
663
+ """Get the output ids through the stop condition. Stop position is included."""
664
+ if self.finished_len is not None:
665
+ return self.output_ids[: self.finished_len]
666
+ return self.output_ids
673
667
 
674
668
  def add_latency(self, stage: RequestStage):
675
669
  if self.metrics_collector is None:
@@ -691,11 +685,16 @@ class Req:
691
685
  # Whether request reached finished condition
692
686
  return self.finished_reason is not None
693
687
 
694
- def init_next_round_input(
695
- self,
696
- tree_cache: Optional[BasePrefixCache] = None,
697
- ):
688
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
698
689
  self.fill_ids = self.origin_input_ids + self.output_ids
690
+ input_len = len(self.fill_ids)
691
+ # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
692
+ max_prefix_len = input_len - 1
693
+ if self.return_logprob:
694
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
695
+ max_prefix_len = max(max_prefix_len, 0)
696
+ token_ids = self.fill_ids[:max_prefix_len]
697
+
699
698
  if tree_cache is not None:
700
699
  (
701
700
  self.prefix_indices,
@@ -703,51 +702,146 @@ class Req:
703
702
  self.last_host_node,
704
703
  self.host_hit_length,
705
704
  ) = tree_cache.match_prefix(
706
- key=RadixKey(
707
- token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
705
+ key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
706
+ **(
707
+ {"req": self, "cow_mamba": True}
708
+ if isinstance(tree_cache, MambaRadixCache)
709
+ else {}
708
710
  ),
709
711
  )
710
712
  self.last_matched_prefix_len = len(self.prefix_indices)
711
713
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
712
714
 
713
- def adjust_max_prefix_ids(self):
714
- self.fill_ids = self.origin_input_ids + self.output_ids
715
- input_len = len(self.fill_ids)
716
-
717
- # FIXME: To work around some bugs in logprob computation, we need to ensure each
718
- # request has at least one token. Later, we can relax this requirement and use `input_len`.
719
- max_prefix_len = input_len - 1
720
-
721
- if self.sampling_params.max_new_tokens > 0:
722
- # Need at least one token to compute logits
723
- max_prefix_len = min(max_prefix_len, input_len - 1)
724
-
725
- if self.return_logprob:
726
- max_prefix_len = min(max_prefix_len, self.logprob_start_len)
727
-
728
- max_prefix_len = max(max_prefix_len, 0)
729
- return self.fill_ids[:max_prefix_len]
730
-
731
715
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
732
716
  def init_incremental_detokenize(self):
733
717
  first_iter = self.surr_offset is None or self.read_offset is None
734
718
 
719
+ output_ids = self.output_ids_through_stop
720
+
735
721
  if first_iter:
736
722
  self.read_offset = len(self.origin_input_ids_unpadded)
737
723
  self.surr_offset = max(
738
724
  self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
739
725
  )
740
726
  self.surr_and_decode_ids = (
741
- self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
727
+ self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
742
728
  )
743
- self.cur_decode_ids_len = len(self.output_ids)
729
+ self.cur_decode_ids_len = len(output_ids)
744
730
  else:
745
- self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
746
- self.cur_decode_ids_len = len(self.output_ids)
731
+ self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
732
+ self.cur_decode_ids_len = len(output_ids)
747
733
 
748
734
  return self.surr_and_decode_ids, self.read_offset - self.surr_offset
749
735
 
750
- def check_finished(self):
736
+ def tail_str(self) -> str:
737
+ # Check stop strings and stop regex patterns together
738
+ if (
739
+ len(self.sampling_params.stop_strs) > 0
740
+ or len(self.sampling_params.stop_regex_strs) > 0
741
+ ):
742
+ max_len_tail_str = max(
743
+ self.sampling_params.stop_str_max_len + 1,
744
+ self.sampling_params.stop_regex_max_len + 1,
745
+ )
746
+
747
+ tail_len = min((max_len_tail_str + 1), len(self.output_ids))
748
+ return self.tokenizer.decode(self.output_ids[-tail_len:])
749
+
750
+ def check_match_stop_str_prefix(self) -> bool:
751
+ """
752
+ Check if the suffix of tail_str overlaps with any stop_str prefix
753
+ """
754
+ if not self.sampling_params.stop_strs:
755
+ return False
756
+
757
+ tail_str = self.tail_str()
758
+
759
+ # Early return if tail_str is empty
760
+ if not tail_str:
761
+ return False
762
+
763
+ for stop_str in self.sampling_params.stop_strs:
764
+ if not stop_str:
765
+ continue
766
+ # Check if stop_str is contained in tail_str (fastest check first)
767
+ if stop_str in tail_str:
768
+ return True
769
+
770
+ # Check if tail_str suffix matches stop_str prefix
771
+ # Only check if stop_str is not empty, it's for stream output
772
+ min_len = min(len(tail_str), len(stop_str))
773
+ for i in range(1, min_len + 1):
774
+ if tail_str[-i:] == stop_str[:i]:
775
+ return True
776
+
777
+ return False
778
+
779
+ def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
780
+ if self.sampling_params.ignore_eos:
781
+ return False
782
+
783
+ # Check stop token ids
784
+ matched_eos = False
785
+
786
+ for i, token_id in enumerate(new_accepted_tokens):
787
+ if self.sampling_params.stop_token_ids:
788
+ matched_eos |= token_id in self.sampling_params.stop_token_ids
789
+ if self.eos_token_ids:
790
+ matched_eos |= token_id in self.eos_token_ids
791
+ if self.tokenizer is not None:
792
+ matched_eos |= token_id == self.tokenizer.eos_token_id
793
+ if self.tokenizer.additional_stop_token_ids:
794
+ matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
795
+ if matched_eos:
796
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
797
+ matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
798
+ self.finished_len = matched_pos + 1
799
+ return True
800
+
801
+ return False
802
+
803
+ def _check_str_based_finish(self):
804
+ if (
805
+ len(self.sampling_params.stop_strs) > 0
806
+ or len(self.sampling_params.stop_regex_strs) > 0
807
+ ):
808
+ tail_str = self.tail_str()
809
+
810
+ # Check stop strings
811
+ if len(self.sampling_params.stop_strs) > 0:
812
+ for stop_str in self.sampling_params.stop_strs:
813
+ if stop_str in tail_str or stop_str in self.decoded_text:
814
+ self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
815
+ return True
816
+
817
+ # Check stop regex
818
+ if len(self.sampling_params.stop_regex_strs) > 0:
819
+ for stop_regex_str in self.sampling_params.stop_regex_strs:
820
+ if re.search(stop_regex_str, tail_str):
821
+ self.finished_reason = FINISHED_MATCHED_REGEX(
822
+ matched=stop_regex_str
823
+ )
824
+ return True
825
+
826
+ return False
827
+
828
+ def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
829
+ for i, token_id in enumerate(new_accepted_tokens):
830
+ if token_id > self.vocab_size or token_id < 0:
831
+ offset = len(self.output_ids) - len(new_accepted_tokens) + i
832
+ if self.sampling_params.stop_token_ids:
833
+ self.output_ids[offset] = next(
834
+ iter(self.sampling_params.stop_token_ids)
835
+ )
836
+ if self.eos_token_ids:
837
+ self.output_ids[offset] = next(iter(self.eos_token_ids))
838
+ self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
839
+ self.finished_len = offset + 1
840
+ return True
841
+
842
+ return False
843
+
844
+ def check_finished(self, new_accepted_len: int = 1):
751
845
  if self.finished():
752
846
  return
753
847
 
@@ -761,6 +855,7 @@ class Req:
761
855
  self.finished_reason = FINISH_LENGTH(
762
856
  length=self.sampling_params.max_new_tokens
763
857
  )
858
+ self.finished_len = self.sampling_params.max_new_tokens
764
859
  return
765
860
 
766
861
  if self.grammar is not None:
@@ -768,47 +863,19 @@ class Req:
768
863
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
769
864
  return
770
865
 
771
- last_token_id = self.output_ids[-1]
866
+ new_accepted_tokens = self.output_ids[-new_accepted_len:]
772
867
 
773
- if not self.sampling_params.ignore_eos:
774
- matched_eos = False
775
-
776
- # Check stop token ids
777
- if self.sampling_params.stop_token_ids:
778
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
779
- if self.eos_token_ids:
780
- matched_eos |= last_token_id in self.eos_token_ids
781
- if self.tokenizer is not None:
782
- matched_eos |= last_token_id == self.tokenizer.eos_token_id
783
- if self.tokenizer.additional_stop_token_ids:
784
- matched_eos |= (
785
- last_token_id in self.tokenizer.additional_stop_token_ids
786
- )
787
- if matched_eos:
788
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
789
- return
790
-
791
- if last_token_id > self.vocab_size or last_token_id < 0:
792
- if self.sampling_params.stop_token_ids:
793
- self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
794
- if self.eos_token_ids:
795
- self.output_ids[-1] = next(iter(self.eos_token_ids))
796
- self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
868
+ if self._check_token_based_finish(new_accepted_tokens):
797
869
  return
798
870
 
799
- # Check stop strings
800
- if len(self.sampling_params.stop_strs) > 0:
801
- tail_str = self.tokenizer.decode(
802
- self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
803
- )
871
+ if self._check_vocab_boundary_finish(new_accepted_tokens):
872
+ return
804
873
 
805
- for stop_str in self.sampling_params.stop_strs:
806
- if stop_str in tail_str or stop_str in self.decoded_text:
807
- self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
808
- return
874
+ if self._check_str_based_finish():
875
+ return
809
876
 
810
877
  def reset_for_retract(self):
811
- self.prefix_indices = []
878
+ self.prefix_indices = torch.empty((0,), dtype=torch.int64)
812
879
  self.last_node = None
813
880
  self.swa_uuid_for_lock = None
814
881
  self.extend_input_len = 0
@@ -818,7 +885,7 @@ class Req:
818
885
  self.temp_input_top_logprobs_idx = None
819
886
  self.extend_logprob_start_len = 0
820
887
  self.is_chunked = 0
821
- self.req_pool_idx = None
888
+ self.mamba_pool_idx = None
822
889
  self.already_computed = 0
823
890
 
824
891
  def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
@@ -886,15 +953,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
886
953
  # This is an optimization to reduce the overhead of the prefill check.
887
954
  batch_is_full: bool = False
888
955
 
889
- # Events
890
- launch_done: Optional[threading.Event] = None
891
-
892
956
  # For chunked prefill in PP
893
957
  chunked_req: Optional[Req] = None
894
958
 
895
959
  # Sampling info
896
960
  sampling_info: SamplingBatchInfo = None
897
- next_batch_sampling_info: SamplingBatchInfo = None
898
961
 
899
962
  # Batched arguments to model runner
900
963
  input_ids: torch.Tensor = None # shape: [b], int64
@@ -1017,117 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1017
1080
  def is_empty(self):
1018
1081
  return len(self.reqs) == 0
1019
1082
 
1020
- def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
1021
- if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
1022
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
1023
- else:
1024
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
1025
- if req_pool_indices is None:
1026
- raise RuntimeError(
1027
- "alloc_req_slots runs out of memory. "
1028
- "Please set a smaller number for `--max-running-requests`. "
1029
- f"{self.req_to_token_pool.available_size()=}, "
1030
- f"{num_reqs=}, "
1031
- )
1032
- return req_pool_indices
1033
-
1034
- def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
1035
- self._evict_tree_cache_if_needed(num_tokens)
1036
-
1037
- if backup_state:
1038
- state = self.token_to_kv_pool_allocator.backup_state()
1039
-
1040
- out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
1041
- if out_cache_loc is None:
1042
- phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
1043
- error_msg = (
1044
- f"{phase_str} out of memory. Try to lower your batch size.\n"
1045
- f"Try to allocate {num_tokens} tokens.\n"
1046
- f"{self._available_and_evictable_str()}"
1047
- )
1048
- logger.error(error_msg)
1049
- if self.tree_cache is not None:
1050
- self.tree_cache.pretty_print()
1051
- raise RuntimeError(error_msg)
1052
-
1053
- if backup_state:
1054
- return out_cache_loc, state
1055
- else:
1056
- return out_cache_loc
1057
-
1058
- def alloc_paged_token_slots_extend(
1059
- self,
1060
- prefix_lens: torch.Tensor,
1061
- prefix_lens_cpu: torch.Tensor,
1062
- seq_lens: torch.Tensor,
1063
- seq_lens_cpu: torch.Tensor,
1064
- last_loc: torch.Tensor,
1065
- extend_num_tokens: int,
1066
- backup_state: bool = False,
1067
- ):
1068
- # Over estimate the number of tokens: assume each request needs a new page.
1069
- num_tokens = (
1070
- extend_num_tokens
1071
- + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
1072
- )
1073
- self._evict_tree_cache_if_needed(num_tokens)
1074
-
1075
- if backup_state:
1076
- state = self.token_to_kv_pool_allocator.backup_state()
1077
-
1078
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
1079
- prefix_lens,
1080
- prefix_lens_cpu,
1081
- seq_lens,
1082
- seq_lens_cpu,
1083
- last_loc,
1084
- extend_num_tokens,
1085
- )
1086
- if out_cache_loc is None:
1087
- error_msg = (
1088
- f"Prefill out of memory. Try to lower your batch size.\n"
1089
- f"Try to allocate {extend_num_tokens} tokens.\n"
1090
- f"{self._available_and_evictable_str()}"
1091
- )
1092
- logger.error(error_msg)
1093
- raise RuntimeError(error_msg)
1094
-
1095
- if backup_state:
1096
- return out_cache_loc, state
1097
- else:
1098
- return out_cache_loc
1099
-
1100
- def alloc_paged_token_slots_decode(
1101
- self,
1102
- seq_lens: torch.Tensor,
1103
- seq_lens_cpu: torch.Tensor,
1104
- last_loc: torch.Tensor,
1105
- backup_state: bool = False,
1106
- ):
1107
- # Over estimate the number of tokens: assume each request needs a new page.
1108
- num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1109
- self._evict_tree_cache_if_needed(num_tokens)
1110
-
1111
- if backup_state:
1112
- state = self.token_to_kv_pool_allocator.backup_state()
1113
-
1114
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
1115
- seq_lens, seq_lens_cpu, last_loc
1116
- )
1117
- if out_cache_loc is None:
1118
- error_msg = (
1119
- f"Decode out of memory. Try to lower your batch size.\n"
1120
- f"Try to allocate {len(seq_lens)} tokens.\n"
1121
- f"{self._available_and_evictable_str()}"
1122
- )
1123
- logger.error(error_msg)
1124
- raise RuntimeError(error_msg)
1125
-
1126
- if backup_state:
1127
- return out_cache_loc, state
1128
- else:
1129
- return out_cache_loc
1130
-
1131
1083
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
1132
1084
  self.encoder_lens_cpu = []
1133
1085
  self.encoder_cached = []
@@ -1205,10 +1157,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1205
1157
  def prepare_for_extend(self):
1206
1158
  self.forward_mode = ForwardMode.EXTEND
1207
1159
 
1208
- # Allocate req slots
1209
- bs = len(self.reqs)
1210
- req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1211
-
1212
1160
  # Init tensors
1213
1161
  reqs = self.reqs
1214
1162
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -1222,9 +1170,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1222
1170
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1223
1171
  ]
1224
1172
 
1225
- req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1226
- self.device, non_blocking=True
1227
- )
1228
1173
  input_ids_tensor = torch.tensor(
1229
1174
  list(chain.from_iterable(input_ids)), dtype=torch.int64
1230
1175
  ).to(self.device, non_blocking=True)
@@ -1235,10 +1180,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1235
1180
  orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1236
1181
  self.device, non_blocking=True
1237
1182
  )
1238
- prefix_lens_tensor = torch.tensor(
1239
- prefix_lens, dtype=torch.int64, device=self.device
1240
- )
1241
- prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
1242
1183
 
1243
1184
  token_type_ids_tensor = None
1244
1185
  if len(token_type_ids) > 0:
@@ -1246,9 +1187,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1246
1187
  sum(token_type_ids, []), dtype=torch.int64
1247
1188
  ).to(self.device, non_blocking=True)
1248
1189
 
1249
- extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1190
+ # Set batch fields needed by alloc_for_extend
1191
+ self.prefix_lens = prefix_lens
1192
+ self.extend_lens = extend_lens
1193
+ self.seq_lens = seq_lens_tensor
1194
+ self.seq_lens_cpu = seq_lens_cpu
1195
+ self.extend_num_tokens = extend_num_tokens
1196
+
1197
+ # Allocate memory
1198
+ out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
1199
+ self
1200
+ )
1250
1201
 
1251
- # Copy prefix and do some basic check
1202
+ # Set fields
1252
1203
  input_embeds = []
1253
1204
  extend_input_logprob_token_ids = []
1254
1205
  multimodal_inputs = []
@@ -1257,15 +1208,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1257
1208
  req.req_pool_idx = req_pool_indices[i]
1258
1209
  assert seq_len - pre_len == req.extend_input_len
1259
1210
 
1260
- if pre_len > 0:
1261
- self.req_to_token_pool.write(
1262
- (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1263
- )
1264
- if isinstance(self.tree_cache, SWAChunkCache):
1265
- self.tree_cache.evict_swa(
1266
- req, pre_len, self.model_config.attention_chunk_size
1267
- )
1268
-
1269
1211
  # If input_embeds are available, store them
1270
1212
  if req.input_embeds is not None:
1271
1213
  # If req.input_embeds is already a list, append its content directly
@@ -1355,29 +1297,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1355
1297
  else:
1356
1298
  extend_input_logprob_token_ids = None
1357
1299
 
1358
- # Allocate memory
1359
- if self.token_to_kv_pool_allocator.page_size == 1:
1360
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1361
- else:
1362
- last_loc = get_last_loc(
1363
- self.req_to_token_pool.req_to_token,
1364
- req_pool_indices_tensor,
1365
- prefix_lens_tensor,
1366
- )
1367
- out_cache_loc = self.alloc_paged_token_slots_extend(
1368
- prefix_lens_tensor,
1369
- prefix_lens_cpu_tensor,
1370
- seq_lens_tensor,
1371
- seq_lens_cpu,
1372
- last_loc,
1373
- extend_num_tokens,
1374
- )
1375
-
1376
- # Set fields
1377
1300
  self.input_ids = input_ids_tensor
1378
1301
  self.req_pool_indices = req_pool_indices_tensor
1379
- self.seq_lens = seq_lens_tensor
1380
- self.seq_lens_cpu = seq_lens_cpu
1381
1302
  self.orig_seq_lens = orig_seq_lens_tensor
1382
1303
  self.out_cache_loc = out_cache_loc
1383
1304
  self.input_embeds = (
@@ -1401,33 +1322,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1401
1322
  self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
1402
1323
 
1403
1324
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1404
- self.extend_num_tokens = extend_num_tokens
1405
- self.prefix_lens = prefix_lens
1406
- self.extend_lens = extend_lens
1407
1325
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
1408
1326
 
1409
- # Write to req_to_token_pool
1410
- if support_triton(global_server_args_dict.get("attention_backend")):
1411
- # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1412
-
1413
- write_req_to_token_pool_triton[(bs,)](
1414
- self.req_to_token_pool.req_to_token,
1415
- req_pool_indices_tensor,
1416
- prefix_lens_tensor,
1417
- seq_lens_tensor,
1418
- extend_lens_tensor,
1419
- out_cache_loc,
1420
- self.req_to_token_pool.req_to_token.shape[1],
1421
- )
1422
- else:
1423
- pt = 0
1424
- for i in range(bs):
1425
- self.req_to_token_pool.write(
1426
- (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1427
- out_cache_loc[pt : pt + extend_lens[i]],
1428
- )
1429
- pt += extend_lens[i]
1430
-
1431
1327
  if self.model_config.is_encoder_decoder:
1432
1328
  self.prepare_encoder_info_extend(input_ids, seq_lens)
1433
1329
 
@@ -1498,7 +1394,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1498
1394
  * self.token_to_kv_pool_allocator.page_size
1499
1395
  )
1500
1396
 
1501
- self._evict_tree_cache_if_needed(num_tokens)
1397
+ evict_from_tree_cache(self.tree_cache, num_tokens)
1502
1398
  return self._is_available_size_sufficient(num_tokens)
1503
1399
 
1504
1400
  def retract_decode(self, server_args: ServerArgs):
@@ -1546,6 +1442,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1546
1442
  idx = sorted_indices.pop()
1547
1443
  req = self.reqs[idx]
1548
1444
  retracted_reqs.append(req)
1445
+ # release memory and don't insert into the tree because we need the space instantly
1549
1446
  self.release_req(idx, len(sorted_indices), server_args)
1550
1447
 
1551
1448
  if len(retracted_reqs) == 0:
@@ -1561,47 +1458,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1561
1458
  total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
1562
1459
 
1563
1460
  new_estimate_ratio = (
1564
- total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
1565
- ) / total_max_new_tokens
1461
+ total_decoded_tokens
1462
+ + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
1463
+ ) / (
1464
+ total_max_new_tokens + 1
1465
+ ) # avoid zero division
1566
1466
  new_estimate_ratio = min(1.0, new_estimate_ratio)
1567
1467
 
1568
1468
  return retracted_reqs, new_estimate_ratio, []
1569
1469
 
1570
1470
  def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
1571
1471
  req = self.reqs[idx]
1572
- seq_lens_cpu = self.seq_lens_cpu.numpy()
1573
1472
 
1574
1473
  if server_args.disaggregation_mode == "decode":
1575
1474
  req.offload_kv_cache(
1576
1475
  self.req_to_token_pool, self.token_to_kv_pool_allocator
1577
1476
  )
1578
- if isinstance(self.tree_cache, ChunkCache):
1579
- # ChunkCache does not have eviction
1580
- token_indices = self.req_to_token_pool.req_to_token[
1581
- req.req_pool_idx, : seq_lens_cpu[idx]
1582
- ]
1583
- self.token_to_kv_pool_allocator.free(token_indices)
1584
- self.req_to_token_pool.free(req.req_pool_idx)
1585
- else:
1586
- # TODO: apply more fine-grained retraction
1587
- last_uncached_pos = (
1588
- len(req.prefix_indices) // server_args.page_size
1589
- ) * server_args.page_size
1590
- token_indices = self.req_to_token_pool.req_to_token[
1591
- req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1592
- ]
1593
- self.token_to_kv_pool_allocator.free(token_indices)
1594
- self.req_to_token_pool.free(req.req_pool_idx)
1595
-
1596
- # release the last node
1597
- if self.is_hybrid:
1598
- self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1599
- else:
1600
- self.tree_cache.dec_lock_ref(req.last_node)
1601
-
1602
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1603
- num_tokens = remaing_req_count * global_config.retract_decode_steps
1604
- self._evict_tree_cache_if_needed(num_tokens)
1477
+ # TODO (csy): for preempted requests, we may want to insert into the tree
1478
+ self.tree_cache.cache_finished_req(req, is_insert=False)
1479
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
1480
+ num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
1481
+ evict_from_tree_cache(self.tree_cache, num_tokens)
1605
1482
 
1606
1483
  req.reset_for_retract()
1607
1484
 
@@ -1624,15 +1501,21 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1624
1501
  self.model_config.vocab_size,
1625
1502
  )
1626
1503
 
1504
+ @property
1505
+ def is_v2_eagle(self):
1506
+ # FIXME: finally deprecate is_v2_eagle
1507
+ return self.enable_overlap and self.spec_algorithm.is_eagle()
1508
+
1627
1509
  def prepare_for_decode(self):
1628
1510
  self.forward_mode = ForwardMode.DECODE
1629
1511
  bs = len(self.reqs)
1630
1512
 
1631
- if (
1632
- self.spec_algorithm.is_eagle()
1633
- or self.spec_algorithm.is_standalone()
1634
- or self.spec_algorithm.is_ngram()
1635
- ):
1513
+ if self.is_v2_eagle:
1514
+ # TODO(spec-v2): all v2 spec should go through this path
1515
+ draft_input: EagleDraftInput = self.spec_info
1516
+ draft_input.prepare_for_decode(self)
1517
+
1518
+ if not self.spec_algorithm.is_none():
1636
1519
  # if spec decoding is used, the decode batch is prepared inside
1637
1520
  # `forward_batch_speculative_generation` after running draft models.
1638
1521
  return
@@ -1665,11 +1548,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1665
1548
  self.output_ids = None
1666
1549
 
1667
1550
  if self.model_config.is_encoder_decoder:
1668
- locs = self.encoder_lens + self.seq_lens
1669
1551
  self.prepare_encoder_info_decode()
1670
- else:
1671
- locs = self.seq_lens.clone()
1672
1552
 
1553
+ # Allocate memory
1554
+ self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
1555
+
1556
+ # Update seq_lens after allocation
1673
1557
  if self.enable_overlap:
1674
1558
  # Do not use in-place operations in the overlap mode
1675
1559
  self.seq_lens = self.seq_lens + 1
@@ -1682,33 +1566,21 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1682
1566
  self.orig_seq_lens.add_(1)
1683
1567
  self.seq_lens_sum += bs
1684
1568
 
1685
- # free memory
1686
- if isinstance(self.tree_cache, SWAChunkCache):
1687
- for req in self.reqs:
1688
- self.tree_cache.evict_swa(
1689
- req, req.seqlen - 1, self.model_config.attention_chunk_size
1690
- )
1691
-
1692
- # Allocate memory
1693
- if self.token_to_kv_pool_allocator.page_size == 1:
1694
- self.out_cache_loc = self.alloc_token_slots(bs)
1695
- else:
1696
- last_loc = self.req_to_token_pool.req_to_token[
1697
- self.req_pool_indices, self.seq_lens - 2
1698
- ]
1699
- self.out_cache_loc = self.alloc_paged_token_slots_decode(
1700
- self.seq_lens, self.seq_lens_cpu, last_loc
1701
- )
1702
-
1703
- self.req_to_token_pool.write(
1704
- (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
1705
- )
1569
+ def maybe_wait_verify_done(self):
1570
+ if self.is_v2_eagle:
1571
+ draft_input: EagleDraftInput = self.spec_info
1572
+ if draft_input.verify_done is not None:
1573
+ draft_input.verify_done.synchronize()
1706
1574
 
1707
1575
  def filter_batch(
1708
1576
  self,
1709
1577
  chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1710
1578
  keep_indices: Optional[List[int]] = None,
1711
1579
  ):
1580
+ # FIXME(lsyin): used here to get the correct seq_lens
1581
+ # The batch has been launched but we need it verified to get correct next batch info
1582
+ self.maybe_wait_verify_done()
1583
+
1712
1584
  if keep_indices is None:
1713
1585
  if isinstance(chunked_req_to_exclude, Req):
1714
1586
  chunked_req_to_exclude = [chunked_req_to_exclude]
@@ -1771,6 +1643,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1771
1643
  )
1772
1644
 
1773
1645
  def merge_batch(self, other: "ScheduleBatch"):
1646
+ # NOTE: in v2 eagle mode, we do not need wait verify here because
1647
+ # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
1648
+ # 2) other batch is always decode, which is finished in previous step
1649
+
1774
1650
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
1775
1651
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
1776
1652
  # needs to be called with pre-merged Batch.reqs.
@@ -1877,7 +1753,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1877
1753
  )
1878
1754
  ),
1879
1755
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1880
- launch_done=self.launch_done,
1881
1756
  is_prefill_only=self.is_prefill_only,
1882
1757
  )
1883
1758
 
@@ -1885,6 +1760,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1885
1760
  # Only contain fields that will be used by process_batch_result
1886
1761
  return ScheduleBatch(
1887
1762
  reqs=self.reqs,
1763
+ req_to_token_pool=self.req_to_token_pool,
1764
+ req_pool_indices=self.req_pool_indices,
1888
1765
  model_config=self.model_config,
1889
1766
  forward_mode=self.forward_mode,
1890
1767
  out_cache_loc=self.out_cache_loc,
@@ -1896,26 +1773,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1896
1773
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1897
1774
  is_extend_in_batch=self.is_extend_in_batch,
1898
1775
  is_prefill_only=self.is_prefill_only,
1776
+ seq_lens_cpu=self.seq_lens_cpu,
1777
+ enable_overlap=self.enable_overlap,
1899
1778
  )
1900
1779
 
1901
- def _evict_tree_cache_if_needed(self, num_tokens: int):
1902
- if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
1903
- return
1904
-
1905
- if self.is_hybrid:
1906
- full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1907
- swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1908
-
1909
- if full_available_size < num_tokens or swa_available_size < num_tokens:
1910
- if self.tree_cache is not None:
1911
- full_num_tokens = max(0, num_tokens - full_available_size)
1912
- swa_num_tokens = max(0, num_tokens - swa_available_size)
1913
- self.tree_cache.evict(full_num_tokens, swa_num_tokens)
1914
- else:
1915
- if self.token_to_kv_pool_allocator.available_size() < num_tokens:
1916
- if self.tree_cache is not None:
1917
- self.tree_cache.evict(num_tokens)
1918
-
1919
1780
  def _is_available_size_sufficient(self, num_tokens: int) -> bool:
1920
1781
  if self.is_hybrid:
1921
1782
  return (
@@ -1925,23 +1786,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1925
1786
  else:
1926
1787
  return self.token_to_kv_pool_allocator.available_size() >= num_tokens
1927
1788
 
1928
- def _available_and_evictable_str(self) -> str:
1929
- if self.is_hybrid:
1930
- full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1931
- swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1932
- full_evictable_size = self.tree_cache.full_evictable_size()
1933
- swa_evictable_size = self.tree_cache.swa_evictable_size()
1934
- return (
1935
- f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
1936
- f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
1937
- f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
1938
- f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
1939
- )
1940
- else:
1941
- available_size = self.token_to_kv_pool_allocator.available_size()
1942
- evictable_size = self.tree_cache.evictable_size()
1943
- return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
1944
-
1945
1789
  def __str__(self):
1946
1790
  return (
1947
1791
  f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -2018,119 +1862,5 @@ class ModelWorkerBatch:
2018
1862
  capture_hidden_mode: CaptureHiddenMode = None
2019
1863
  hicache_consumer_index: int = -1
2020
1864
 
2021
- # Overlap event
2022
- launch_done: Optional[threading.Event] = None
2023
-
2024
1865
  # Whether this batch is prefill-only (no token generation needed)
2025
1866
  is_prefill_only: bool = False
2026
-
2027
-
2028
- @triton.jit
2029
- def write_req_to_token_pool_triton(
2030
- req_to_token_ptr, # [max_batch, max_context_len]
2031
- req_pool_indices,
2032
- pre_lens,
2033
- seq_lens,
2034
- extend_lens,
2035
- out_cache_loc,
2036
- req_to_token_ptr_stride: tl.constexpr,
2037
- ):
2038
- BLOCK_SIZE: tl.constexpr = 512
2039
- pid = tl.program_id(0)
2040
-
2041
- req_pool_index = tl.load(req_pool_indices + pid)
2042
- pre_len = tl.load(pre_lens + pid)
2043
- seq_len = tl.load(seq_lens + pid)
2044
-
2045
- # NOTE: This can be slow for large bs
2046
- cumsum_start = tl.cast(0, tl.int64)
2047
- for i in range(pid):
2048
- cumsum_start += tl.load(extend_lens + i)
2049
-
2050
- num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
2051
- for i in range(num_loop):
2052
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
2053
- mask = offset < (seq_len - pre_len)
2054
- value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
2055
- tl.store(
2056
- req_to_token_ptr
2057
- + req_pool_index * req_to_token_ptr_stride
2058
- + offset
2059
- + pre_len,
2060
- value,
2061
- mask=mask,
2062
- )
2063
-
2064
-
2065
- def get_last_loc(
2066
- req_to_token: torch.Tensor,
2067
- req_pool_indices_tensor: torch.Tensor,
2068
- prefix_lens_tensor: torch.Tensor,
2069
- ) -> torch.Tensor:
2070
- if (
2071
- global_server_args_dict["attention_backend"] != "ascend"
2072
- and global_server_args_dict["attention_backend"] != "torch_native"
2073
- ):
2074
- impl = get_last_loc_triton
2075
- else:
2076
- impl = get_last_loc_torch
2077
-
2078
- return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
2079
-
2080
-
2081
- def get_last_loc_torch(
2082
- req_to_token: torch.Tensor,
2083
- req_pool_indices_tensor: torch.Tensor,
2084
- prefix_lens_tensor: torch.Tensor,
2085
- ) -> torch.Tensor:
2086
- return torch.where(
2087
- prefix_lens_tensor > 0,
2088
- req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
2089
- torch.full_like(prefix_lens_tensor, -1),
2090
- )
2091
-
2092
-
2093
- @triton.jit
2094
- def get_last_loc_kernel(
2095
- req_to_token,
2096
- req_pool_indices_tensor,
2097
- prefix_lens_tensor,
2098
- result,
2099
- num_tokens,
2100
- req_to_token_stride,
2101
- BLOCK_SIZE: tl.constexpr,
2102
- ):
2103
- pid = tl.program_id(0)
2104
- offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
2105
- mask = offset < num_tokens
2106
-
2107
- prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
2108
- req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
2109
-
2110
- token_mask = prefix_lens > 0
2111
- token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
2112
- tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
2113
-
2114
- tl.store(result + offset, tokens, mask=mask)
2115
-
2116
-
2117
- def get_last_loc_triton(
2118
- req_to_token: torch.Tensor,
2119
- req_pool_indices_tensor: torch.Tensor,
2120
- prefix_lens_tensor: torch.Tensor,
2121
- ) -> torch.Tensor:
2122
- BLOCK_SIZE = 256
2123
- num_tokens = prefix_lens_tensor.shape[0]
2124
- result = torch.empty_like(prefix_lens_tensor)
2125
- grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
2126
-
2127
- get_last_loc_kernel[grid](
2128
- req_to_token,
2129
- req_pool_indices_tensor,
2130
- prefix_lens_tensor,
2131
- result,
2132
- num_tokens,
2133
- req_to_token.stride(0),
2134
- BLOCK_SIZE,
2135
- )
2136
- return result