sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import heapq
2
+ import json
2
3
  import logging
3
4
  import threading
4
5
  import time
5
- from queue import Queue
6
6
  from typing import List, Optional
7
7
 
8
8
  import torch
@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
19
19
  MHATokenToKVPoolHost,
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
22
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
23
23
  from sglang.srt.metrics.collector import StorageMetricsCollector
24
24
 
25
25
  logger = logging.getLogger(__name__)
@@ -39,17 +39,19 @@ class HiRadixCache(RadixCache):
39
39
  hicache_io_backend: str,
40
40
  hicache_mem_layout: str,
41
41
  enable_metrics: bool,
42
+ eviction_policy: str = "lru",
42
43
  hicache_storage_backend: Optional[str] = None,
43
44
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
44
45
  model_name: Optional[str] = None,
45
46
  storage_backend_extra_config: Optional[str] = None,
47
+ is_eagle: bool = False,
46
48
  ):
47
49
 
48
50
  if hicache_io_backend == "direct":
49
51
  if hicache_mem_layout == "page_first":
50
- hicache_mem_layout = "layer_first"
52
+ hicache_mem_layout = "page_first_direct"
51
53
  logger.warning(
52
- "Page first layout is not supported with direct IO backend, switching to layer first layout"
54
+ "Page first layout is not supported with direct IO backend, switching to page first direct layout"
53
55
  )
54
56
 
55
57
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
@@ -77,9 +79,21 @@ class HiRadixCache(RadixCache):
77
79
  self.enable_storage = hicache_storage_backend is not None
78
80
  self.enable_storage_metrics = self.enable_storage and enable_metrics
79
81
 
80
- # todo: customizable storage prefetch threshold and timeout
81
- self.prefetch_threshold = 256
82
- self.prefetch_timeout = 3 # seconds
82
+ (
83
+ extra_config,
84
+ prefetch_threshold,
85
+ prefetch_timeout_base,
86
+ prefetch_timeout_per_ki_token,
87
+ hicache_storage_pass_prefix_keys,
88
+ ) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
89
+ self.prefetch_threshold = prefetch_threshold
90
+ self.prefetch_timeout_base = prefetch_timeout_base
91
+ self.prefetch_timeout_per_page = (
92
+ page_size / 1024 * prefetch_timeout_per_ki_token
93
+ )
94
+ self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
95
+ # TODO: support more timeout check functions
96
+ self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
83
97
  self.prefetch_stop_policy = hicache_storage_prefetch_policy
84
98
 
85
99
  self.load_cache_event = threading.Event()
@@ -94,7 +108,7 @@ class HiRadixCache(RadixCache):
94
108
  storage_backend=hicache_storage_backend,
95
109
  prefetch_threshold=self.prefetch_threshold,
96
110
  model_name=model_name,
97
- storage_backend_extra_config=storage_backend_extra_config,
111
+ storage_backend_extra_config=extra_config,
98
112
  )
99
113
  if self.enable_storage_metrics:
100
114
  # TODO: support pp
@@ -117,8 +131,65 @@ class HiRadixCache(RadixCache):
117
131
  1 if hicache_write_policy == "write_through" else 2
118
132
  )
119
133
  self.load_back_threshold = 10
134
+
120
135
  super().__init__(
121
- req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
136
+ req_to_token_pool,
137
+ token_to_kv_pool_allocator,
138
+ page_size,
139
+ disable=False,
140
+ eviction_policy=eviction_policy,
141
+ is_eagle=is_eagle,
142
+ )
143
+
144
+ def _parse_storage_backend_extra_config(
145
+ self, storage_backend_extra_config: Optional[str]
146
+ ):
147
+ """
148
+ Parse storage backend extra config JSON and extract specific parameters.
149
+
150
+ Args:
151
+ storage_backend_extra_config: JSON string containing extra configuration
152
+
153
+ Returns:
154
+ tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
155
+ """
156
+ # Parse extra config JSON if provided
157
+ extra_config = {}
158
+ if storage_backend_extra_config:
159
+ try:
160
+ extra_config = json.loads(storage_backend_extra_config)
161
+ except Exception as e:
162
+ logger.error(f"Invalid backend extra config JSON: {e}")
163
+ raise e
164
+
165
+ prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
166
+ prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
167
+ prefetch_timeout_per_ki_token = extra_config.pop(
168
+ "prefetch_timeout_per_ki_token", 0.25
169
+ ) # seconds per 1024 tokens
170
+ hicache_storage_pass_prefix_keys = extra_config.pop(
171
+ "hicache_storage_pass_prefix_keys", False
172
+ )
173
+
174
+ if not isinstance(prefetch_threshold, int):
175
+ raise ValueError(
176
+ f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
177
+ )
178
+ if not isinstance(prefetch_timeout_base, (int, float)):
179
+ raise ValueError(
180
+ f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
181
+ )
182
+ if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
183
+ raise ValueError(
184
+ f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
185
+ )
186
+
187
+ return (
188
+ extra_config,
189
+ prefetch_threshold,
190
+ float(prefetch_timeout_base),
191
+ float(prefetch_timeout_per_ki_token),
192
+ hicache_storage_pass_prefix_keys,
122
193
  )
123
194
 
124
195
  def reset(self):
@@ -180,8 +251,14 @@ class HiRadixCache(RadixCache):
180
251
  return len(host_indices)
181
252
 
182
253
  def write_backup_storage(self, node: TreeNode):
254
+ prefix_keys = (
255
+ node.get_prefix_hash_values(node.parent)
256
+ if self.hicache_storage_pass_prefix_keys
257
+ else None
258
+ )
259
+
183
260
  operation_id = self.cache_controller.write_storage(
184
- node.host_value, node.key, node.hash_value
261
+ node.host_value, node.key, node.hash_value, prefix_keys
185
262
  )
186
263
  self.ongoing_backup[operation_id] = node
187
264
  node.protect_host()
@@ -258,12 +335,15 @@ class HiRadixCache(RadixCache):
258
335
 
259
336
  def evict(self, num_tokens: int):
260
337
  leaves = self._collect_leaves_device()
261
- heapq.heapify(leaves)
338
+ eviction_heap = [
339
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
340
+ ]
341
+ heapq.heapify(eviction_heap)
262
342
 
263
343
  num_evicted = 0
264
344
  write_back_nodes = []
265
- while num_evicted < num_tokens and len(leaves):
266
- x = heapq.heappop(leaves)
345
+ while num_evicted < num_tokens and len(eviction_heap):
346
+ _priority, x = heapq.heappop(eviction_heap)
267
347
 
268
348
  if x.lock_ref > 0:
269
349
  continue
@@ -285,7 +365,8 @@ class HiRadixCache(RadixCache):
285
365
  break
286
366
  else:
287
367
  # all children are evicted or no children
288
- heapq.heappush(leaves, x.parent)
368
+ new_priority = self.eviction_strategy.get_priority(x.parent)
369
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
289
370
 
290
371
  if self.cache_controller.write_policy == "write_back":
291
372
  self.writing_check(write_back=True)
@@ -295,7 +376,7 @@ class HiRadixCache(RadixCache):
295
376
 
296
377
  def _evict_backuped(self, node: TreeNode):
297
378
  # evict a node already written to host
298
- num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
379
+ num_evicted = self.cache_controller.evict_device(node.value)
299
380
  assert num_evicted > 0
300
381
  self.evictable_size_ -= num_evicted
301
382
  node.value = None
@@ -310,11 +391,14 @@ class HiRadixCache(RadixCache):
310
391
 
311
392
  def evict_host(self, num_tokens: int):
312
393
  leaves = self._collect_leaves()
313
- heapq.heapify(leaves)
394
+ eviction_heap = [
395
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
396
+ ]
397
+ heapq.heapify(eviction_heap)
314
398
 
315
399
  num_evicted = 0
316
- while num_evicted < num_tokens and len(leaves):
317
- x = heapq.heappop(leaves)
400
+ while num_evicted < num_tokens and len(eviction_heap):
401
+ _priority, x = heapq.heappop(eviction_heap)
318
402
  if x == self.root_node:
319
403
  break
320
404
  # only evict the host value of evicted nodes
@@ -333,7 +417,8 @@ class HiRadixCache(RadixCache):
333
417
  del x.parent.children[k]
334
418
 
335
419
  if len(x.parent.children) == 0 and x.parent.evicted:
336
- heapq.heappush(leaves, x.parent)
420
+ new_priority = self.eviction_strategy.get_priority(x.parent)
421
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
337
422
 
338
423
  def load_back(
339
424
  self, node: TreeNode, mem_quota: Optional[int] = None
@@ -476,6 +561,15 @@ class HiRadixCache(RadixCache):
476
561
  host_indices = torch.cat(host_indices_list, dim=0)
477
562
  cc.mem_pool_host.free(host_indices)
478
563
 
564
+ # Timeout is linearly increasing with the number of pages
565
+ def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
566
+ # If hash_value has not been computed in timeout_base seconds, terminate it.
567
+ return (
568
+ time.monotonic() - operation.start_time
569
+ > self.prefetch_timeout_base
570
+ + len(operation.hash_value) * self.prefetch_timeout_per_page
571
+ )
572
+
479
573
  def can_terminate_prefetch(self, operation: PrefetchOperation):
480
574
  can_terminate = True
481
575
 
@@ -492,9 +586,7 @@ class HiRadixCache(RadixCache):
492
586
  if self.prefetch_stop_policy == "wait_complete":
493
587
  can_terminate = completed
494
588
  elif self.prefetch_stop_policy == "timeout":
495
- can_terminate = completed or (
496
- time.monotonic() - operation.start_time > self.prefetch_timeout
497
- )
589
+ can_terminate = completed or self.is_prefetch_timeout(operation)
498
590
  else:
499
591
  # unknown prefetch stop policy, just return True
500
592
  return True
@@ -556,12 +648,12 @@ class HiRadixCache(RadixCache):
556
648
  written_indices = host_indices[:min_completed_tokens]
557
649
  matched_length = self._insert_helper_host(
558
650
  last_host_node,
559
- fetched_token_ids,
651
+ RadixKey(
652
+ token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
653
+ ),
560
654
  written_indices,
561
655
  hash_value[: min_completed_tokens // self.page_size],
562
656
  )
563
- if len(written_indices):
564
- self.cache_controller.mem_pool_host.update_prefetch(written_indices)
565
657
 
566
658
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
567
659
  self.cache_controller.append_host_mem_release(
@@ -578,8 +670,9 @@ class HiRadixCache(RadixCache):
578
670
 
579
671
  return True
580
672
 
581
- def match_prefix(self, key: List[int], **kwargs):
673
+ def match_prefix(self, key: RadixKey, **kwargs):
582
674
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
675
+ key.token_ids = self.key_convert_fn(key.token_ids)
583
676
  if self.disable or len(key) == 0:
584
677
  return MatchResult(
585
678
  device_indices=empty_value,
@@ -619,6 +712,7 @@ class HiRadixCache(RadixCache):
619
712
  last_host_node: TreeNode,
620
713
  new_input_tokens: List[int],
621
714
  last_hash: Optional[str] = None,
715
+ prefix_keys: Optional[List[str]] = None,
622
716
  ):
623
717
  # align the number of fetching tokens to the page size
624
718
  prefetch_length = len(new_input_tokens) - (
@@ -642,7 +736,7 @@ class HiRadixCache(RadixCache):
642
736
  # no sufficient host memory for prefetch
643
737
  return
644
738
  operation = self.cache_controller.prefetch(
645
- req_id, host_indices, new_input_tokens, last_hash
739
+ req_id, host_indices, new_input_tokens, last_hash, prefix_keys
646
740
  )
647
741
  self.ongoing_prefetch[req_id] = (
648
742
  last_host_node,
@@ -652,7 +746,9 @@ class HiRadixCache(RadixCache):
652
746
  )
653
747
  self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
654
748
 
655
- def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
749
+ def _insert_helper_host(
750
+ self, node: TreeNode, key: RadixKey, host_value, hash_value
751
+ ):
656
752
  node.last_access_time = time.monotonic()
657
753
  if len(key) == 0:
658
754
  return 0
@@ -686,7 +782,7 @@ class HiRadixCache(RadixCache):
686
782
  node.children[child_key] = new_node
687
783
  return matched_length
688
784
 
689
- def _match_prefix_helper(self, node: TreeNode, key: List):
785
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
690
786
  node.last_access_time = time.monotonic()
691
787
  child_key = self.get_child_key_fn(key)
692
788
  value = []
@@ -712,7 +808,7 @@ class HiRadixCache(RadixCache):
712
808
 
713
809
  return value, node
714
810
 
715
- def _split_node(self, key, child: TreeNode, split_len: int):
811
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
716
812
  # child node split into new_node -> child
717
813
  new_node = TreeNode()
718
814
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
@@ -739,10 +835,16 @@ class HiRadixCache(RadixCache):
739
835
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
740
836
  return new_node
741
837
 
742
- def insert(self, key: List, value, chunked=False):
838
+ def insert(self, key: RadixKey, value=None, chunked=False):
839
+ key.token_ids = self.key_convert_fn(key.token_ids)
840
+
743
841
  if len(key) == 0:
744
842
  return 0
745
843
 
844
+ if self.is_eagle and value is not None:
845
+ # Make sure the value len equal to the EAGLE bigram key len
846
+ value = value[: len(key)]
847
+
746
848
  node = self.root_node
747
849
  child_key = self.get_child_key_fn(key)
748
850
  total_prefix_length = 0
@@ -757,7 +859,6 @@ class HiRadixCache(RadixCache):
757
859
  # change the reference if the node is evicted
758
860
  # this often happens in the case of KV cache recomputation
759
861
  node.value = value[:prefix_len]
760
- self.token_to_kv_pool_host.update_synced(node.host_value)
761
862
  self.evictable_size_ += len(node.value)
762
863
  else:
763
864
  self._inc_hit_count(node, chunked)
@@ -767,7 +868,6 @@ class HiRadixCache(RadixCache):
767
868
  new_node = self._split_node(node.key, node, prefix_len)
768
869
  if new_node.evicted:
769
870
  new_node.value = value[:prefix_len]
770
- self.token_to_kv_pool_host.update_synced(new_node.host_value)
771
871
  self.evictable_size_ += len(new_node.value)
772
872
  else:
773
873
  self._inc_hit_count(new_node, chunked)
@@ -797,7 +897,7 @@ class HiRadixCache(RadixCache):
797
897
  for idx in range(0, len(key), self.page_size):
798
898
  new_node.hash_value.append(
799
899
  self.cache_controller.get_hash_str(
800
- key[idx : idx + self.page_size],
900
+ key.token_ids[idx : idx + self.page_size],
801
901
  prior_hash=last_hash,
802
902
  )
803
903
  )