sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional
26
+ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
27
27
 
28
28
  import torch
29
29
 
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
34
34
  )
35
35
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
36
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
37
38
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from sglang.srt.managers.schedule_batch import Req
41
42
 
42
43
 
44
+ class RadixKey:
45
+
46
+ def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
47
+ # token ids sequence
48
+ self.token_ids = token_ids
49
+ # extra key (e.g. lora_id, cache_salt)
50
+ self.extra_key = extra_key
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.token_ids)
54
+
55
+ def __iter__(self) -> Iterator[int]:
56
+ return iter(self.token_ids)
57
+
58
+ def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
59
+ if isinstance(idx, slice):
60
+ return RadixKey(self.token_ids[idx], self.extra_key)
61
+ return RadixKey([self.token_ids[idx]], self.extra_key)
62
+
63
+ def __repr__(self) -> str:
64
+ preview = self.token_ids[:10]
65
+ return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
66
+
67
+
43
68
  class TreeNode:
44
69
 
45
70
  counter = 0
@@ -47,7 +72,7 @@ class TreeNode:
47
72
  def __init__(self, id: Optional[int] = None):
48
73
  self.children = defaultdict(TreeNode)
49
74
  self.parent: TreeNode = None
50
- self.key: List[int] = None
75
+ self.key: RadixKey = None
51
76
  self.value: Optional[torch.Tensor] = None
52
77
  self.lock_ref = 0
53
78
  self.last_access_time = time.monotonic()
@@ -93,27 +118,57 @@ class TreeNode:
93
118
  return self.last_access_time < other.last_access_time
94
119
 
95
120
 
96
- def _key_match_page_size1(key0: List, key1: List):
121
+ def _check_extra_key(key0: RadixKey, key1: RadixKey):
122
+ if key0.extra_key != key1.extra_key:
123
+ raise ValueError(
124
+ f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
125
+ )
126
+
127
+
128
+ def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
129
+ _check_extra_key(key0, key1)
97
130
  i = 0
98
- for k0, k1 in zip(key0, key1):
131
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
99
132
  if k0 != k1:
100
133
  break
101
134
  i += 1
102
135
  return i
103
136
 
104
137
 
105
- def _key_match_paged(key0: List, key1: List, page_size: int):
138
+ def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
139
+ _check_extra_key(key0, key1)
106
140
  min_len = min(len(key0), len(key1))
107
141
 
108
142
  i = 0
109
143
  while i < min_len:
110
- if key0[i : i + page_size] != key1[i : i + page_size]:
144
+ if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
111
145
  break
112
146
  i += page_size
113
147
 
114
148
  return i
115
149
 
116
150
 
151
+ def get_child_key(key: RadixKey, page_size: int = 1):
152
+ if page_size == 1:
153
+ plain_key = key.token_ids[0]
154
+ else:
155
+ plain_key = tuple(key.token_ids[:page_size])
156
+ if key.extra_key is None:
157
+ return plain_key
158
+ else:
159
+ return (key.extra_key, plain_key)
160
+
161
+
162
+ def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
163
+ # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
164
+ # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
165
+ if len(tokens) < 2:
166
+ return []
167
+ if isinstance(tokens[0], tuple):
168
+ return tokens
169
+ return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
170
+
171
+
117
172
  class RadixCache(BasePrefixCache):
118
173
  def __init__(
119
174
  self,
@@ -122,6 +177,8 @@ class RadixCache(BasePrefixCache):
122
177
  page_size: int,
123
178
  disable: bool = False,
124
179
  enable_kv_cache_events: bool = False,
180
+ eviction_policy: str = "lru",
181
+ is_eagle: bool = False,
125
182
  ):
126
183
  self.req_to_token_pool = req_to_token_pool
127
184
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -129,6 +186,7 @@ class RadixCache(BasePrefixCache):
129
186
  self.disable = disable
130
187
  self.enable_kv_cache_events = enable_kv_cache_events
131
188
  self.kv_event_queue = []
189
+ self.is_eagle = is_eagle
132
190
 
133
191
  if self.token_to_kv_pool_allocator:
134
192
  self.device = self.token_to_kv_pool_allocator.device
@@ -137,17 +195,31 @@ class RadixCache(BasePrefixCache):
137
195
 
138
196
  if self.page_size == 1:
139
197
  self.key_match_fn = _key_match_page_size1
140
- self.get_child_key_fn = lambda key: key[0]
198
+ self.get_child_key_fn = get_child_key
141
199
  else:
142
200
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
143
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
201
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
202
+
203
+ if is_eagle:
204
+ self.key_convert_fn = _convert_to_bigram_key
205
+ else:
206
+ self.key_convert_fn = lambda key: key
207
+
208
+ if eviction_policy.lower() == "lru":
209
+ self.eviction_strategy: EvictionStrategy = LRUStrategy()
210
+ elif eviction_policy.lower() == "lfu":
211
+ self.eviction_strategy: EvictionStrategy = LFUStrategy()
212
+ else:
213
+ raise ValueError(
214
+ f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
215
+ )
144
216
  self.reset()
145
217
 
146
218
  ##### Public API #####
147
219
 
148
220
  def reset(self):
149
221
  self.root_node = TreeNode()
150
- self.root_node.key = []
222
+ self.root_node.key = RadixKey(token_ids=[], extra_key=None)
151
223
  self.root_node.value = []
152
224
  self.root_node.host_value = []
153
225
  self.root_node.lock_ref = 1
@@ -155,18 +227,47 @@ class RadixCache(BasePrefixCache):
155
227
  self.protected_size_ = 0
156
228
  self._record_all_cleared_event()
157
229
 
158
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
159
- """Find the matching prefix from the radix tree.
230
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
231
+ """Find the longest cached prefix of ``key`` in the radix tree.
232
+
233
+ The logical namespace for prefix matching is determined by both the
234
+ token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
235
+ Entries that share identical leading token ids but have *different*
236
+ ``extra_key`` values are intentionally kept disjoint and never share
237
+ prefix nodes. This is useful to:
238
+
239
+ * Isolate KV cache lines for different LoRA / adapter IDs.
240
+ * Separate requests that intentionally should not share state (e.g.,
241
+ different sampling salt, cache version, or retrieval augmentation
242
+ context) by supplying a distinct ``extra_key``.
243
+
160
244
  Args:
161
- key: A list of token IDs to find a matching prefix.
245
+ key (RadixKey): The lookup key containing a list of token ids and an
246
+ optional ``extra_key`` namespace tag. If ``page_size > 1`` the
247
+ length is internally truncated to a multiple of ``page_size``
248
+ before matching. Passing an empty key returns an empty result
249
+ with the root as the last node.
250
+ **kwargs: Reserved for future extensions (ignored currently).
251
+
162
252
  Returns:
163
- A tuple of a tensor of matching prefix token IDs and
164
- the last node that contains the prefix values. Note that
165
- this API can modify the internal state of the Radix tree.
166
- The last node create a new child if the prefix is shorter
167
- than the last node's value.
253
+ MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
254
+ the concatenated KV cache indices corresponding to the longest
255
+ cached prefix (may be length 0). ``last_device_node`` and
256
+ ``last_host_node`` (currently the same) are the tree node objects
257
+ representing the terminal node of the matched prefix. This method
258
+ may mutate internal structure by splitting an existing node if the
259
+ match ends inside a stored segment.
260
+
261
+ Internal updates:
262
+ * Refreshes access metadata (timestamps) used by the
263
+ configured eviction strategy.
264
+ * If the lookup ends inside a stored segment the node is split once
265
+ to expose a precise boundary; this structural refinement improves
266
+ subsequent match efficiency and does not duplicate data.
168
267
  """
169
- if self.disable or len(key) == 0:
268
+ key.token_ids = self.key_convert_fn(key.token_ids)
269
+
270
+ def empty_match_result():
170
271
  return MatchResult(
171
272
  device_indices=torch.empty(
172
273
  (0,),
@@ -177,10 +278,16 @@ class RadixCache(BasePrefixCache):
177
278
  last_host_node=self.root_node,
178
279
  )
179
280
 
281
+ if self.disable or len(key) == 0:
282
+ return empty_match_result()
283
+
180
284
  if self.page_size != 1:
181
285
  page_aligned_len = len(key) // self.page_size * self.page_size
182
286
  key = key[:page_aligned_len]
183
287
 
288
+ if len(key) == 0:
289
+ return empty_match_result()
290
+
184
291
  value, last_node = self._match_prefix_helper(self.root_node, key)
185
292
  if value:
186
293
  value = torch.cat(value)
@@ -192,12 +299,19 @@ class RadixCache(BasePrefixCache):
192
299
  last_host_node=last_node,
193
300
  )
194
301
 
195
- def insert(self, key: List, value=None, chunked=False):
302
+ def insert(self, key: RadixKey, value=None, chunked=False):
196
303
  if self.disable:
197
304
  return 0
198
305
 
306
+ key.token_ids = self.key_convert_fn(key.token_ids)
307
+
199
308
  if value is None:
200
- value = [x for x in key]
309
+ value = torch.tensor(key.token_ids, dtype=torch.int64)
310
+
311
+ if self.is_eagle:
312
+ # Make sure the value len equal to the EAGLE bigram key len
313
+ value = value[: len(key)]
314
+
201
315
  return self._insert_helper(self.root_node, key, value)
202
316
 
203
317
  def cache_finished_req(self, req: Req):
@@ -211,27 +325,39 @@ class RadixCache(BasePrefixCache):
211
325
  return
212
326
 
213
327
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328
+ all_token_len = len(token_ids)
329
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
214
330
  kv_indices = self.req_to_token_pool.req_to_token[
215
- req.req_pool_idx, : len(token_ids)
331
+ req.req_pool_idx, :all_token_len
216
332
  ]
217
333
 
218
334
  if self.page_size != 1:
219
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
335
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
220
336
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
221
337
  dtype=torch.int64, copy=True
222
338
  )
223
339
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
224
340
  else:
225
- page_aligned_len = len(kv_indices)
341
+ page_aligned_len = actual_kv_len
226
342
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
343
+ if self.is_eagle:
344
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
345
+
346
+ page_aligned_token_len = (
347
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
348
+ )
349
+
350
+ old_prefix_len = len(req.prefix_indices)
351
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
352
+ # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
353
+ old_prefix_len -= 1
227
354
 
228
355
  # Radix Cache takes one ref in memory pool
229
356
  new_prefix_len = self.insert(
230
- token_ids[:page_aligned_len], page_aligned_kv_indices
231
- )
232
- self.token_to_kv_pool_allocator.free(
233
- kv_indices[len(req.prefix_indices) : new_prefix_len]
357
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
358
+ page_aligned_kv_indices,
234
359
  )
360
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
235
361
 
236
362
  # Remove req slot release the cache lock
237
363
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -243,45 +369,73 @@ class RadixCache(BasePrefixCache):
243
369
  return
244
370
 
245
371
  token_ids = req.fill_ids
372
+ all_token_len = len(token_ids)
373
+ # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
374
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
246
375
  kv_indices = self.req_to_token_pool.req_to_token[
247
- req.req_pool_idx, : len(token_ids)
376
+ req.req_pool_idx, :all_token_len
248
377
  ]
249
378
 
250
379
  if self.page_size != 1:
251
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
380
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
252
381
  page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
253
382
  dtype=torch.int64, copy=True
254
383
  )
255
384
  else:
256
- page_aligned_len = len(kv_indices)
385
+ page_aligned_len = actual_kv_len
257
386
  page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
258
- page_aligned_token_ids = token_ids[:page_aligned_len]
387
+
388
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
389
+ page_aligned_token_len = (
390
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
391
+ )
392
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
393
+
394
+ old_prefix_len = len(req.prefix_indices)
395
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
396
+ # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
397
+ old_prefix_len -= 1
259
398
 
260
399
  # Radix Cache takes one ref in memory pool
261
400
  new_prefix_len = self.insert(
262
- page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
263
- )
264
- self.token_to_kv_pool_allocator.free(
265
- kv_indices[len(req.prefix_indices) : new_prefix_len]
401
+ RadixKey(page_aligned_token_ids, req.extra_key),
402
+ page_aligned_kv_indices,
403
+ chunked=chunked,
266
404
  )
405
+ self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
267
406
 
268
407
  # The prefix indices could be updated, reuse it
269
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
408
+ new_indices, new_last_node, _, _ = self.match_prefix(
409
+ RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
410
+ )
270
411
  self.req_to_token_pool.write(
271
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
272
- new_indices[len(req.prefix_indices) :],
412
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
413
+ new_indices[old_prefix_len:],
273
414
  )
274
415
 
416
+ # The last_matched_prefix_len is not always equal to len(req.prefix_indices)
417
+ # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
418
+ # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
419
+ # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
420
+ req.last_matched_prefix_len = len(new_indices)
421
+
275
422
  self.dec_lock_ref(req.last_node)
276
423
  self.inc_lock_ref(new_last_node)
277
424
 
278
425
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
279
426
  if self.page_size != 1:
427
+ # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
280
428
  req.prefix_indices = torch.cat(
281
429
  [new_indices, kv_indices[len(new_indices) :]]
282
430
  )
283
431
  else:
284
- req.prefix_indices = new_indices
432
+ if self.is_eagle:
433
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
434
+ req.prefix_indices = torch.cat(
435
+ [new_indices, kv_indices[actual_kv_len:]]
436
+ )
437
+ else:
438
+ req.prefix_indices = new_indices
285
439
  req.last_node = new_last_node
286
440
 
287
441
  def pretty_print(self):
@@ -296,11 +450,14 @@ class RadixCache(BasePrefixCache):
296
450
  return
297
451
 
298
452
  leaves = self._collect_leaves()
299
- heapq.heapify(leaves)
453
+ eviction_heap = [
454
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
455
+ ]
456
+ heapq.heapify(eviction_heap)
300
457
 
301
458
  num_evicted = 0
302
- while num_evicted < num_tokens and len(leaves):
303
- x = heapq.heappop(leaves)
459
+ while num_evicted < num_tokens and len(eviction_heap):
460
+ _priority, x = heapq.heappop(eviction_heap)
304
461
 
305
462
  if x == self.root_node:
306
463
  break
@@ -312,7 +469,8 @@ class RadixCache(BasePrefixCache):
312
469
  self._delete_leaf(x)
313
470
 
314
471
  if len(x.parent.children) == 0:
315
- heapq.heappush(leaves, x.parent)
472
+ new_priority = self.eviction_strategy.get_priority(x.parent)
473
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
316
474
 
317
475
  self._record_remove_event(x)
318
476
 
@@ -323,9 +481,9 @@ class RadixCache(BasePrefixCache):
323
481
  delta = 0
324
482
  while node != self.root_node:
325
483
  if node.lock_ref == 0:
326
- self.evictable_size_ -= len(node.value)
327
- self.protected_size_ += len(node.value)
328
- delta -= len(node.value)
484
+ self.evictable_size_ -= len(node.key)
485
+ self.protected_size_ += len(node.key)
486
+ delta -= len(node.key)
329
487
  node.lock_ref += 1
330
488
  node = node.parent
331
489
  return delta
@@ -337,9 +495,9 @@ class RadixCache(BasePrefixCache):
337
495
  delta = 0
338
496
  while node != self.root_node:
339
497
  if node.lock_ref == 1:
340
- self.evictable_size_ += len(node.value)
341
- self.protected_size_ -= len(node.value)
342
- delta += len(node.value)
498
+ self.evictable_size_ += len(node.key)
499
+ self.protected_size_ -= len(node.key)
500
+ delta += len(node.key)
343
501
  node.lock_ref -= 1
344
502
  node = node.parent
345
503
  return delta
@@ -364,7 +522,7 @@ class RadixCache(BasePrefixCache):
364
522
 
365
523
  ##### Internal Helper Functions #####
366
524
 
367
- def _match_prefix_helper(self, node: TreeNode, key: List):
525
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
368
526
  node.last_access_time = time.monotonic()
369
527
 
370
528
  child_key = self.get_child_key_fn(key)
@@ -389,7 +547,7 @@ class RadixCache(BasePrefixCache):
389
547
 
390
548
  return value, node
391
549
 
392
- def _split_node(self, key, child: TreeNode, split_len: int):
550
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
393
551
  # new_node -> child
394
552
  self._record_remove_event(child)
395
553
  new_node = TreeNode()
@@ -408,7 +566,7 @@ class RadixCache(BasePrefixCache):
408
566
 
409
567
  return new_node
410
568
 
411
- def _insert_helper(self, node: TreeNode, key: List, value):
569
+ def _insert_helper(self, node: TreeNode, key: RadixKey, value):
412
570
  node.last_access_time = time.monotonic()
413
571
  if len(key) == 0:
414
572
  return 0
@@ -437,7 +595,7 @@ class RadixCache(BasePrefixCache):
437
595
  new_node.key = key
438
596
  new_node.value = value
439
597
  node.children[child_key] = new_node
440
- self.evictable_size_ += len(value)
598
+ self.evictable_size_ += len(key)
441
599
  self._record_store_event(new_node)
442
600
  return total_prefix_length
443
601
 
@@ -449,7 +607,7 @@ class RadixCache(BasePrefixCache):
449
607
  print(
450
608
  " " * current_indent,
451
609
  len(current_node.key),
452
- current_node.key[:10],
610
+ current_node.key.token_ids[:10],
453
611
  f"r={current_node.lock_ref}",
454
612
  )
455
613
  for key, child in current_node.children.items():
@@ -501,11 +659,11 @@ class RadixCache(BasePrefixCache):
501
659
  last_page_start = (
502
660
  (len(node.parent.key) - 1) // self.page_size
503
661
  ) * self.page_size
504
- parent_parent_tokens = node.parent.key[last_page_start:]
662
+ parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
505
663
  parent_block_hash = hash(tuple(parent_parent_tokens))
506
664
 
507
665
  for start in range(0, len(node.key), self.page_size):
508
- page_tokens = node.key[start : start + self.page_size]
666
+ page_tokens = node.key.token_ids[start : start + self.page_size]
509
667
  if not page_tokens:
510
668
  continue
511
669
 
@@ -528,7 +686,7 @@ class RadixCache(BasePrefixCache):
528
686
  # One BlockRemoved per chunk.
529
687
  if self.enable_kv_cache_events:
530
688
  for start in range(0, len(node.key), self.page_size):
531
- page_tokens = node.key[start : start + self.page_size]
689
+ page_tokens = node.key.token_ids[start : start + self.page_size]
532
690
  if not page_tokens:
533
691
  continue
534
692
  block_hash = hash(tuple(page_tokens))
@@ -554,19 +712,12 @@ class RadixCache(BasePrefixCache):
554
712
  if __name__ == "__main__":
555
713
  tree = RadixCache(None, None, page_size=1, disable=False)
556
714
 
557
- tree.insert("Hello")
558
- tree.insert("Hello")
559
- tree.insert("Hello_L.A.!")
560
- # tree.insert("Hello_world! Happy")
561
- # tree.insert("I love you!")
715
+ # Example token id sequences (as lists of ints)
716
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
717
+ tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
718
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
719
+ tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
720
+ tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
562
721
  tree.pretty_print()
563
722
 
564
- # print(tree.match_prefix("I love you! aha"))
565
-
566
- # def evict_callback(x):
567
- # print("evict", x)
568
- # return len(x)
569
-
570
- # tree.evict(5, evict_callback)
571
- # tree.evict(10, evict_callback)
572
- # tree.pretty_print()
723
+ print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
13
13
  TreeNodeCpp,
14
14
  )
15
15
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+ from sglang.srt.mem_cache.radix_cache import RadixKey
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from sglang.srt.managers.schedule_batch import Req
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
93
94
  raise NotImplementedError("Host cache is not supported yet")
94
95
  self.tree.reset()
95
96
 
96
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
97
98
  device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
- self.tree.match_prefix(key)
99
+ self.tree.match_prefix(key.token_ids)
99
100
  )
100
101
  return MatchResult(
101
102
  device_indices=self._merge_tensor(device_indices_vec),
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
104
105
  host_hit_length=host_indices_length,
105
106
  )
106
107
 
107
- def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
108
109
  """
109
110
  Insert a key-value pair into the radix tree.
110
111
  Args:
111
- key (List[int]): The key to insert, represented as a list of integers.
112
+ key (RadixKey): The key to insert, represented as a RadixKey.
112
113
  value (torch.Tensor): The value to associate with the key.
113
114
  Returns:
114
115
  int: Number of device indices that were already present in the tree before the insertion.
115
116
  """
116
- ongoing_write, length = self.tree.writing_through(key, value)
117
+ ongoing_write, length = self.tree.writing_through(key.token_ids, value)
117
118
  if self.cache_controller is None:
118
119
  assert len(ongoing_write) == 0, "Implementation error"
119
120
  return length
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
160
161
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
162
  # it will automatically align them, but length of them should be equal
162
163
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
- new_prefix_len = self._insert(token_ids, kv_indices)
164
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
164
165
 
165
166
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
167
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
191
192
  # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
193
  # it will automatically align them, but length of them should be equal
193
194
  old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
- new_prefix_len = self._insert(token_ids, kv_indices)
195
+ new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
195
196
 
196
197
  # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
198
  assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
199
 
199
200
  # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
201
  # The prefix indices need to updated to reuse the kv indices in the pool
201
- new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
202
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
203
+ RadixKey(token_ids, req.extra_key).token_ids
204
+ )
202
205
  new_indices = self._merge_tensor(new_indices_vec)
203
206
  assert new_prefix_len <= len(new_indices)
204
207
 
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ """Storage backend module for SGLang HiCache."""
5
+
6
+ from .backend_factory import StorageBackendFactory
7
+
8
+ __all__ = [
9
+ "StorageBackendFactory",
10
+ ]