sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional
26
+ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
27
27
 
28
28
  import torch
29
29
 
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
34
34
  )
35
35
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
36
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
37
38
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from sglang.srt.managers.schedule_batch import Req
41
42
 
42
43
 
44
+ class RadixKey:
45
+
46
+ def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
47
+ # token ids sequence
48
+ self.token_ids = token_ids
49
+ # extra key (e.g. lora_id, cache_salt)
50
+ self.extra_key = extra_key
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.token_ids)
54
+
55
+ def __iter__(self) -> Iterator[int]:
56
+ return iter(self.token_ids)
57
+
58
+ def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
59
+ if isinstance(idx, slice):
60
+ return RadixKey(self.token_ids[idx], self.extra_key)
61
+ return RadixKey([self.token_ids[idx]], self.extra_key)
62
+
63
+ def __repr__(self) -> str:
64
+ preview = self.token_ids[:10]
65
+ return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
66
+
67
+
43
68
  class TreeNode:
44
69
 
45
70
  counter = 0
@@ -47,14 +72,12 @@ class TreeNode:
47
72
  def __init__(self, id: Optional[int] = None):
48
73
  self.children = defaultdict(TreeNode)
49
74
  self.parent: TreeNode = None
50
- self.key: List[int] = None
75
+ self.key: RadixKey = None
51
76
  self.value: Optional[torch.Tensor] = None
52
77
  self.lock_ref = 0
53
78
  self.last_access_time = time.monotonic()
54
79
 
55
80
  self.hit_count = 0
56
- # indicating the node is loading KV cache from host
57
- self.loading = False
58
81
  # indicating the node is locked to protect from eviction
59
82
  # incremented when the node is referenced by a storage operation
60
83
  self.host_ref_counter = 0
@@ -95,27 +118,57 @@ class TreeNode:
95
118
  return self.last_access_time < other.last_access_time
96
119
 
97
120
 
98
- def _key_match_page_size1(key0: List, key1: List):
121
+ def _check_extra_key(key0: RadixKey, key1: RadixKey):
122
+ if key0.extra_key != key1.extra_key:
123
+ raise ValueError(
124
+ f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
125
+ )
126
+
127
+
128
+ def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
129
+ _check_extra_key(key0, key1)
99
130
  i = 0
100
- for k0, k1 in zip(key0, key1):
131
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
101
132
  if k0 != k1:
102
133
  break
103
134
  i += 1
104
135
  return i
105
136
 
106
137
 
107
- def _key_match_paged(key0: List, key1: List, page_size: int):
138
+ def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
139
+ _check_extra_key(key0, key1)
108
140
  min_len = min(len(key0), len(key1))
109
141
 
110
142
  i = 0
111
143
  while i < min_len:
112
- if key0[i : i + page_size] != key1[i : i + page_size]:
144
+ if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
113
145
  break
114
146
  i += page_size
115
147
 
116
148
  return i
117
149
 
118
150
 
151
+ def get_child_key(key: RadixKey, page_size: int = 1):
152
+ if page_size == 1:
153
+ plain_key = key.token_ids[0]
154
+ else:
155
+ plain_key = tuple(key.token_ids[:page_size])
156
+ if key.extra_key is None:
157
+ return plain_key
158
+ else:
159
+ return (key.extra_key, plain_key)
160
+
161
+
162
+ def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
163
+ # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
164
+ # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
165
+ if len(tokens) < 2:
166
+ return []
167
+ if isinstance(tokens[0], tuple):
168
+ return tokens
169
+ return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
170
+
171
+
119
172
  class RadixCache(BasePrefixCache):
120
173
  def __init__(
121
174
  self,
@@ -124,6 +177,8 @@ class RadixCache(BasePrefixCache):
124
177
  page_size: int,
125
178
  disable: bool = False,
126
179
  enable_kv_cache_events: bool = False,
180
+ eviction_policy: str = "lru",
181
+ is_eagle: bool = False,
127
182
  ):
128
183
  self.req_to_token_pool = req_to_token_pool
129
184
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -131,6 +186,7 @@ class RadixCache(BasePrefixCache):
131
186
  self.disable = disable
132
187
  self.enable_kv_cache_events = enable_kv_cache_events
133
188
  self.kv_event_queue = []
189
+ self.is_eagle = is_eagle
134
190
 
135
191
  if self.token_to_kv_pool_allocator:
136
192
  self.device = self.token_to_kv_pool_allocator.device
@@ -139,17 +195,31 @@ class RadixCache(BasePrefixCache):
139
195
 
140
196
  if self.page_size == 1:
141
197
  self.key_match_fn = _key_match_page_size1
142
- self.get_child_key_fn = lambda key: key[0]
198
+ self.get_child_key_fn = get_child_key
143
199
  else:
144
200
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
145
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
201
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
202
+
203
+ if is_eagle:
204
+ self.key_convert_fn = _convert_to_bigram_key
205
+ else:
206
+ self.key_convert_fn = lambda key: key
207
+
208
+ if eviction_policy.lower() == "lru":
209
+ self.eviction_strategy: EvictionStrategy = LRUStrategy()
210
+ elif eviction_policy.lower() == "lfu":
211
+ self.eviction_strategy: EvictionStrategy = LFUStrategy()
212
+ else:
213
+ raise ValueError(
214
+ f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
215
+ )
146
216
  self.reset()
147
217
 
148
218
  ##### Public API #####
149
219
 
150
220
  def reset(self):
151
221
  self.root_node = TreeNode()
152
- self.root_node.key = []
222
+ self.root_node.key = RadixKey(token_ids=[], extra_key=None)
153
223
  self.root_node.value = []
154
224
  self.root_node.host_value = []
155
225
  self.root_node.lock_ref = 1
@@ -157,18 +227,47 @@ class RadixCache(BasePrefixCache):
157
227
  self.protected_size_ = 0
158
228
  self._record_all_cleared_event()
159
229
 
160
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
161
- """Find the matching prefix from the radix tree.
230
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
231
+ """Find the longest cached prefix of ``key`` in the radix tree.
232
+
233
+ The logical namespace for prefix matching is determined by both the
234
+ token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
235
+ Entries that share identical leading token ids but have *different*
236
+ ``extra_key`` values are intentionally kept disjoint and never share
237
+ prefix nodes. This is useful to:
238
+
239
+ * Isolate KV cache lines for different LoRA / adapter IDs.
240
+ * Separate requests that intentionally should not share state (e.g.,
241
+ different sampling salt, cache version, or retrieval augmentation
242
+ context) by supplying a distinct ``extra_key``.
243
+
162
244
  Args:
163
- key: A list of token IDs to find a matching prefix.
245
+ key (RadixKey): The lookup key containing a list of token ids and an
246
+ optional ``extra_key`` namespace tag. If ``page_size > 1`` the
247
+ length is internally truncated to a multiple of ``page_size``
248
+ before matching. Passing an empty key returns an empty result
249
+ with the root as the last node.
250
+ **kwargs: Reserved for future extensions (ignored currently).
251
+
164
252
  Returns:
165
- A tuple of a tensor of matching prefix token IDs and
166
- the last node that contains the prefix values. Note that
167
- this API can modify the internal state of the Radix tree.
168
- The last node create a new child if the prefix is shorter
169
- than the last node's value.
253
+ MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
254
+ the concatenated KV cache indices corresponding to the longest
255
+ cached prefix (may be length 0). ``last_device_node`` and
256
+ ``last_host_node`` (currently the same) are the tree node objects
257
+ representing the terminal node of the matched prefix. This method
258
+ may mutate internal structure by splitting an existing node if the
259
+ match ends inside a stored segment.
260
+
261
+ Internal updates:
262
+ * Refreshes access metadata (timestamps) used by the
263
+ configured eviction strategy.
264
+ * If the lookup ends inside a stored segment the node is split once
265
+ to expose a precise boundary; this structural refinement improves
266
+ subsequent match efficiency and does not duplicate data.
170
267
  """
171
- if self.disable or len(key) == 0:
268
+ key.token_ids = self.key_convert_fn(key.token_ids)
269
+
270
+ def empty_match_result():
172
271
  return MatchResult(
173
272
  device_indices=torch.empty(
174
273
  (0,),
@@ -179,10 +278,16 @@ class RadixCache(BasePrefixCache):
179
278
  last_host_node=self.root_node,
180
279
  )
181
280
 
281
+ if self.disable or len(key) == 0:
282
+ return empty_match_result()
283
+
182
284
  if self.page_size != 1:
183
285
  page_aligned_len = len(key) // self.page_size * self.page_size
184
286
  key = key[:page_aligned_len]
185
287
 
288
+ if len(key) == 0:
289
+ return empty_match_result()
290
+
186
291
  value, last_node = self._match_prefix_helper(self.root_node, key)
187
292
  if value:
188
293
  value = torch.cat(value)
@@ -194,12 +299,19 @@ class RadixCache(BasePrefixCache):
194
299
  last_host_node=last_node,
195
300
  )
196
301
 
197
- def insert(self, key: List, value=None, chunked=False):
302
+ def insert(self, key: RadixKey, value=None, chunked=False):
198
303
  if self.disable:
199
304
  return 0
200
305
 
306
+ key.token_ids = self.key_convert_fn(key.token_ids)
307
+
201
308
  if value is None:
202
- value = [x for x in key]
309
+ value = torch.tensor(key.token_ids, dtype=torch.int64)
310
+
311
+ if self.is_eagle:
312
+ # Make sure the value len equal to the EAGLE bigram key len
313
+ value = value[: len(key)]
314
+
203
315
  return self._insert_helper(self.root_node, key, value)
204
316
 
205
317
  def cache_finished_req(self, req: Req):
@@ -213,27 +325,42 @@ class RadixCache(BasePrefixCache):
213
325
  return
214
326
 
215
327
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328
+ all_token_len = len(token_ids)
329
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
330
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
331
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
216
332
  kv_indices = self.req_to_token_pool.req_to_token[
217
- req.req_pool_idx, : len(token_ids)
333
+ req.req_pool_idx, :all_token_len
218
334
  ]
219
335
 
220
336
  if self.page_size != 1:
221
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
337
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
222
338
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
223
339
  dtype=torch.int64, copy=True
224
340
  )
225
341
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
226
342
  else:
227
- page_aligned_len = len(kv_indices)
343
+ page_aligned_len = actual_kv_len
228
344
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
345
+ if self.is_eagle:
346
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
347
+
348
+ page_aligned_token_len = (
349
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
350
+ )
351
+
352
+ old_prefix_len = len(req.prefix_indices)
353
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
354
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
355
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
356
+ old_prefix_len -= 1
229
357
 
230
358
  # Radix Cache takes one ref in memory pool
231
359
  new_prefix_len = self.insert(
232
- token_ids[:page_aligned_len], page_aligned_kv_indices
233
- )
234
- self.token_to_kv_pool_allocator.free(
235
- kv_indices[len(req.prefix_indices) : new_prefix_len]
360
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
361
+ page_aligned_kv_indices,
236
362
  )
363
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
237
364
 
238
365
  # Remove req slot release the cache lock
239
366
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -245,45 +372,75 @@ class RadixCache(BasePrefixCache):
245
372
  return
246
373
 
247
374
  token_ids = req.fill_ids
375
+ all_token_len = len(token_ids)
376
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
377
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
378
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
248
379
  kv_indices = self.req_to_token_pool.req_to_token[
249
- req.req_pool_idx, : len(token_ids)
380
+ req.req_pool_idx, :all_token_len
250
381
  ]
251
382
 
252
383
  if self.page_size != 1:
253
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
384
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
254
385
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
255
386
  dtype=torch.int64, copy=True
256
387
  )
257
388
  else:
258
- page_aligned_len = len(kv_indices)
389
+ page_aligned_len = actual_kv_len
259
390
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
260
- page_aligned_token_ids = token_ids[:page_aligned_len]
391
+
392
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
393
+ page_aligned_token_len = (
394
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
395
+ )
396
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
397
+
398
+ old_prefix_len = len(req.prefix_indices)
399
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
400
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
401
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
402
+ old_prefix_len -= 1
261
403
 
262
404
  # Radix Cache takes one ref in memory pool
263
405
  new_prefix_len = self.insert(
264
- page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
265
- )
266
- self.token_to_kv_pool_allocator.free(
267
- kv_indices[len(req.prefix_indices) : new_prefix_len]
406
+ RadixKey(page_aligned_token_ids, req.extra_key),
407
+ page_aligned_kv_indices,
408
+ chunked=chunked,
268
409
  )
410
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
269
411
 
270
412
  # The prefix indices could be updated, reuse it
271
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
413
+ new_indices, new_last_node, _, _ = self.match_prefix(
414
+ RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
415
+ )
272
416
  self.req_to_token_pool.write(
273
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
274
- new_indices[len(req.prefix_indices) :],
417
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
418
+ new_indices[old_prefix_len:],
275
419
  )
276
420
 
421
+ # The last_matched_prefix_len is not always equal to len(req.prefix_indices)
422
+ # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
423
+ # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
424
+ # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
425
+ req.last_matched_prefix_len = len(new_indices)
426
+
277
427
  self.dec_lock_ref(req.last_node)
278
428
  self.inc_lock_ref(new_last_node)
279
429
 
280
430
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
281
431
  if self.page_size != 1:
432
+ # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
282
433
  req.prefix_indices = torch.cat(
283
434
  [new_indices, kv_indices[len(new_indices) :]]
284
435
  )
285
436
  else:
286
- req.prefix_indices = new_indices
437
+ if self.is_eagle:
438
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
439
+ req.prefix_indices = torch.cat(
440
+ [new_indices, kv_indices[actual_kv_len:]]
441
+ )
442
+ else:
443
+ req.prefix_indices = new_indices
287
444
  req.last_node = new_last_node
288
445
 
289
446
  def pretty_print(self):
@@ -298,11 +455,14 @@ class RadixCache(BasePrefixCache):
298
455
  return
299
456
 
300
457
  leaves = self._collect_leaves()
301
- heapq.heapify(leaves)
458
+ eviction_heap = [
459
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
460
+ ]
461
+ heapq.heapify(eviction_heap)
302
462
 
303
463
  num_evicted = 0
304
- while num_evicted < num_tokens and len(leaves):
305
- x = heapq.heappop(leaves)
464
+ while num_evicted < num_tokens and len(eviction_heap):
465
+ _priority, x = heapq.heappop(eviction_heap)
306
466
 
307
467
  if x == self.root_node:
308
468
  break
@@ -314,7 +474,8 @@ class RadixCache(BasePrefixCache):
314
474
  self._delete_leaf(x)
315
475
 
316
476
  if len(x.parent.children) == 0:
317
- heapq.heappush(leaves, x.parent)
477
+ new_priority = self.eviction_strategy.get_priority(x.parent)
478
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
318
479
 
319
480
  self._record_remove_event(x)
320
481
 
@@ -325,9 +486,9 @@ class RadixCache(BasePrefixCache):
325
486
  delta = 0
326
487
  while node != self.root_node:
327
488
  if node.lock_ref == 0:
328
- self.evictable_size_ -= len(node.value)
329
- self.protected_size_ += len(node.value)
330
- delta -= len(node.value)
489
+ self.evictable_size_ -= len(node.key)
490
+ self.protected_size_ += len(node.key)
491
+ delta -= len(node.key)
331
492
  node.lock_ref += 1
332
493
  node = node.parent
333
494
  return delta
@@ -339,9 +500,9 @@ class RadixCache(BasePrefixCache):
339
500
  delta = 0
340
501
  while node != self.root_node:
341
502
  if node.lock_ref == 1:
342
- self.evictable_size_ += len(node.value)
343
- self.protected_size_ -= len(node.value)
344
- delta += len(node.value)
503
+ self.evictable_size_ += len(node.key)
504
+ self.protected_size_ -= len(node.key)
505
+ delta += len(node.key)
345
506
  node.lock_ref -= 1
346
507
  node = node.parent
347
508
  return delta
@@ -366,7 +527,7 @@ class RadixCache(BasePrefixCache):
366
527
 
367
528
  ##### Internal Helper Functions #####
368
529
 
369
- def _match_prefix_helper(self, node: TreeNode, key: List):
530
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
370
531
  node.last_access_time = time.monotonic()
371
532
 
372
533
  child_key = self.get_child_key_fn(key)
@@ -391,7 +552,7 @@ class RadixCache(BasePrefixCache):
391
552
 
392
553
  return value, node
393
554
 
394
- def _split_node(self, key, child: TreeNode, split_len: int):
555
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
395
556
  # new_node -> child
396
557
  self._record_remove_event(child)
397
558
  new_node = TreeNode()
@@ -410,7 +571,7 @@ class RadixCache(BasePrefixCache):
410
571
 
411
572
  return new_node
412
573
 
413
- def _insert_helper(self, node: TreeNode, key: List, value):
574
+ def _insert_helper(self, node: TreeNode, key: RadixKey, value):
414
575
  node.last_access_time = time.monotonic()
415
576
  if len(key) == 0:
416
577
  return 0
@@ -439,7 +600,7 @@ class RadixCache(BasePrefixCache):
439
600
  new_node.key = key
440
601
  new_node.value = value
441
602
  node.children[child_key] = new_node
442
- self.evictable_size_ += len(value)
603
+ self.evictable_size_ += len(key)
443
604
  self._record_store_event(new_node)
444
605
  return total_prefix_length
445
606
 
@@ -451,7 +612,7 @@ class RadixCache(BasePrefixCache):
451
612
  print(
452
613
  " " * current_indent,
453
614
  len(current_node.key),
454
- current_node.key[:10],
615
+ current_node.key.token_ids[:10],
455
616
  f"r={current_node.lock_ref}",
456
617
  )
457
618
  for key, child in current_node.children.items():
@@ -503,11 +664,11 @@ class RadixCache(BasePrefixCache):
503
664
  last_page_start = (
504
665
  (len(node.parent.key) - 1) // self.page_size
505
666
  ) * self.page_size
506
- parent_parent_tokens = node.parent.key[last_page_start:]
667
+ parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
507
668
  parent_block_hash = hash(tuple(parent_parent_tokens))
508
669
 
509
670
  for start in range(0, len(node.key), self.page_size):
510
- page_tokens = node.key[start : start + self.page_size]
671
+ page_tokens = node.key.token_ids[start : start + self.page_size]
511
672
  if not page_tokens:
512
673
  continue
513
674
 
@@ -530,7 +691,7 @@ class RadixCache(BasePrefixCache):
530
691
  # One BlockRemoved per chunk.
531
692
  if self.enable_kv_cache_events:
532
693
  for start in range(0, len(node.key), self.page_size):
533
- page_tokens = node.key[start : start + self.page_size]
694
+ page_tokens = node.key.token_ids[start : start + self.page_size]
534
695
  if not page_tokens:
535
696
  continue
536
697
  block_hash = hash(tuple(page_tokens))
@@ -556,19 +717,12 @@ class RadixCache(BasePrefixCache):
556
717
  if __name__ == "__main__":
557
718
  tree = RadixCache(None, None, page_size=1, disable=False)
558
719
 
559
- tree.insert("Hello")
560
- tree.insert("Hello")
561
- tree.insert("Hello_L.A.!")
562
- # tree.insert("Hello_world! Happy")
563
- # tree.insert("I love you!")
720
+ # Example token id sequences (as lists of ints)
721
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
722
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
723
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
724
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
725
+ tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
564
726
  tree.pretty_print()
565
727
 
566
- # print(tree.match_prefix("I love you! aha"))
567
-
568
- # def evict_callback(x):
569
- # print("evict", x)
570
- # return len(x)
571
-
572
- # tree.evict(5, evict_callback)
573
- # tree.evict(10, evict_callback)
574
- # tree.pretty_print()
728
+ print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
13
13
  TreeNodeCpp,
14
14
  )
15
15
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+ from sglang.srt.mem_cache.radix_cache import RadixKey
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.managers.schedule_batch import Req
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
93
94
  raise NotImplementedError("Host cache is not supported yet")
94
95
  self.tree.reset()
95
96
 
96
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
97
98
  device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
- self.tree.match_prefix(key)
99
+ self.tree.match_prefix(key.token_ids)
99
100
  )
100
101
  return MatchResult(
101
102
  device_indices=self._merge_tensor(device_indices_vec),
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
104
105
  host_hit_length=host_indices_length,
105
106
  )
106
107
 
107
- def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
108
109
  """
109
110
  Insert a key-value pair into the radix tree.
110
111
  Args:
111
- key (List[int]): The key to insert, represented as a list of integers.
112
+ key (RadixKey): The key to insert, represented as a RadixKey.
112
113
  value (torch.Tensor): The value to associate with the key.
113
114
  Returns:
114
115
  int: Number of device indices that were already present in the tree before the insertion.
115
116
  """
116
- ongoing_write, length = self.tree.writing_through(key, value)
117
+ ongoing_write, length = self.tree.writing_through(key.token_ids, value)
117
118
  if self.cache_controller is None:
118
119
  assert len(ongoing_write) == 0, "Implementation error"
119
120
  return length
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
160
161
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
162
  # it will automatically align them, but length of them should be equal
162
163
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
- new_prefix_len = self._insert(token_ids, kv_indices)
164
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
164
165
 
165
166
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
167
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
191
192
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
193
  # it will automatically align them, but length of them should be equal
193
194
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
- new_prefix_len = self._insert(token_ids, kv_indices)
195
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
195
196
 
196
197
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
198
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
199
 
199
200
  # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
201
  # The prefix indices need to updated to reuse the kv indices in the pool
201
- new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
202
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
203
+ RadixKey(token_ids, req.extra_key).token_ids
204
+ )
202
205
  new_indices = self._merge_tensor(new_indices_vec)
203
206
  assert new_prefix_len <= len(new_indices)
204
207
 
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ """Storage backend module for SGLang HiCache."""
5
+
6
+ from .backend_factory import StorageBackendFactory
7
+
8
+ __all__ = [
9
+ "StorageBackendFactory",
10
+ ]