sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
22
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
+ from sglang.srt.metrics.collector import StorageMetricsCollector
23
24
 
24
25
  logger = logging.getLogger(__name__)
25
26
 
@@ -37,8 +38,11 @@ class HiRadixCache(RadixCache):
37
38
  hicache_write_policy: str,
38
39
  hicache_io_backend: str,
39
40
  hicache_mem_layout: str,
41
+ enable_metrics: bool,
40
42
  hicache_storage_backend: Optional[str] = None,
41
43
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
44
+ model_name: Optional[str] = None,
45
+ storage_backend_extra_config: Optional[str] = None,
42
46
  ):
43
47
 
44
48
  if hicache_io_backend == "direct":
@@ -71,6 +75,8 @@ class HiRadixCache(RadixCache):
71
75
  self.tp_group = tp_cache_group
72
76
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
73
77
  self.enable_storage = hicache_storage_backend is not None
78
+ self.enable_storage_metrics = self.enable_storage and enable_metrics
79
+
74
80
  # todo: customizable storage prefetch threshold and timeout
75
81
  self.prefetch_threshold = 256
76
82
  self.prefetch_timeout = 3 # seconds
@@ -87,7 +93,17 @@ class HiRadixCache(RadixCache):
87
93
  io_backend=hicache_io_backend,
88
94
  storage_backend=hicache_storage_backend,
89
95
  prefetch_threshold=self.prefetch_threshold,
96
+ model_name=model_name,
97
+ storage_backend_extra_config=storage_backend_extra_config,
90
98
  )
99
+ if self.enable_storage_metrics:
100
+ # TODO: support pp
101
+ labels = {
102
+ "storage_backend": hicache_storage_backend,
103
+ "tp_rank": self.cache_controller.tp_rank,
104
+ "dp_rank": self.cache_controller.dp_rank,
105
+ }
106
+ self.metrics_collector = StorageMetricsCollector(labels=labels)
91
107
 
92
108
  # record the nodes with ongoing write through
93
109
  self.ongoing_write_through = {}
@@ -98,10 +114,7 @@ class HiRadixCache(RadixCache):
98
114
  self.ongoing_backup = {}
99
115
  # todo: dynamically adjust the threshold
100
116
  self.write_through_threshold = (
101
- 1 if hicache_write_policy == "write_through" else 3
102
- )
103
- self.write_through_threshold_storage = (
104
- 1 if hicache_write_policy == "write_through" else 3
117
+ 1 if hicache_write_policy == "write_through" else 2
105
118
  )
106
119
  self.load_back_threshold = 10
107
120
  super().__init__(
@@ -121,6 +134,28 @@ class HiRadixCache(RadixCache):
121
134
  height += 1
122
135
  return height
123
136
 
137
+ def clear_storage_backend(self) -> bool:
138
+ if self.enable_storage:
139
+ try:
140
+ # Check if the storage backend has a clear method (for nixl backends)
141
+ if hasattr(self.cache_controller.storage_backend, "clear"):
142
+ self.cache_controller.storage_backend.clear()
143
+ logger.info(
144
+ "Hierarchical cache storage backend cleared successfully!"
145
+ )
146
+ return True
147
+ else:
148
+ logger.warning(
149
+ f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
150
+ )
151
+ return False
152
+ except Exception as e:
153
+ logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
154
+ return False
155
+ else:
156
+ logger.warning("Hierarchical cache storage backend is not enabled.")
157
+ return False
158
+
124
159
  def write_backup(self, node: TreeNode, write_back=False):
125
160
  host_indices = self.cache_controller.write(
126
161
  device_indices=node.value,
@@ -151,8 +186,9 @@ class HiRadixCache(RadixCache):
151
186
  self.ongoing_backup[operation_id] = node
152
187
  node.protect_host()
153
188
 
154
- def inc_hit_count(self, node: TreeNode):
155
- if self.cache_controller.write_policy == "write_back":
189
+ def _inc_hit_count(self, node: TreeNode, chunked=False):
190
+ # skip the hit count update for chunked requests
191
+ if self.cache_controller.write_policy == "write_back" or chunked:
156
192
  return
157
193
  node.hit_count += 1
158
194
 
@@ -160,51 +196,62 @@ class HiRadixCache(RadixCache):
160
196
  if node.hit_count >= self.write_through_threshold:
161
197
  # write to host if the node is not backuped
162
198
  self.write_backup(node)
163
- else:
164
- if (
165
- self.enable_storage
166
- and (not node.backuped_storage)
167
- and node.hit_count >= self.write_through_threshold_storage
168
- ):
169
- # if the node is backuped on host memory but not on storage
170
- self.write_backup_storage(node)
171
199
 
172
200
  def writing_check(self, write_back=False):
173
201
  if write_back:
174
202
  # blocking till all write back complete
175
203
  while len(self.ongoing_write_through) > 0:
176
- ack_id = self.cache_controller.ack_write_queue.get()
177
- del self.ongoing_write_through[ack_id]
204
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
205
+ finish_event.synchronize()
206
+ for ack_id in ack_list:
207
+ del self.ongoing_write_through[ack_id]
208
+ self.cache_controller.ack_write_queue.clear()
209
+ assert len(self.ongoing_write_through) == 0
178
210
  return
179
- queue_size = torch.tensor(
180
- self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
181
- )
211
+
212
+ # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
213
+ if len(self.ongoing_write_through) == 0:
214
+ return
215
+
216
+ finish_count = 0
217
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
218
+ if not finish_event.query():
219
+ break
220
+ finish_count += 1
221
+ queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
182
222
  if self.tp_world_size > 1:
183
- # synchrnoize TP workers to make the same update to radix cache
223
+ # synchronize TP workers to make the same update to radix cache
184
224
  torch.distributed.all_reduce(
185
225
  queue_size,
186
226
  op=torch.distributed.ReduceOp.MIN,
187
227
  group=self.tp_group,
188
228
  )
189
- for _ in range(queue_size.item()):
190
- ack_id = self.cache_controller.ack_write_queue.get()
191
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
192
- del self.ongoing_write_through[ack_id]
229
+
230
+ finish_count = int(queue_size.item())
231
+ while finish_count > 0:
232
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
233
+ finish_event.synchronize()
234
+ for ack_id in ack_list:
235
+ backuped_node = self.ongoing_write_through.pop(ack_id)
236
+ self.dec_lock_ref(backuped_node)
237
+ if self.enable_storage:
238
+ self.write_backup_storage(backuped_node)
239
+ finish_count -= 1
193
240
 
194
241
  def loading_check(self):
195
- while not self.cache_controller.ack_load_queue.empty():
196
- try:
197
- ack_id = self.cache_controller.ack_load_queue.get_nowait()
198
- start_node, end_node = self.ongoing_load_back[ack_id]
199
- self.dec_lock_ref(end_node)
200
- while end_node != start_node:
201
- assert end_node.loading
202
- end_node.loading = False
203
- end_node = end_node.parent
204
- # clear the reference
205
- del self.ongoing_load_back[ack_id]
206
- except Exception:
242
+ finish_count = 0
243
+ for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
244
+ if not finish_event.query():
245
+ # the KV cache loading is still ongoing
207
246
  break
247
+ finish_count += 1
248
+ # no need to sync across TP workers as batch forwarding is synced
249
+ for ack_id in ack_list:
250
+ end_node = self.ongoing_load_back.pop(ack_id)
251
+ self.dec_lock_ref(end_node)
252
+
253
+ # ACK until all events are processed
254
+ del self.cache_controller.ack_load_queue[:finish_count]
208
255
 
209
256
  def evictable_size(self):
210
257
  return self.evictable_size_
@@ -329,12 +376,11 @@ class HiRadixCache(RadixCache):
329
376
  # no sufficient GPU memory to load back KV caches
330
377
  return None
331
378
 
332
- self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
379
+ self.ongoing_load_back[last_hit_node.id] = last_hit_node
333
380
  offset = 0
334
381
  for node in nodes_to_load:
335
382
  node.value = device_indices[offset : offset + len(node.host_value)]
336
383
  offset += len(node.host_value)
337
- node.loading = True
338
384
  self.evictable_size_ += len(device_indices)
339
385
  self.inc_lock_ref(last_hit_node)
340
386
 
@@ -363,66 +409,72 @@ class HiRadixCache(RadixCache):
363
409
  last_node,
364
410
  )
365
411
 
366
- def ready_to_load_host_cache(self):
367
- producer_index = self.cache_controller.layer_done_counter.next_producer()
368
- self.load_cache_event.set()
369
- return producer_index
412
+ def ready_to_load_host_cache(self) -> int:
413
+ """
414
+ Notify the cache controller to start the KV cache loading.
415
+ Return the consumer index for the schedule batch manager to track.
416
+ """
417
+ return self.cache_controller.start_loading()
370
418
 
371
419
  def check_hicache_events(self):
372
420
  self.writing_check()
373
421
  self.loading_check()
374
422
  if self.enable_storage:
375
- self.check_revoked_prefetch()
376
- self.check_backup_progress()
377
-
378
- def check_revoked_prefetch(self):
379
- queue_size = torch.tensor(
380
- self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
381
- )
382
- if self.tp_world_size > 1:
383
- # synchrnoize TP workers to make the same update to hiradix cache
384
- torch.distributed.all_reduce(
385
- queue_size,
386
- op=torch.distributed.ReduceOp.MIN,
387
- group=self.tp_group,
423
+ self.drain_storage_control_queues()
424
+ if self.enable_storage_metrics:
425
+ self.metrics_collector.log_storage_metrics(
426
+ self.cache_controller.storage_backend.get_stats()
388
427
  )
389
- for _ in range(queue_size.item()):
390
- req_id = self.cache_controller.prefetch_revoke_queue.get()
391
- if req_id in self.ongoing_prefetch:
392
- last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
393
- last_host_node.release_host()
394
- del self.ongoing_prefetch[req_id]
395
- self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
396
- else:
397
- # the revoked operation already got terminated
398
- pass
399
428
 
400
- def check_backup_progress(self):
401
- queue_size = torch.tensor(
402
- self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
429
+ def drain_storage_control_queues(self):
430
+ """
431
+ Combine prefetch revoke, backup ack, and host mem release checks
432
+ to minimize TP synchronization and Python overhead.
433
+ """
434
+ cc = self.cache_controller
435
+
436
+ qsizes = torch.tensor(
437
+ [
438
+ cc.prefetch_revoke_queue.qsize(),
439
+ cc.ack_backup_queue.qsize(),
440
+ cc.host_mem_release_queue.qsize(),
441
+ ],
442
+ dtype=torch.int,
403
443
  )
404
444
  if self.tp_world_size > 1:
405
- # synchrnoize TP workers to make the same update to hiradix cache
406
445
  torch.distributed.all_reduce(
407
- queue_size,
408
- op=torch.distributed.ReduceOp.MIN,
409
- group=self.tp_group,
446
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
410
447
  )
411
- for _ in range(queue_size.item()):
412
- ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
413
- host_node = self.ongoing_backup[ack_id]
414
-
415
- if completed_tokens > 0:
416
- if completed_tokens < len(host_node.key):
417
- # backup is only partially successful, split the node
418
- new_node = self._split_node(
419
- host_node.key, host_node, completed_tokens
420
- )
421
- new_node.backuped_storage = True
422
- else:
423
- host_node.backuped_storage = True
424
- host_node.release_host()
425
- del self.ongoing_backup[ack_id]
448
+
449
+ n_revoke, n_backup, n_release = map(int, qsizes.tolist())
450
+
451
+ # process prefetch revokes
452
+ for _ in range(n_revoke):
453
+ req_id = cc.prefetch_revoke_queue.get()
454
+ info = self.ongoing_prefetch.pop(req_id, None)
455
+ if info is not None:
456
+ last_host_node, token_ids, _, _ = info
457
+ last_host_node.release_host()
458
+ cc.prefetch_tokens_occupied -= len(token_ids)
459
+ # else: the revoked operation already got terminated, nothing to do
460
+
461
+ # process backup acks
462
+ for _ in range(n_backup):
463
+ operation = cc.ack_backup_queue.get()
464
+ ack_id = operation.id
465
+ entry = self.ongoing_backup.pop(ack_id, None)
466
+ if entry is not None:
467
+ entry.release_host()
468
+ if self.enable_storage_metrics:
469
+ self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
470
+
471
+ # release host memory
472
+ host_indices_list = []
473
+ for _ in range(n_release):
474
+ host_indices_list.append(cc.host_mem_release_queue.get())
475
+ if host_indices_list:
476
+ host_indices = torch.cat(host_indices_list, dim=0)
477
+ cc.mem_pool_host.free(host_indices)
426
478
 
427
479
  def can_terminate_prefetch(self, operation: PrefetchOperation):
428
480
  can_terminate = True
@@ -430,9 +482,12 @@ class HiRadixCache(RadixCache):
430
482
  if self.prefetch_stop_policy == "best_effort":
431
483
  return can_terminate
432
484
 
433
- completed = (
434
- operation.completed_tokens == len(operation.hash_value) * self.page_size
435
- )
485
+ if len(operation.hash_value) == 0:
486
+ completed = False
487
+ else:
488
+ completed = (
489
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
490
+ )
436
491
 
437
492
  if self.prefetch_stop_policy == "wait_complete":
438
493
  can_terminate = completed
@@ -444,15 +499,22 @@ class HiRadixCache(RadixCache):
444
499
  # unknown prefetch stop policy, just return True
445
500
  return True
446
501
 
502
+ operation_terminated = operation.is_terminated()
447
503
  if self.tp_world_size > 1:
448
- can_terminate = torch.tensor(can_terminate, dtype=torch.int)
504
+ states = torch.tensor(
505
+ [1 - int(can_terminate), int(operation_terminated)],
506
+ dtype=torch.int,
507
+ )
449
508
  torch.distributed.all_reduce(
450
- can_terminate,
451
- op=torch.distributed.ReduceOp.MIN,
509
+ states,
510
+ op=torch.distributed.ReduceOp.MAX,
452
511
  group=self.tp_group,
453
512
  )
454
- can_terminate = bool(can_terminate.item())
455
-
513
+ can_terminate = states[0].item() == 0
514
+ operation_terminated = states[1].item() == 1
515
+ # the operation should be terminated if it is already terminated on any TP worker
516
+ # or it meets the termination condition on all TP workers
517
+ can_terminate = can_terminate or operation_terminated
456
518
  return can_terminate
457
519
 
458
520
  def check_prefetch_progress(self, req_id: str) -> bool:
@@ -479,7 +541,7 @@ class HiRadixCache(RadixCache):
479
541
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
480
542
 
481
543
  min_completed_tokens = completed_tokens
482
- if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
544
+ if self.tp_world_size > 1:
483
545
  # synchrnoize TP workers to make the same update to hiradix cache
484
546
  completed_tokens_tensor = torch.tensor(
485
547
  min_completed_tokens, dtype=torch.int
@@ -502,13 +564,18 @@ class HiRadixCache(RadixCache):
502
564
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
503
565
 
504
566
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
505
- self.cache_controller.mem_pool_host.free(
567
+ self.cache_controller.append_host_mem_release(
506
568
  host_indices[min_completed_tokens:completed_tokens]
507
569
  )
508
570
  last_host_node.release_host()
509
571
  del self.ongoing_prefetch[req_id]
510
572
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
511
573
 
574
+ if self.enable_storage_metrics:
575
+ self.metrics_collector.log_prefetched_tokens(
576
+ min_completed_tokens - matched_length
577
+ )
578
+
512
579
  return True
513
580
 
514
581
  def match_prefix(self, key: List[int], **kwargs):
@@ -536,6 +603,8 @@ class HiRadixCache(RadixCache):
536
603
  while last_node.evicted:
537
604
  host_hit_length += len(last_node.host_value)
538
605
  last_node = last_node.parent
606
+ while not last_host_node.backuped:
607
+ last_host_node = last_host_node.parent
539
608
 
540
609
  return MatchResult(
541
610
  device_indices=value,
@@ -556,7 +625,11 @@ class HiRadixCache(RadixCache):
556
625
  len(new_input_tokens) % self.page_size
557
626
  )
558
627
  new_input_tokens = new_input_tokens[:prefetch_length]
559
- if not self.enable_storage or prefetch_length < self.prefetch_threshold:
628
+ if (
629
+ not self.enable_storage
630
+ or prefetch_length < self.prefetch_threshold
631
+ or self.cache_controller.prefetch_rate_limited()
632
+ ):
560
633
  return
561
634
 
562
635
  last_host_node.protect_host()
@@ -564,6 +637,10 @@ class HiRadixCache(RadixCache):
564
637
  if host_indices is None:
565
638
  self.evict_host(prefetch_length)
566
639
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
640
+ if host_indices is None:
641
+ last_host_node.release_host()
642
+ # no sufficient host memory for prefetch
643
+ return
567
644
  operation = self.cache_controller.prefetch(
568
645
  req_id, host_indices, new_input_tokens, last_hash
569
646
  )
@@ -642,7 +719,6 @@ class HiRadixCache(RadixCache):
642
719
  new_node.parent = child.parent
643
720
  new_node.lock_ref = child.lock_ref
644
721
  new_node.key = child.key[:split_len]
645
- new_node.loading = child.loading
646
722
  new_node.hit_count = child.hit_count
647
723
 
648
724
  # split value and host value if exists
@@ -663,11 +739,11 @@ class HiRadixCache(RadixCache):
663
739
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
664
740
  return new_node
665
741
 
666
- def _insert_helper(self, node: TreeNode, key: List, value):
667
- node.last_access_time = time.monotonic()
742
+ def insert(self, key: List, value, chunked=False):
668
743
  if len(key) == 0:
669
744
  return 0
670
745
 
746
+ node = self.root_node
671
747
  child_key = self.get_child_key_fn(key)
672
748
  total_prefix_length = 0
673
749
 
@@ -684,7 +760,7 @@ class HiRadixCache(RadixCache):
684
760
  self.token_to_kv_pool_host.update_synced(node.host_value)
685
761
  self.evictable_size_ += len(node.value)
686
762
  else:
687
- self.inc_hit_count(node)
763
+ self._inc_hit_count(node, chunked)
688
764
  total_prefix_length += prefix_len
689
765
  else:
690
766
  # partial match, split the node
@@ -694,7 +770,7 @@ class HiRadixCache(RadixCache):
694
770
  self.token_to_kv_pool_host.update_synced(new_node.host_value)
695
771
  self.evictable_size_ += len(new_node.value)
696
772
  else:
697
- self.inc_hit_count(new_node)
773
+ self._inc_hit_count(new_node, chunked)
698
774
  total_prefix_length += prefix_len
699
775
  node = new_node
700
776
 
@@ -728,7 +804,7 @@ class HiRadixCache(RadixCache):
728
804
  last_hash = new_node.hash_value[-1]
729
805
 
730
806
  if self.cache_controller.write_policy != "write_back":
731
- self.inc_hit_count(new_node)
807
+ self._inc_hit_count(new_node, chunked)
732
808
  return total_prefix_length
733
809
 
734
810
  def _collect_leaves_device(self):
@@ -755,3 +831,19 @@ class HiRadixCache(RadixCache):
755
831
  if not cur_child.evicted:
756
832
  stack.append(cur_child)
757
833
  return ret_list
834
+
835
+ def release_aborted_request(self, rid: str):
836
+ if rid not in self.ongoing_prefetch:
837
+ return
838
+
839
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
840
+ if operation.host_indices is None:
841
+ return
842
+
843
+ completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
844
+ if self.tp_world_size > 1:
845
+ torch.distributed.barrier(group=self.tp_group)
846
+ last_host_node.release_host()
847
+ del self.ongoing_prefetch[rid]
848
+ self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
849
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
183
183
  self.req_to_token_pool.free(req.req_pool_idx)
184
184
  self.dec_lock_ref(req.last_node)
185
185
 
186
- def cache_unfinished_req(self, req: Req):
186
+ def cache_unfinished_req(self, req: Req, chunked=False):
187
187
  """Cache request when it is unfinished."""
188
188
  if self.disable:
189
189
  return