sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional
26
+ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
27
27
 
28
28
  import torch
29
29
 
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
34
34
  )
35
35
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
36
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
37
38
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from sglang.srt.managers.schedule_batch import Req
41
42
 
42
43
 
44
+ class RadixKey:
45
+
46
+ def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
47
+ # token ids sequence
48
+ self.token_ids = token_ids
49
+ # extra key (e.g. lora_id, cache_salt)
50
+ self.extra_key = extra_key
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.token_ids)
54
+
55
+ def __iter__(self) -> Iterator[int]:
56
+ return iter(self.token_ids)
57
+
58
+ def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
59
+ if isinstance(idx, slice):
60
+ return RadixKey(self.token_ids[idx], self.extra_key)
61
+ return RadixKey([self.token_ids[idx]], self.extra_key)
62
+
63
+ def __repr__(self) -> str:
64
+ preview = self.token_ids[:10]
65
+ return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
66
+
67
+
43
68
  class TreeNode:
44
69
 
45
70
  counter = 0
@@ -47,14 +72,12 @@ class TreeNode:
47
72
  def __init__(self, id: Optional[int] = None):
48
73
  self.children = defaultdict(TreeNode)
49
74
  self.parent: TreeNode = None
50
- self.key: List[int] = None
75
+ self.key: RadixKey = None
51
76
  self.value: Optional[torch.Tensor] = None
52
77
  self.lock_ref = 0
53
78
  self.last_access_time = time.monotonic()
54
79
 
55
80
  self.hit_count = 0
56
- # indicating the node is loading KV cache from host
57
- self.loading = False
58
81
  # indicating the node is locked to protect from eviction
59
82
  # incremented when the node is referenced by a storage operation
60
83
  self.host_ref_counter = 0
@@ -95,27 +118,57 @@ class TreeNode:
95
118
  return self.last_access_time < other.last_access_time
96
119
 
97
120
 
98
- def _key_match_page_size1(key0: List, key1: List):
121
+ def _check_extra_key(key0: RadixKey, key1: RadixKey):
122
+ if key0.extra_key != key1.extra_key:
123
+ raise ValueError(
124
+ f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
125
+ )
126
+
127
+
128
+ def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
129
+ _check_extra_key(key0, key1)
99
130
  i = 0
100
- for k0, k1 in zip(key0, key1):
131
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
101
132
  if k0 != k1:
102
133
  break
103
134
  i += 1
104
135
  return i
105
136
 
106
137
 
107
- def _key_match_paged(key0: List, key1: List, page_size: int):
138
+ def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
139
+ _check_extra_key(key0, key1)
108
140
  min_len = min(len(key0), len(key1))
109
141
 
110
142
  i = 0
111
143
  while i < min_len:
112
- if key0[i : i + page_size] != key1[i : i + page_size]:
144
+ if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
113
145
  break
114
146
  i += page_size
115
147
 
116
148
  return i
117
149
 
118
150
 
151
+ def get_child_key(key: RadixKey, page_size: int = 1):
152
+ if page_size == 1:
153
+ plain_key = key.token_ids[0]
154
+ else:
155
+ plain_key = tuple(key.token_ids[:page_size])
156
+ if key.extra_key is None:
157
+ return plain_key
158
+ else:
159
+ return (key.extra_key, plain_key)
160
+
161
+
162
+ def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
163
+ # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
164
+ # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
165
+ if len(tokens) < 2:
166
+ return []
167
+ if isinstance(tokens[0], tuple):
168
+ return tokens
169
+ return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
170
+
171
+
119
172
  class RadixCache(BasePrefixCache):
120
173
  def __init__(
121
174
  self,
@@ -124,6 +177,8 @@ class RadixCache(BasePrefixCache):
124
177
  page_size: int,
125
178
  disable: bool = False,
126
179
  enable_kv_cache_events: bool = False,
180
+ eviction_policy: str = "lru",
181
+ is_eagle: bool = False,
127
182
  ):
128
183
  self.req_to_token_pool = req_to_token_pool
129
184
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -131,6 +186,7 @@ class RadixCache(BasePrefixCache):
131
186
  self.disable = disable
132
187
  self.enable_kv_cache_events = enable_kv_cache_events
133
188
  self.kv_event_queue = []
189
+ self.is_eagle = is_eagle
134
190
 
135
191
  if self.token_to_kv_pool_allocator:
136
192
  self.device = self.token_to_kv_pool_allocator.device
@@ -139,17 +195,31 @@ class RadixCache(BasePrefixCache):
139
195
 
140
196
  if self.page_size == 1:
141
197
  self.key_match_fn = _key_match_page_size1
142
- self.get_child_key_fn = lambda key: key[0]
198
+ self.get_child_key_fn = get_child_key
143
199
  else:
144
200
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
145
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
201
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
202
+
203
+ if is_eagle:
204
+ self.key_convert_fn = _convert_to_bigram_key
205
+ else:
206
+ self.key_convert_fn = lambda key: key
207
+
208
+ if eviction_policy.lower() == "lru":
209
+ self.eviction_strategy: EvictionStrategy = LRUStrategy()
210
+ elif eviction_policy.lower() == "lfu":
211
+ self.eviction_strategy: EvictionStrategy = LFUStrategy()
212
+ else:
213
+ raise ValueError(
214
+ f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
215
+ )
146
216
  self.reset()
147
217
 
148
218
  ##### Public API #####
149
219
 
150
220
  def reset(self):
151
221
  self.root_node = TreeNode()
152
- self.root_node.key = []
222
+ self.root_node.key = RadixKey(token_ids=[], extra_key=None)
153
223
  self.root_node.value = []
154
224
  self.root_node.host_value = []
155
225
  self.root_node.lock_ref = 1
@@ -157,18 +227,47 @@ class RadixCache(BasePrefixCache):
157
227
  self.protected_size_ = 0
158
228
  self._record_all_cleared_event()
159
229
 
160
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
161
- """Find the matching prefix from the radix tree.
230
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
231
+ """Find the longest cached prefix of ``key`` in the radix tree.
232
+
233
+ The logical namespace for prefix matching is determined by both the
234
+ token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
235
+ Entries that share identical leading token ids but have *different*
236
+ ``extra_key`` values are intentionally kept disjoint and never share
237
+ prefix nodes. This is useful to:
238
+
239
+ * Isolate KV cache lines for different LoRA / adapter IDs.
240
+ * Separate requests that intentionally should not share state (e.g.,
241
+ different sampling salt, cache version, or retrieval augmentation
242
+ context) by supplying a distinct ``extra_key``.
243
+
162
244
  Args:
163
- key: A list of token IDs to find a matching prefix.
245
+ key (RadixKey): The lookup key containing a list of token ids and an
246
+ optional ``extra_key`` namespace tag. If ``page_size > 1`` the
247
+ length is internally truncated to a multiple of ``page_size``
248
+ before matching. Passing an empty key returns an empty result
249
+ with the root as the last node.
250
+ **kwargs: Reserved for future extensions (ignored currently).
251
+
164
252
  Returns:
165
- A tuple of a tensor of matching prefix token IDs and
166
- the last node that contains the prefix values. Note that
167
- this API can modify the internal state of the Radix tree.
168
- The last node create a new child if the prefix is shorter
169
- than the last node's value.
253
+ MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
254
+ the concatenated KV cache indices corresponding to the longest
255
+ cached prefix (may be length 0). ``last_device_node`` and
256
+ ``last_host_node`` (currently the same) are the tree node objects
257
+ representing the terminal node of the matched prefix. This method
258
+ may mutate internal structure by splitting an existing node if the
259
+ match ends inside a stored segment.
260
+
261
+ Internal updates:
262
+ * Refreshes access metadata (timestamps) used by the
263
+ configured eviction strategy.
264
+ * If the lookup ends inside a stored segment the node is split once
265
+ to expose a precise boundary; this structural refinement improves
266
+ subsequent match efficiency and does not duplicate data.
170
267
  """
171
- if self.disable or len(key) == 0:
268
+ key.token_ids = self.key_convert_fn(key.token_ids)
269
+
270
+ def empty_match_result():
172
271
  return MatchResult(
173
272
  device_indices=torch.empty(
174
273
  (0,),
@@ -179,10 +278,16 @@ class RadixCache(BasePrefixCache):
179
278
  last_host_node=self.root_node,
180
279
  )
181
280
 
281
+ if self.disable or len(key) == 0:
282
+ return empty_match_result()
283
+
182
284
  if self.page_size != 1:
183
285
  page_aligned_len = len(key) // self.page_size * self.page_size
184
286
  key = key[:page_aligned_len]
185
287
 
288
+ if len(key) == 0:
289
+ return empty_match_result()
290
+
186
291
  value, last_node = self._match_prefix_helper(self.root_node, key)
187
292
  if value:
188
293
  value = torch.cat(value)
@@ -194,12 +299,19 @@ class RadixCache(BasePrefixCache):
194
299
  last_host_node=last_node,
195
300
  )
196
301
 
197
- def insert(self, key: List, value=None, chunked=False):
302
+ def insert(self, key: RadixKey, value=None, chunked=False):
198
303
  if self.disable:
199
304
  return 0
200
305
 
306
+ key.token_ids = self.key_convert_fn(key.token_ids)
307
+
201
308
  if value is None:
202
- value = [x for x in key]
309
+ value = torch.tensor(key.token_ids, dtype=torch.int64)
310
+
311
+ if self.is_eagle:
312
+ # Make sure the value len equal to the EAGLE bigram key len
313
+ value = value[: len(key)]
314
+
203
315
  return self._insert_helper(self.root_node, key, value)
204
316
 
205
317
  def cache_finished_req(self, req: Req):
@@ -213,27 +325,39 @@ class RadixCache(BasePrefixCache):
213
325
  return
214
326
 
215
327
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328
+ all_token_len = len(token_ids)
329
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
216
330
  kv_indices = self.req_to_token_pool.req_to_token[
217
- req.req_pool_idx, : len(token_ids)
331
+ req.req_pool_idx, :all_token_len
218
332
  ]
219
333
 
220
334
  if self.page_size != 1:
221
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
335
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
222
336
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
223
337
  dtype=torch.int64, copy=True
224
338
  )
225
339
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
226
340
  else:
227
- page_aligned_len = len(kv_indices)
341
+ page_aligned_len = actual_kv_len
228
342
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
343
+ if self.is_eagle:
344
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
345
+
346
+ page_aligned_token_len = (
347
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
348
+ )
349
+
350
+ old_prefix_len = len(req.prefix_indices)
351
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
352
+ # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
353
+ old_prefix_len -= 1
229
354
 
230
355
  # Radix Cache takes one ref in memory pool
231
356
  new_prefix_len = self.insert(
232
- token_ids[:page_aligned_len], page_aligned_kv_indices
233
- )
234
- self.token_to_kv_pool_allocator.free(
235
- kv_indices[len(req.prefix_indices) : new_prefix_len]
357
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
358
+ page_aligned_kv_indices,
236
359
  )
360
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
237
361
 
238
362
  # Remove req slot release the cache lock
239
363
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -245,45 +369,73 @@ class RadixCache(BasePrefixCache):
245
369
  return
246
370
 
247
371
  token_ids = req.fill_ids
372
+ all_token_len = len(token_ids)
373
+ # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
374
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
248
375
  kv_indices = self.req_to_token_pool.req_to_token[
249
- req.req_pool_idx, : len(token_ids)
376
+ req.req_pool_idx, :all_token_len
250
377
  ]
251
378
 
252
379
  if self.page_size != 1:
253
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
380
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
254
381
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
255
382
  dtype=torch.int64, copy=True
256
383
  )
257
384
  else:
258
- page_aligned_len = len(kv_indices)
385
+ page_aligned_len = actual_kv_len
259
386
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
260
- page_aligned_token_ids = token_ids[:page_aligned_len]
387
+
388
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
389
+ page_aligned_token_len = (
390
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
391
+ )
392
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
393
+
394
+ old_prefix_len = len(req.prefix_indices)
395
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
396
+ # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
397
+ old_prefix_len -= 1
261
398
 
262
399
  # Radix Cache takes one ref in memory pool
263
400
  new_prefix_len = self.insert(
264
- page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
265
- )
266
- self.token_to_kv_pool_allocator.free(
267
- kv_indices[len(req.prefix_indices) : new_prefix_len]
401
+ RadixKey(page_aligned_token_ids, req.extra_key),
402
+ page_aligned_kv_indices,
403
+ chunked=chunked,
268
404
  )
405
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
269
406
 
270
407
  # The prefix indices could be updated, reuse it
271
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
408
+ new_indices, new_last_node, _, _ = self.match_prefix(
409
+ RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
410
+ )
272
411
  self.req_to_token_pool.write(
273
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
274
- new_indices[len(req.prefix_indices) :],
412
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
413
+ new_indices[old_prefix_len:],
275
414
  )
276
415
 
416
+ # The last_matched_prefix_len is not always equal to len(req.prefix_indices)
417
+ # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
418
+ # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
419
+ # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
420
+ req.last_matched_prefix_len = len(new_indices)
421
+
277
422
  self.dec_lock_ref(req.last_node)
278
423
  self.inc_lock_ref(new_last_node)
279
424
 
280
425
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
281
426
  if self.page_size != 1:
427
+ # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
282
428
  req.prefix_indices = torch.cat(
283
429
  [new_indices, kv_indices[len(new_indices) :]]
284
430
  )
285
431
  else:
286
- req.prefix_indices = new_indices
432
+ if self.is_eagle:
433
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
434
+ req.prefix_indices = torch.cat(
435
+ [new_indices, kv_indices[actual_kv_len:]]
436
+ )
437
+ else:
438
+ req.prefix_indices = new_indices
287
439
  req.last_node = new_last_node
288
440
 
289
441
  def pretty_print(self):
@@ -298,11 +450,14 @@ class RadixCache(BasePrefixCache):
298
450
  return
299
451
 
300
452
  leaves = self._collect_leaves()
301
- heapq.heapify(leaves)
453
+ eviction_heap = [
454
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
455
+ ]
456
+ heapq.heapify(eviction_heap)
302
457
 
303
458
  num_evicted = 0
304
- while num_evicted < num_tokens and len(leaves):
305
- x = heapq.heappop(leaves)
459
+ while num_evicted < num_tokens and len(eviction_heap):
460
+ _priority, x = heapq.heappop(eviction_heap)
306
461
 
307
462
  if x == self.root_node:
308
463
  break
@@ -314,7 +469,8 @@ class RadixCache(BasePrefixCache):
314
469
  self._delete_leaf(x)
315
470
 
316
471
  if len(x.parent.children) == 0:
317
- heapq.heappush(leaves, x.parent)
472
+ new_priority = self.eviction_strategy.get_priority(x.parent)
473
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
318
474
 
319
475
  self._record_remove_event(x)
320
476
 
@@ -325,9 +481,9 @@ class RadixCache(BasePrefixCache):
325
481
  delta = 0
326
482
  while node != self.root_node:
327
483
  if node.lock_ref == 0:
328
- self.evictable_size_ -= len(node.value)
329
- self.protected_size_ += len(node.value)
330
- delta -= len(node.value)
484
+ self.evictable_size_ -= len(node.key)
485
+ self.protected_size_ += len(node.key)
486
+ delta -= len(node.key)
331
487
  node.lock_ref += 1
332
488
  node = node.parent
333
489
  return delta
@@ -339,9 +495,9 @@ class RadixCache(BasePrefixCache):
339
495
  delta = 0
340
496
  while node != self.root_node:
341
497
  if node.lock_ref == 1:
342
- self.evictable_size_ += len(node.value)
343
- self.protected_size_ -= len(node.value)
344
- delta += len(node.value)
498
+ self.evictable_size_ += len(node.key)
499
+ self.protected_size_ -= len(node.key)
500
+ delta += len(node.key)
345
501
  node.lock_ref -= 1
346
502
  node = node.parent
347
503
  return delta
@@ -366,7 +522,7 @@ class RadixCache(BasePrefixCache):
366
522
 
367
523
  ##### Internal Helper Functions #####
368
524
 
369
- def _match_prefix_helper(self, node: TreeNode, key: List):
525
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
370
526
  node.last_access_time = time.monotonic()
371
527
 
372
528
  child_key = self.get_child_key_fn(key)
@@ -391,7 +547,7 @@ class RadixCache(BasePrefixCache):
391
547
 
392
548
  return value, node
393
549
 
394
- def _split_node(self, key, child: TreeNode, split_len: int):
550
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
395
551
  # new_node -> child
396
552
  self._record_remove_event(child)
397
553
  new_node = TreeNode()
@@ -410,7 +566,7 @@ class RadixCache(BasePrefixCache):
410
566
 
411
567
  return new_node
412
568
 
413
- def _insert_helper(self, node: TreeNode, key: List, value):
569
+ def _insert_helper(self, node: TreeNode, key: RadixKey, value):
414
570
  node.last_access_time = time.monotonic()
415
571
  if len(key) == 0:
416
572
  return 0
@@ -439,7 +595,7 @@ class RadixCache(BasePrefixCache):
439
595
  new_node.key = key
440
596
  new_node.value = value
441
597
  node.children[child_key] = new_node
442
- self.evictable_size_ += len(value)
598
+ self.evictable_size_ += len(key)
443
599
  self._record_store_event(new_node)
444
600
  return total_prefix_length
445
601
 
@@ -451,7 +607,7 @@ class RadixCache(BasePrefixCache):
451
607
  print(
452
608
  " " * current_indent,
453
609
  len(current_node.key),
454
- current_node.key[:10],
610
+ current_node.key.token_ids[:10],
455
611
  f"r={current_node.lock_ref}",
456
612
  )
457
613
  for key, child in current_node.children.items():
@@ -503,11 +659,11 @@ class RadixCache(BasePrefixCache):
503
659
  last_page_start = (
504
660
  (len(node.parent.key) - 1) // self.page_size
505
661
  ) * self.page_size
506
- parent_parent_tokens = node.parent.key[last_page_start:]
662
+ parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
507
663
  parent_block_hash = hash(tuple(parent_parent_tokens))
508
664
 
509
665
  for start in range(0, len(node.key), self.page_size):
510
- page_tokens = node.key[start : start + self.page_size]
666
+ page_tokens = node.key.token_ids[start : start + self.page_size]
511
667
  if not page_tokens:
512
668
  continue
513
669
 
@@ -530,7 +686,7 @@ class RadixCache(BasePrefixCache):
530
686
  # One BlockRemoved per chunk.
531
687
  if self.enable_kv_cache_events:
532
688
  for start in range(0, len(node.key), self.page_size):
533
- page_tokens = node.key[start : start + self.page_size]
689
+ page_tokens = node.key.token_ids[start : start + self.page_size]
534
690
  if not page_tokens:
535
691
  continue
536
692
  block_hash = hash(tuple(page_tokens))
@@ -556,19 +712,12 @@ class RadixCache(BasePrefixCache):
556
712
  if __name__ == "__main__":
557
713
  tree = RadixCache(None, None, page_size=1, disable=False)
558
714
 
559
- tree.insert("Hello")
560
- tree.insert("Hello")
561
- tree.insert("Hello_L.A.!")
562
- # tree.insert("Hello_world! Happy")
563
- # tree.insert("I love you!")
715
+ # Example token id sequences (as lists of ints)
716
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
717
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
718
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
719
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
720
+ tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
564
721
  tree.pretty_print()
565
722
 
566
- # print(tree.match_prefix("I love you! aha"))
567
-
568
- # def evict_callback(x):
569
- # print("evict", x)
570
- # return len(x)
571
-
572
- # tree.evict(5, evict_callback)
573
- # tree.evict(10, evict_callback)
574
- # tree.pretty_print()
723
+ print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
13
13
  TreeNodeCpp,
14
14
  )
15
15
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+ from sglang.srt.mem_cache.radix_cache import RadixKey
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.managers.schedule_batch import Req
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
93
94
  raise NotImplementedError("Host cache is not supported yet")
94
95
  self.tree.reset()
95
96
 
96
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
97
98
  device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
- self.tree.match_prefix(key)
99
+ self.tree.match_prefix(key.token_ids)
99
100
  )
100
101
  return MatchResult(
101
102
  device_indices=self._merge_tensor(device_indices_vec),
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
104
105
  host_hit_length=host_indices_length,
105
106
  )
106
107
 
107
- def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
108
109
  """
109
110
  Insert a key-value pair into the radix tree.
110
111
  Args:
111
- key (List[int]): The key to insert, represented as a list of integers.
112
+ key (RadixKey): The key to insert, represented as a RadixKey.
112
113
  value (torch.Tensor): The value to associate with the key.
113
114
  Returns:
114
115
  int: Number of device indices that were already present in the tree before the insertion.
115
116
  """
116
- ongoing_write, length = self.tree.writing_through(key, value)
117
+ ongoing_write, length = self.tree.writing_through(key.token_ids, value)
117
118
  if self.cache_controller is None:
118
119
  assert len(ongoing_write) == 0, "Implementation error"
119
120
  return length
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
160
161
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
162
  # it will automatically align them, but length of them should be equal
162
163
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
- new_prefix_len = self._insert(token_ids, kv_indices)
164
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
164
165
 
165
166
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
167
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
191
192
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
193
  # it will automatically align them, but length of them should be equal
193
194
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
- new_prefix_len = self._insert(token_ids, kv_indices)
195
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
195
196
 
196
197
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
198
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
199
 
199
200
  # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
201
  # The prefix indices need to updated to reuse the kv indices in the pool
201
- new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
202
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
203
+ RadixKey(token_ids, req.extra_key).token_ids
204
+ )
202
205
  new_indices = self._merge_tensor(new_indices_vec)
203
206
  assert new_prefix_len <= len(new_indices)
204
207
 
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ """Storage backend module for SGLang HiCache."""
5
+
6
+ from .backend_factory import StorageBackendFactory
7
+
8
+ __all__ = [
9
+ "StorageBackendFactory",
10
+ ]