sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,8 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional
25
+ from functools import lru_cache, partial
26
+ from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
27
27
 
28
28
  import torch
29
29
 
@@ -34,12 +34,44 @@ from sglang.srt.disaggregation.kv_events import (
34
34
  )
35
35
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
36
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.evict_policy import (
38
+ EvictionStrategy,
39
+ FIFOStrategy,
40
+ FILOStrategy,
41
+ LFUStrategy,
42
+ LRUStrategy,
43
+ MRUStrategy,
44
+ )
37
45
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
38
46
 
39
47
  if TYPE_CHECKING:
40
48
  from sglang.srt.managers.schedule_batch import Req
41
49
 
42
50
 
51
+ class RadixKey:
52
+
53
+ def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
54
+ # token ids sequence
55
+ self.token_ids = token_ids
56
+ # extra key (e.g. lora_id, cache_salt)
57
+ self.extra_key = extra_key
58
+
59
+ def __len__(self) -> int:
60
+ return len(self.token_ids)
61
+
62
+ def __iter__(self) -> Iterator[int]:
63
+ return iter(self.token_ids)
64
+
65
+ def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
66
+ if isinstance(idx, slice):
67
+ return RadixKey(self.token_ids[idx], self.extra_key)
68
+ return RadixKey([self.token_ids[idx]], self.extra_key)
69
+
70
+ def __repr__(self) -> str:
71
+ preview = self.token_ids[:10]
72
+ return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
73
+
74
+
43
75
  class TreeNode:
44
76
 
45
77
  counter = 0
@@ -47,10 +79,11 @@ class TreeNode:
47
79
  def __init__(self, id: Optional[int] = None):
48
80
  self.children = defaultdict(TreeNode)
49
81
  self.parent: TreeNode = None
50
- self.key: List[int] = None
82
+ self.key: RadixKey = None
51
83
  self.value: Optional[torch.Tensor] = None
52
84
  self.lock_ref = 0
53
85
  self.last_access_time = time.monotonic()
86
+ self.creation_time = time.monotonic()
54
87
 
55
88
  self.hit_count = 0
56
89
  # indicating the node is locked to protect from eviction
@@ -89,31 +122,68 @@ class TreeNode:
89
122
  return None
90
123
  return self.hash_value[-1]
91
124
 
125
+ @lru_cache(maxsize=1)
126
+ def get_prefix_hash_values(self, node: TreeNode) -> List[str]:
127
+ if node is None or node.hash_value is None:
128
+ return []
129
+
130
+ return node.get_prefix_hash_values(node.parent) + node.hash_value
131
+
92
132
  def __lt__(self, other: "TreeNode"):
93
133
  return self.last_access_time < other.last_access_time
94
134
 
95
135
 
96
- def _key_match_page_size1(key0: List, key1: List):
136
+ def _check_extra_key(key0: RadixKey, key1: RadixKey):
137
+ if key0.extra_key != key1.extra_key:
138
+ raise ValueError(
139
+ f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
140
+ )
141
+
142
+
143
+ def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
144
+ _check_extra_key(key0, key1)
97
145
  i = 0
98
- for k0, k1 in zip(key0, key1):
146
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
99
147
  if k0 != k1:
100
148
  break
101
149
  i += 1
102
150
  return i
103
151
 
104
152
 
105
- def _key_match_paged(key0: List, key1: List, page_size: int):
153
+ def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
154
+ _check_extra_key(key0, key1)
106
155
  min_len = min(len(key0), len(key1))
107
156
 
108
157
  i = 0
109
158
  while i < min_len:
110
- if key0[i : i + page_size] != key1[i : i + page_size]:
159
+ if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
111
160
  break
112
161
  i += page_size
113
162
 
114
163
  return i
115
164
 
116
165
 
166
+ def get_child_key(key: RadixKey, page_size: int = 1):
167
+ if page_size == 1:
168
+ plain_key = key.token_ids[0]
169
+ else:
170
+ plain_key = tuple(key.token_ids[:page_size])
171
+ if key.extra_key is None:
172
+ return plain_key
173
+ else:
174
+ return (key.extra_key, plain_key)
175
+
176
+
177
+ def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
178
+ # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
179
+ # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
180
+ if len(tokens) < 2:
181
+ return []
182
+ if isinstance(tokens[0], tuple):
183
+ return tokens
184
+ return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
185
+
186
+
117
187
  class RadixCache(BasePrefixCache):
118
188
  def __init__(
119
189
  self,
@@ -122,6 +192,8 @@ class RadixCache(BasePrefixCache):
122
192
  page_size: int,
123
193
  disable: bool = False,
124
194
  enable_kv_cache_events: bool = False,
195
+ eviction_policy: str = "lru",
196
+ is_eagle: bool = False,
125
197
  ):
126
198
  self.req_to_token_pool = req_to_token_pool
127
199
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -129,6 +201,7 @@ class RadixCache(BasePrefixCache):
129
201
  self.disable = disable
130
202
  self.enable_kv_cache_events = enable_kv_cache_events
131
203
  self.kv_event_queue = []
204
+ self.is_eagle = is_eagle
132
205
 
133
206
  if self.token_to_kv_pool_allocator:
134
207
  self.device = self.token_to_kv_pool_allocator.device
@@ -137,17 +210,37 @@ class RadixCache(BasePrefixCache):
137
210
 
138
211
  if self.page_size == 1:
139
212
  self.key_match_fn = _key_match_page_size1
140
- self.get_child_key_fn = lambda key: key[0]
213
+ self.get_child_key_fn = get_child_key
141
214
  else:
142
215
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
143
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
216
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
217
+
218
+ if is_eagle:
219
+ self.key_convert_fn = _convert_to_bigram_key
220
+ else:
221
+ self.key_convert_fn = lambda key: key
222
+
223
+ if eviction_policy.lower() == "lru":
224
+ self.eviction_strategy: EvictionStrategy = LRUStrategy()
225
+ elif eviction_policy.lower() == "lfu":
226
+ self.eviction_strategy: EvictionStrategy = LFUStrategy()
227
+ elif eviction_policy.lower() == "fifo":
228
+ self.eviction_strategy: EvictionStrategy = FIFOStrategy()
229
+ elif eviction_policy.lower() == "mru":
230
+ self.eviction_strategy: EvictionStrategy = MRUStrategy()
231
+ elif eviction_policy.lower() == "filo":
232
+ self.eviction_strategy: EvictionStrategy = FILOStrategy()
233
+ else:
234
+ raise ValueError(
235
+ f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu', 'fifo', 'mru', 'filo'."
236
+ )
144
237
  self.reset()
145
238
 
146
239
  ##### Public API #####
147
240
 
148
241
  def reset(self):
149
242
  self.root_node = TreeNode()
150
- self.root_node.key = []
243
+ self.root_node.key = RadixKey(token_ids=[], extra_key=None)
151
244
  self.root_node.value = []
152
245
  self.root_node.host_value = []
153
246
  self.root_node.lock_ref = 1
@@ -155,18 +248,47 @@ class RadixCache(BasePrefixCache):
155
248
  self.protected_size_ = 0
156
249
  self._record_all_cleared_event()
157
250
 
158
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
159
- """Find the matching prefix from the radix tree.
251
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
252
+ """Find the longest cached prefix of ``key`` in the radix tree.
253
+
254
+ The logical namespace for prefix matching is determined by both the
255
+ token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
256
+ Entries that share identical leading token ids but have *different*
257
+ ``extra_key`` values are intentionally kept disjoint and never share
258
+ prefix nodes. This is useful to:
259
+
260
+ * Isolate KV cache lines for different LoRA / adapter IDs.
261
+ * Separate requests that intentionally should not share state (e.g.,
262
+ different sampling salt, cache version, or retrieval augmentation
263
+ context) by supplying a distinct ``extra_key``.
264
+
160
265
  Args:
161
- key: A list of token IDs to find a matching prefix.
266
+ key (RadixKey): The lookup key containing a list of token ids and an
267
+ optional ``extra_key`` namespace tag. If ``page_size > 1`` the
268
+ length is internally truncated to a multiple of ``page_size``
269
+ before matching. Passing an empty key returns an empty result
270
+ with the root as the last node.
271
+ **kwargs: Reserved for future extensions (ignored currently).
272
+
162
273
  Returns:
163
- A tuple of a tensor of matching prefix token IDs and
164
- the last node that contains the prefix values. Note that
165
- this API can modify the internal state of the Radix tree.
166
- The last node create a new child if the prefix is shorter
167
- than the last node's value.
274
+ MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
275
+ the concatenated KV cache indices corresponding to the longest
276
+ cached prefix (may be length 0). ``last_device_node`` and
277
+ ``last_host_node`` (currently the same) are the tree node objects
278
+ representing the terminal node of the matched prefix. This method
279
+ may mutate internal structure by splitting an existing node if the
280
+ match ends inside a stored segment.
281
+
282
+ Internal updates:
283
+ * Refreshes access metadata (timestamps) used by the
284
+ configured eviction strategy.
285
+ * If the lookup ends inside a stored segment the node is split once
286
+ to expose a precise boundary; this structural refinement improves
287
+ subsequent match efficiency and does not duplicate data.
168
288
  """
169
- if self.disable or len(key) == 0:
289
+ key.token_ids = self.key_convert_fn(key.token_ids)
290
+
291
+ def empty_match_result():
170
292
  return MatchResult(
171
293
  device_indices=torch.empty(
172
294
  (0,),
@@ -177,10 +299,16 @@ class RadixCache(BasePrefixCache):
177
299
  last_host_node=self.root_node,
178
300
  )
179
301
 
302
+ if self.disable or len(key) == 0:
303
+ return empty_match_result()
304
+
180
305
  if self.page_size != 1:
181
306
  page_aligned_len = len(key) // self.page_size * self.page_size
182
307
  key = key[:page_aligned_len]
183
308
 
309
+ if len(key) == 0:
310
+ return empty_match_result()
311
+
184
312
  value, last_node = self._match_prefix_helper(self.root_node, key)
185
313
  if value:
186
314
  value = torch.cat(value)
@@ -192,47 +320,77 @@ class RadixCache(BasePrefixCache):
192
320
  last_host_node=last_node,
193
321
  )
194
322
 
195
- def insert(self, key: List, value=None, chunked=False):
323
+ def insert(self, key: RadixKey, value=None, chunked=False):
196
324
  if self.disable:
197
325
  return 0
198
326
 
327
+ key.token_ids = self.key_convert_fn(key.token_ids)
328
+
199
329
  if value is None:
200
- value = [x for x in key]
330
+ value = torch.tensor(key.token_ids, dtype=torch.int64)
331
+
332
+ if self.is_eagle:
333
+ # Make sure the value len equal to the EAGLE bigram key len
334
+ value = value[: len(key)]
335
+
201
336
  return self._insert_helper(self.root_node, key, value)
202
337
 
203
- def cache_finished_req(self, req: Req):
338
+ def cache_finished_req(self, req: Req, is_insert: bool = True):
204
339
  """Cache request when it finishes."""
340
+ all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
205
341
  if self.disable:
206
342
  kv_indices = self.req_to_token_pool.req_to_token[
207
- req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
343
+ req.req_pool_idx, :all_token_len
208
344
  ]
209
345
  self.token_to_kv_pool_allocator.free(kv_indices)
210
346
  self.req_to_token_pool.free(req.req_pool_idx)
211
347
  return
212
348
 
213
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
349
+ token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
350
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
351
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
352
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
214
353
  kv_indices = self.req_to_token_pool.req_to_token[
215
- req.req_pool_idx, : len(token_ids)
354
+ req.req_pool_idx, :all_token_len
216
355
  ]
217
356
 
218
357
  if self.page_size != 1:
219
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
358
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
220
359
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
221
360
  dtype=torch.int64, copy=True
222
361
  )
223
- self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
224
362
  else:
225
- page_aligned_len = len(kv_indices)
363
+ page_aligned_len = actual_kv_len
226
364
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
227
365
 
228
- # Radix Cache takes one ref in memory pool
229
- new_prefix_len = self.insert(
230
- token_ids[:page_aligned_len], page_aligned_kv_indices
231
- )
232
- self.token_to_kv_pool_allocator.free(
233
- kv_indices[len(req.prefix_indices) : new_prefix_len]
366
+ page_aligned_token_len = (
367
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
234
368
  )
235
369
 
370
+ old_prefix_len = len(req.prefix_indices)
371
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
372
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
373
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
374
+ old_prefix_len -= 1
375
+
376
+ # Radix Cache takes one ref in memory pool
377
+ if is_insert:
378
+ new_prefix_len = self.insert(
379
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
380
+ page_aligned_kv_indices,
381
+ )
382
+ # Free the duplicates that were already in the tree
383
+ self.token_to_kv_pool_allocator.free(
384
+ kv_indices[old_prefix_len:new_prefix_len]
385
+ )
386
+ else:
387
+ self.token_to_kv_pool_allocator.free(
388
+ kv_indices[old_prefix_len:page_aligned_len]
389
+ )
390
+
391
+ # free the unaligned tail
392
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
393
+
236
394
  # Remove req slot release the cache lock
237
395
  self.req_to_token_pool.free(req.req_pool_idx)
238
396
  self.dec_lock_ref(req.last_node)
@@ -243,45 +401,75 @@ class RadixCache(BasePrefixCache):
243
401
  return
244
402
 
245
403
  token_ids = req.fill_ids
404
+ all_token_len = len(token_ids)
405
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
406
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
407
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
246
408
  kv_indices = self.req_to_token_pool.req_to_token[
247
- req.req_pool_idx, : len(token_ids)
409
+ req.req_pool_idx, :all_token_len
248
410
  ]
249
411
 
250
412
  if self.page_size != 1:
251
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
413
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
252
414
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
253
415
  dtype=torch.int64, copy=True
254
416
  )
255
417
  else:
256
- page_aligned_len = len(kv_indices)
418
+ page_aligned_len = actual_kv_len
257
419
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
258
- page_aligned_token_ids = token_ids[:page_aligned_len]
420
+
421
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
422
+ page_aligned_token_len = (
423
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
424
+ )
425
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
426
+
427
+ old_prefix_len = len(req.prefix_indices)
428
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
429
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
430
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
431
+ old_prefix_len -= 1
259
432
 
260
433
  # Radix Cache takes one ref in memory pool
261
434
  new_prefix_len = self.insert(
262
- page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
263
- )
264
- self.token_to_kv_pool_allocator.free(
265
- kv_indices[len(req.prefix_indices) : new_prefix_len]
435
+ RadixKey(page_aligned_token_ids, req.extra_key),
436
+ page_aligned_kv_indices,
437
+ chunked=chunked,
266
438
  )
439
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
267
440
 
268
441
  # The prefix indices could be updated, reuse it
269
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
442
+ new_indices, new_last_node, _, _ = self.match_prefix(
443
+ RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
444
+ )
270
445
  self.req_to_token_pool.write(
271
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
272
- new_indices[len(req.prefix_indices) :],
446
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
447
+ new_indices[old_prefix_len:],
273
448
  )
274
449
 
450
+ # The last_matched_prefix_len is not always equal to len(req.prefix_indices)
451
+ # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
452
+ # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
453
+ # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
454
+ req.last_matched_prefix_len = len(new_indices)
455
+
275
456
  self.dec_lock_ref(req.last_node)
276
457
  self.inc_lock_ref(new_last_node)
277
458
 
278
459
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
279
460
  if self.page_size != 1:
461
+ # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
280
462
  req.prefix_indices = torch.cat(
281
463
  [new_indices, kv_indices[len(new_indices) :]]
282
464
  )
283
465
  else:
284
- req.prefix_indices = new_indices
466
+ if self.is_eagle:
467
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
468
+ req.prefix_indices = torch.cat(
469
+ [new_indices, kv_indices[actual_kv_len:]]
470
+ )
471
+ else:
472
+ req.prefix_indices = new_indices
285
473
  req.last_node = new_last_node
286
474
 
287
475
  def pretty_print(self):
@@ -296,11 +484,14 @@ class RadixCache(BasePrefixCache):
296
484
  return
297
485
 
298
486
  leaves = self._collect_leaves()
299
- heapq.heapify(leaves)
487
+ eviction_heap = [
488
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
489
+ ]
490
+ heapq.heapify(eviction_heap)
300
491
 
301
492
  num_evicted = 0
302
- while num_evicted < num_tokens and len(leaves):
303
- x = heapq.heappop(leaves)
493
+ while num_evicted < num_tokens and len(eviction_heap):
494
+ _priority, x = heapq.heappop(eviction_heap)
304
495
 
305
496
  if x == self.root_node:
306
497
  break
@@ -312,7 +503,8 @@ class RadixCache(BasePrefixCache):
312
503
  self._delete_leaf(x)
313
504
 
314
505
  if len(x.parent.children) == 0:
315
- heapq.heappush(leaves, x.parent)
506
+ new_priority = self.eviction_strategy.get_priority(x.parent)
507
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
316
508
 
317
509
  self._record_remove_event(x)
318
510
 
@@ -323,9 +515,9 @@ class RadixCache(BasePrefixCache):
323
515
  delta = 0
324
516
  while node != self.root_node:
325
517
  if node.lock_ref == 0:
326
- self.evictable_size_ -= len(node.value)
327
- self.protected_size_ += len(node.value)
328
- delta -= len(node.value)
518
+ self.evictable_size_ -= len(node.key)
519
+ self.protected_size_ += len(node.key)
520
+ delta -= len(node.key)
329
521
  node.lock_ref += 1
330
522
  node = node.parent
331
523
  return delta
@@ -337,9 +529,9 @@ class RadixCache(BasePrefixCache):
337
529
  delta = 0
338
530
  while node != self.root_node:
339
531
  if node.lock_ref == 1:
340
- self.evictable_size_ += len(node.value)
341
- self.protected_size_ -= len(node.value)
342
- delta += len(node.value)
532
+ self.evictable_size_ += len(node.key)
533
+ self.protected_size_ -= len(node.key)
534
+ delta += len(node.key)
343
535
  node.lock_ref -= 1
344
536
  node = node.parent
345
537
  return delta
@@ -364,7 +556,7 @@ class RadixCache(BasePrefixCache):
364
556
 
365
557
  ##### Internal Helper Functions #####
366
558
 
367
- def _match_prefix_helper(self, node: TreeNode, key: List):
559
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
368
560
  node.last_access_time = time.monotonic()
369
561
 
370
562
  child_key = self.get_child_key_fn(key)
@@ -389,7 +581,7 @@ class RadixCache(BasePrefixCache):
389
581
 
390
582
  return value, node
391
583
 
392
- def _split_node(self, key, child: TreeNode, split_len: int):
584
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
393
585
  # new_node -> child
394
586
  self._record_remove_event(child)
395
587
  new_node = TreeNode()
@@ -408,7 +600,7 @@ class RadixCache(BasePrefixCache):
408
600
 
409
601
  return new_node
410
602
 
411
- def _insert_helper(self, node: TreeNode, key: List, value):
603
+ def _insert_helper(self, node: TreeNode, key: RadixKey, value):
412
604
  node.last_access_time = time.monotonic()
413
605
  if len(key) == 0:
414
606
  return 0
@@ -437,7 +629,7 @@ class RadixCache(BasePrefixCache):
437
629
  new_node.key = key
438
630
  new_node.value = value
439
631
  node.children[child_key] = new_node
440
- self.evictable_size_ += len(value)
632
+ self.evictable_size_ += len(key)
441
633
  self._record_store_event(new_node)
442
634
  return total_prefix_length
443
635
 
@@ -449,7 +641,7 @@ class RadixCache(BasePrefixCache):
449
641
  print(
450
642
  " " * current_indent,
451
643
  len(current_node.key),
452
- current_node.key[:10],
644
+ current_node.key.token_ids[:10],
453
645
  f"r={current_node.lock_ref}",
454
646
  )
455
647
  for key, child in current_node.children.items():
@@ -501,11 +693,11 @@ class RadixCache(BasePrefixCache):
501
693
  last_page_start = (
502
694
  (len(node.parent.key) - 1) // self.page_size
503
695
  ) * self.page_size
504
- parent_parent_tokens = node.parent.key[last_page_start:]
696
+ parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
505
697
  parent_block_hash = hash(tuple(parent_parent_tokens))
506
698
 
507
699
  for start in range(0, len(node.key), self.page_size):
508
- page_tokens = node.key[start : start + self.page_size]
700
+ page_tokens = node.key.token_ids[start : start + self.page_size]
509
701
  if not page_tokens:
510
702
  continue
511
703
 
@@ -528,7 +720,7 @@ class RadixCache(BasePrefixCache):
528
720
  # One BlockRemoved per chunk.
529
721
  if self.enable_kv_cache_events:
530
722
  for start in range(0, len(node.key), self.page_size):
531
- page_tokens = node.key[start : start + self.page_size]
723
+ page_tokens = node.key.token_ids[start : start + self.page_size]
532
724
  if not page_tokens:
533
725
  continue
534
726
  block_hash = hash(tuple(page_tokens))
@@ -554,19 +746,12 @@ class RadixCache(BasePrefixCache):
554
746
  if __name__ == "__main__":
555
747
  tree = RadixCache(None, None, page_size=1, disable=False)
556
748
 
557
- tree.insert("Hello")
558
- tree.insert("Hello")
559
- tree.insert("Hello_L.A.!")
560
- # tree.insert("Hello_world! Happy")
561
- # tree.insert("I love you!")
749
+ # Example token id sequences (as lists of ints)
750
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
751
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
752
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
753
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
754
+ tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
562
755
  tree.pretty_print()
563
756
 
564
- # print(tree.match_prefix("I love you! aha"))
565
-
566
- # def evict_callback(x):
567
- # print("evict", x)
568
- # return len(x)
569
-
570
- # tree.evict(5, evict_callback)
571
- # tree.evict(10, evict_callback)
572
- # tree.pretty_print()
757
+ print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))