sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import heapq
2
+ import json
2
3
  import logging
3
4
  import threading
4
5
  import time
5
- from queue import Queue
6
6
  from typing import List, Optional
7
7
 
8
8
  import torch
@@ -19,7 +19,8 @@ from sglang.srt.mem_cache.memory_pool_host import (
19
19
  MHATokenToKVPoolHost,
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
- from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
22
+ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
23
+ from sglang.srt.metrics.collector import StorageMetricsCollector
23
24
 
24
25
  logger = logging.getLogger(__name__)
25
26
 
@@ -37,17 +38,20 @@ class HiRadixCache(RadixCache):
37
38
  hicache_write_policy: str,
38
39
  hicache_io_backend: str,
39
40
  hicache_mem_layout: str,
41
+ enable_metrics: bool,
42
+ eviction_policy: str = "lru",
40
43
  hicache_storage_backend: Optional[str] = None,
41
44
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
45
  model_name: Optional[str] = None,
43
46
  storage_backend_extra_config: Optional[str] = None,
47
+ is_eagle: bool = False,
44
48
  ):
45
49
 
46
50
  if hicache_io_backend == "direct":
47
51
  if hicache_mem_layout == "page_first":
48
- hicache_mem_layout = "layer_first"
52
+ hicache_mem_layout = "page_first_direct"
49
53
  logger.warning(
50
- "Page first layout is not supported with direct IO backend, switching to layer first layout"
54
+ "Page first layout is not supported with direct IO backend, switching to page first direct layout"
51
55
  )
52
56
 
53
57
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
@@ -73,9 +77,21 @@ class HiRadixCache(RadixCache):
73
77
  self.tp_group = tp_cache_group
74
78
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
75
79
  self.enable_storage = hicache_storage_backend is not None
76
- # todo: customizable storage prefetch threshold and timeout
77
- self.prefetch_threshold = 256
78
- self.prefetch_timeout = 3 # seconds
80
+ self.enable_storage_metrics = self.enable_storage and enable_metrics
81
+
82
+ (
83
+ extra_config,
84
+ prefetch_threshold,
85
+ prefetch_timeout_base,
86
+ prefetch_timeout_per_ki_token,
87
+ ) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
88
+ self.prefetch_threshold = prefetch_threshold
89
+ self.prefetch_timeout_base = prefetch_timeout_base
90
+ self.prefetch_timeout_per_page = (
91
+ page_size / 1024 * prefetch_timeout_per_ki_token
92
+ )
93
+ # TODO: support more timeout check functions
94
+ self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
79
95
  self.prefetch_stop_policy = hicache_storage_prefetch_policy
80
96
 
81
97
  self.load_cache_event = threading.Event()
@@ -90,8 +106,16 @@ class HiRadixCache(RadixCache):
90
106
  storage_backend=hicache_storage_backend,
91
107
  prefetch_threshold=self.prefetch_threshold,
92
108
  model_name=model_name,
93
- storage_backend_extra_config=storage_backend_extra_config,
109
+ storage_backend_extra_config=extra_config,
94
110
  )
111
+ if self.enable_storage_metrics:
112
+ # TODO: support pp
113
+ labels = {
114
+ "storage_backend": hicache_storage_backend,
115
+ "tp_rank": self.cache_controller.tp_rank,
116
+ "dp_rank": self.cache_controller.dp_rank,
117
+ }
118
+ self.metrics_collector = StorageMetricsCollector(labels=labels)
95
119
 
96
120
  # record the nodes with ongoing write through
97
121
  self.ongoing_write_through = {}
@@ -105,8 +129,61 @@ class HiRadixCache(RadixCache):
105
129
  1 if hicache_write_policy == "write_through" else 2
106
130
  )
107
131
  self.load_back_threshold = 10
132
+
108
133
  super().__init__(
109
- req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
134
+ req_to_token_pool,
135
+ token_to_kv_pool_allocator,
136
+ page_size,
137
+ disable=False,
138
+ eviction_policy=eviction_policy,
139
+ is_eagle=is_eagle,
140
+ )
141
+
142
+ def _parse_storage_backend_extra_config(
143
+ self, storage_backend_extra_config: Optional[str]
144
+ ):
145
+ """
146
+ Parse storage backend extra config JSON and extract specific parameters.
147
+
148
+ Args:
149
+ storage_backend_extra_config: JSON string containing extra configuration
150
+
151
+ Returns:
152
+ tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
153
+ """
154
+ # Parse extra config JSON if provided
155
+ extra_config = {}
156
+ if storage_backend_extra_config:
157
+ try:
158
+ extra_config = json.loads(storage_backend_extra_config)
159
+ except Exception as e:
160
+ logger.error(f"Invalid backend extra config JSON: {e}")
161
+ raise e
162
+
163
+ prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
164
+ prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
165
+ prefetch_timeout_per_ki_token = extra_config.pop(
166
+ "prefetch_timeout_per_ki_token", 0.25
167
+ ) # seconds per 1024 tokens
168
+
169
+ if not isinstance(prefetch_threshold, int):
170
+ raise ValueError(
171
+ f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
172
+ )
173
+ if not isinstance(prefetch_timeout_base, (int, float)):
174
+ raise ValueError(
175
+ f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
176
+ )
177
+ if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
178
+ raise ValueError(
179
+ f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
180
+ )
181
+
182
+ return (
183
+ extra_config,
184
+ prefetch_threshold,
185
+ float(prefetch_timeout_base),
186
+ float(prefetch_timeout_per_ki_token),
110
187
  )
111
188
 
112
189
  def reset(self):
@@ -122,11 +199,24 @@ class HiRadixCache(RadixCache):
122
199
  height += 1
123
200
  return height
124
201
 
125
- def clear_storage_backend(self):
202
+ def clear_storage_backend(self) -> bool:
126
203
  if self.enable_storage:
127
- self.cache_controller.storage_backend.clear()
128
- logger.info("Hierarchical cache storage backend cleared successfully!")
129
- return True
204
+ try:
205
+ # Check if the storage backend has a clear method (for nixl backends)
206
+ if hasattr(self.cache_controller.storage_backend, "clear"):
207
+ self.cache_controller.storage_backend.clear()
208
+ logger.info(
209
+ "Hierarchical cache storage backend cleared successfully!"
210
+ )
211
+ return True
212
+ else:
213
+ logger.warning(
214
+ f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
215
+ )
216
+ return False
217
+ except Exception as e:
218
+ logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
219
+ return False
130
220
  else:
131
221
  logger.warning("Hierarchical cache storage backend is not enabled.")
132
222
  return False
@@ -176,53 +266,72 @@ class HiRadixCache(RadixCache):
176
266
  if write_back:
177
267
  # blocking till all write back complete
178
268
  while len(self.ongoing_write_through) > 0:
179
- ack_id = self.cache_controller.ack_write_queue.get()
180
- del self.ongoing_write_through[ack_id]
269
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
270
+ finish_event.synchronize()
271
+ for ack_id in ack_list:
272
+ del self.ongoing_write_through[ack_id]
273
+ self.cache_controller.ack_write_queue.clear()
274
+ assert len(self.ongoing_write_through) == 0
181
275
  return
182
- queue_size = torch.tensor(
183
- self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
184
- )
276
+
277
+ # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
278
+ if len(self.ongoing_write_through) == 0:
279
+ return
280
+
281
+ finish_count = 0
282
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
283
+ if not finish_event.query():
284
+ break
285
+ finish_count += 1
286
+ queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
185
287
  if self.tp_world_size > 1:
186
- # synchrnoize TP workers to make the same update to radix cache
288
+ # synchronize TP workers to make the same update to radix cache
187
289
  torch.distributed.all_reduce(
188
290
  queue_size,
189
291
  op=torch.distributed.ReduceOp.MIN,
190
292
  group=self.tp_group,
191
293
  )
192
- for _ in range(queue_size.item()):
193
- ack_id = self.cache_controller.ack_write_queue.get()
194
- backuped_node = self.ongoing_write_through[ack_id]
195
- self.dec_lock_ref(backuped_node)
196
- del self.ongoing_write_through[ack_id]
197
- if self.enable_storage:
198
- self.write_backup_storage(backuped_node)
294
+
295
+ finish_count = int(queue_size.item())
296
+ while finish_count > 0:
297
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
298
+ finish_event.synchronize()
299
+ for ack_id in ack_list:
300
+ backuped_node = self.ongoing_write_through.pop(ack_id)
301
+ self.dec_lock_ref(backuped_node)
302
+ if self.enable_storage:
303
+ self.write_backup_storage(backuped_node)
304
+ finish_count -= 1
199
305
 
200
306
  def loading_check(self):
201
- while not self.cache_controller.ack_load_queue.empty():
202
- try:
203
- ack_id = self.cache_controller.ack_load_queue.get_nowait()
204
- start_node, end_node = self.ongoing_load_back[ack_id]
205
- self.dec_lock_ref(end_node)
206
- while end_node != start_node:
207
- assert end_node.loading
208
- end_node.loading = False
209
- end_node = end_node.parent
210
- # clear the reference
211
- del self.ongoing_load_back[ack_id]
212
- except Exception:
307
+ finish_count = 0
308
+ for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
309
+ if not finish_event.query():
310
+ # the KV cache loading is still ongoing
213
311
  break
312
+ finish_count += 1
313
+ # no need to sync across TP workers as batch forwarding is synced
314
+ for ack_id in ack_list:
315
+ end_node = self.ongoing_load_back.pop(ack_id)
316
+ self.dec_lock_ref(end_node)
317
+
318
+ # ACK until all events are processed
319
+ del self.cache_controller.ack_load_queue[:finish_count]
214
320
 
215
321
  def evictable_size(self):
216
322
  return self.evictable_size_
217
323
 
218
324
  def evict(self, num_tokens: int):
219
325
  leaves = self._collect_leaves_device()
220
- heapq.heapify(leaves)
326
+ eviction_heap = [
327
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
328
+ ]
329
+ heapq.heapify(eviction_heap)
221
330
 
222
331
  num_evicted = 0
223
332
  write_back_nodes = []
224
- while num_evicted < num_tokens and len(leaves):
225
- x = heapq.heappop(leaves)
333
+ while num_evicted < num_tokens and len(eviction_heap):
334
+ _priority, x = heapq.heappop(eviction_heap)
226
335
 
227
336
  if x.lock_ref > 0:
228
337
  continue
@@ -244,7 +353,8 @@ class HiRadixCache(RadixCache):
244
353
  break
245
354
  else:
246
355
  # all children are evicted or no children
247
- heapq.heappush(leaves, x.parent)
356
+ new_priority = self.eviction_strategy.get_priority(x.parent)
357
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
248
358
 
249
359
  if self.cache_controller.write_policy == "write_back":
250
360
  self.writing_check(write_back=True)
@@ -254,7 +364,7 @@ class HiRadixCache(RadixCache):
254
364
 
255
365
  def _evict_backuped(self, node: TreeNode):
256
366
  # evict a node already written to host
257
- num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
367
+ num_evicted = self.cache_controller.evict_device(node.value)
258
368
  assert num_evicted > 0
259
369
  self.evictable_size_ -= num_evicted
260
370
  node.value = None
@@ -269,11 +379,14 @@ class HiRadixCache(RadixCache):
269
379
 
270
380
  def evict_host(self, num_tokens: int):
271
381
  leaves = self._collect_leaves()
272
- heapq.heapify(leaves)
382
+ eviction_heap = [
383
+ (self.eviction_strategy.get_priority(node), node) for node in leaves
384
+ ]
385
+ heapq.heapify(eviction_heap)
273
386
 
274
387
  num_evicted = 0
275
- while num_evicted < num_tokens and len(leaves):
276
- x = heapq.heappop(leaves)
388
+ while num_evicted < num_tokens and len(eviction_heap):
389
+ _priority, x = heapq.heappop(eviction_heap)
277
390
  if x == self.root_node:
278
391
  break
279
392
  # only evict the host value of evicted nodes
@@ -292,7 +405,8 @@ class HiRadixCache(RadixCache):
292
405
  del x.parent.children[k]
293
406
 
294
407
  if len(x.parent.children) == 0 and x.parent.evicted:
295
- heapq.heappush(leaves, x.parent)
408
+ new_priority = self.eviction_strategy.get_priority(x.parent)
409
+ heapq.heappush(eviction_heap, (new_priority, x.parent))
296
410
 
297
411
  def load_back(
298
412
  self, node: TreeNode, mem_quota: Optional[int] = None
@@ -335,12 +449,11 @@ class HiRadixCache(RadixCache):
335
449
  # no sufficient GPU memory to load back KV caches
336
450
  return None
337
451
 
338
- self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
452
+ self.ongoing_load_back[last_hit_node.id] = last_hit_node
339
453
  offset = 0
340
454
  for node in nodes_to_load:
341
455
  node.value = device_indices[offset : offset + len(node.host_value)]
342
456
  offset += len(node.host_value)
343
- node.loading = True
344
457
  self.evictable_size_ += len(device_indices)
345
458
  self.inc_lock_ref(last_hit_node)
346
459
 
@@ -369,16 +482,22 @@ class HiRadixCache(RadixCache):
369
482
  last_node,
370
483
  )
371
484
 
372
- def ready_to_load_host_cache(self):
373
- producer_index = self.cache_controller.layer_done_counter.next_producer()
374
- self.load_cache_event.set()
375
- return producer_index
485
+ def ready_to_load_host_cache(self) -> int:
486
+ """
487
+ Notify the cache controller to start the KV cache loading.
488
+ Return the consumer index for the schedule batch manager to track.
489
+ """
490
+ return self.cache_controller.start_loading()
376
491
 
377
492
  def check_hicache_events(self):
378
493
  self.writing_check()
379
494
  self.loading_check()
380
495
  if self.enable_storage:
381
496
  self.drain_storage_control_queues()
497
+ if self.enable_storage_metrics:
498
+ self.metrics_collector.log_storage_metrics(
499
+ self.cache_controller.storage_backend.get_stats()
500
+ )
382
501
 
383
502
  def drain_storage_control_queues(self):
384
503
  """
@@ -414,10 +533,13 @@ class HiRadixCache(RadixCache):
414
533
 
415
534
  # process backup acks
416
535
  for _ in range(n_backup):
417
- ack_id = cc.ack_backup_queue.get()
536
+ operation = cc.ack_backup_queue.get()
537
+ ack_id = operation.id
418
538
  entry = self.ongoing_backup.pop(ack_id, None)
419
539
  if entry is not None:
420
540
  entry.release_host()
541
+ if self.enable_storage_metrics:
542
+ self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
421
543
 
422
544
  # release host memory
423
545
  host_indices_list = []
@@ -427,6 +549,15 @@ class HiRadixCache(RadixCache):
427
549
  host_indices = torch.cat(host_indices_list, dim=0)
428
550
  cc.mem_pool_host.free(host_indices)
429
551
 
552
+ # Timeout is linearly increasing with the number of pages
553
+ def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
554
+ # If hash_value has not been computed in timeout_base seconds, terminate it.
555
+ return (
556
+ time.monotonic() - operation.start_time
557
+ > self.prefetch_timeout_base
558
+ + len(operation.hash_value) * self.prefetch_timeout_per_page
559
+ )
560
+
430
561
  def can_terminate_prefetch(self, operation: PrefetchOperation):
431
562
  can_terminate = True
432
563
 
@@ -443,22 +574,27 @@ class HiRadixCache(RadixCache):
443
574
  if self.prefetch_stop_policy == "wait_complete":
444
575
  can_terminate = completed
445
576
  elif self.prefetch_stop_policy == "timeout":
446
- can_terminate = completed or (
447
- time.monotonic() - operation.start_time > self.prefetch_timeout
448
- )
577
+ can_terminate = completed or self.is_prefetch_timeout(operation)
449
578
  else:
450
579
  # unknown prefetch stop policy, just return True
451
580
  return True
452
581
 
582
+ operation_terminated = operation.is_terminated()
453
583
  if self.tp_world_size > 1:
454
- can_terminate = torch.tensor(can_terminate, dtype=torch.int)
584
+ states = torch.tensor(
585
+ [1 - int(can_terminate), int(operation_terminated)],
586
+ dtype=torch.int,
587
+ )
455
588
  torch.distributed.all_reduce(
456
- can_terminate,
457
- op=torch.distributed.ReduceOp.MIN,
589
+ states,
590
+ op=torch.distributed.ReduceOp.MAX,
458
591
  group=self.tp_group,
459
592
  )
460
- can_terminate = bool(can_terminate.item())
461
-
593
+ can_terminate = states[0].item() == 0
594
+ operation_terminated = states[1].item() == 1
595
+ # the operation should be terminated if it is already terminated on any TP worker
596
+ # or it meets the termination condition on all TP workers
597
+ can_terminate = can_terminate or operation_terminated
462
598
  return can_terminate
463
599
 
464
600
  def check_prefetch_progress(self, req_id: str) -> bool:
@@ -485,7 +621,7 @@ class HiRadixCache(RadixCache):
485
621
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
486
622
 
487
623
  min_completed_tokens = completed_tokens
488
- if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
624
+ if self.tp_world_size > 1:
489
625
  # synchrnoize TP workers to make the same update to hiradix cache
490
626
  completed_tokens_tensor = torch.tensor(
491
627
  min_completed_tokens, dtype=torch.int
@@ -500,12 +636,12 @@ class HiRadixCache(RadixCache):
500
636
  written_indices = host_indices[:min_completed_tokens]
501
637
  matched_length = self._insert_helper_host(
502
638
  last_host_node,
503
- fetched_token_ids,
639
+ RadixKey(
640
+ token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
641
+ ),
504
642
  written_indices,
505
643
  hash_value[: min_completed_tokens // self.page_size],
506
644
  )
507
- if len(written_indices):
508
- self.cache_controller.mem_pool_host.update_prefetch(written_indices)
509
645
 
510
646
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
511
647
  self.cache_controller.append_host_mem_release(
@@ -515,10 +651,16 @@ class HiRadixCache(RadixCache):
515
651
  del self.ongoing_prefetch[req_id]
516
652
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
517
653
 
654
+ if self.enable_storage_metrics:
655
+ self.metrics_collector.log_prefetched_tokens(
656
+ min_completed_tokens - matched_length
657
+ )
658
+
518
659
  return True
519
660
 
520
- def match_prefix(self, key: List[int], **kwargs):
661
+ def match_prefix(self, key: RadixKey, **kwargs):
521
662
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
663
+ key.token_ids = self.key_convert_fn(key.token_ids)
522
664
  if self.disable or len(key) == 0:
523
665
  return MatchResult(
524
666
  device_indices=empty_value,
@@ -591,7 +733,9 @@ class HiRadixCache(RadixCache):
591
733
  )
592
734
  self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
593
735
 
594
- def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
736
+ def _insert_helper_host(
737
+ self, node: TreeNode, key: RadixKey, host_value, hash_value
738
+ ):
595
739
  node.last_access_time = time.monotonic()
596
740
  if len(key) == 0:
597
741
  return 0
@@ -625,7 +769,7 @@ class HiRadixCache(RadixCache):
625
769
  node.children[child_key] = new_node
626
770
  return matched_length
627
771
 
628
- def _match_prefix_helper(self, node: TreeNode, key: List):
772
+ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
629
773
  node.last_access_time = time.monotonic()
630
774
  child_key = self.get_child_key_fn(key)
631
775
  value = []
@@ -651,14 +795,13 @@ class HiRadixCache(RadixCache):
651
795
 
652
796
  return value, node
653
797
 
654
- def _split_node(self, key, child: TreeNode, split_len: int):
798
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
655
799
  # child node split into new_node -> child
656
800
  new_node = TreeNode()
657
801
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
658
802
  new_node.parent = child.parent
659
803
  new_node.lock_ref = child.lock_ref
660
804
  new_node.key = child.key[:split_len]
661
- new_node.loading = child.loading
662
805
  new_node.hit_count = child.hit_count
663
806
 
664
807
  # split value and host value if exists
@@ -679,10 +822,16 @@ class HiRadixCache(RadixCache):
679
822
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
680
823
  return new_node
681
824
 
682
- def insert(self, key: List, value, chunked=False):
825
+ def insert(self, key: RadixKey, value=None, chunked=False):
826
+ key.token_ids = self.key_convert_fn(key.token_ids)
827
+
683
828
  if len(key) == 0:
684
829
  return 0
685
830
 
831
+ if self.is_eagle and value is not None:
832
+ # Make sure the value len equal to the EAGLE bigram key len
833
+ value = value[: len(key)]
834
+
686
835
  node = self.root_node
687
836
  child_key = self.get_child_key_fn(key)
688
837
  total_prefix_length = 0
@@ -697,7 +846,6 @@ class HiRadixCache(RadixCache):
697
846
  # change the reference if the node is evicted
698
847
  # this often happens in the case of KV cache recomputation
699
848
  node.value = value[:prefix_len]
700
- self.token_to_kv_pool_host.update_synced(node.host_value)
701
849
  self.evictable_size_ += len(node.value)
702
850
  else:
703
851
  self._inc_hit_count(node, chunked)
@@ -707,7 +855,6 @@ class HiRadixCache(RadixCache):
707
855
  new_node = self._split_node(node.key, node, prefix_len)
708
856
  if new_node.evicted:
709
857
  new_node.value = value[:prefix_len]
710
- self.token_to_kv_pool_host.update_synced(new_node.host_value)
711
858
  self.evictable_size_ += len(new_node.value)
712
859
  else:
713
860
  self._inc_hit_count(new_node, chunked)
@@ -737,7 +884,7 @@ class HiRadixCache(RadixCache):
737
884
  for idx in range(0, len(key), self.page_size):
738
885
  new_node.hash_value.append(
739
886
  self.cache_controller.get_hash_str(
740
- key[idx : idx + self.page_size],
887
+ key.token_ids[idx : idx + self.page_size],
741
888
  prior_hash=last_hash,
742
889
  )
743
890
  )