sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
22
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
+ from sglang.srt.metrics.collector import StorageMetricsCollector
23
24
 
24
25
  logger = logging.getLogger(__name__)
25
26
 
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
37
38
  hicache_write_policy: str,
38
39
  hicache_io_backend: str,
39
40
  hicache_mem_layout: str,
41
+ enable_metrics: bool,
40
42
  hicache_storage_backend: Optional[str] = None,
41
43
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
44
  model_name: Optional[str] = None,
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
73
75
  self.tp_group = tp_cache_group
74
76
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
75
77
  self.enable_storage = hicache_storage_backend is not None
78
+ self.enable_storage_metrics = self.enable_storage and enable_metrics
79
+
76
80
  # todo: customizable storage prefetch threshold and timeout
77
81
  self.prefetch_threshold = 256
78
82
  self.prefetch_timeout = 3 # seconds
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
92
96
  model_name=model_name,
93
97
  storage_backend_extra_config=storage_backend_extra_config,
94
98
  )
99
+ if self.enable_storage_metrics:
100
+ # TODO: support pp
101
+ labels = {
102
+ "storage_backend": hicache_storage_backend,
103
+ "tp_rank": self.cache_controller.tp_rank,
104
+ "dp_rank": self.cache_controller.dp_rank,
105
+ }
106
+ self.metrics_collector = StorageMetricsCollector(labels=labels)
95
107
 
96
108
  # record the nodes with ongoing write through
97
109
  self.ongoing_write_through = {}
@@ -102,10 +114,7 @@ class HiRadixCache(RadixCache):
102
114
  self.ongoing_backup = {}
103
115
  # todo: dynamically adjust the threshold
104
116
  self.write_through_threshold = (
105
- 1 if hicache_write_policy == "write_through" else 3
106
- )
107
- self.write_through_threshold_storage = (
108
- 1 if hicache_write_policy == "write_through" else 3
117
+ 1 if hicache_write_policy == "write_through" else 2
109
118
  )
110
119
  self.load_back_threshold = 10
111
120
  super().__init__(
@@ -125,6 +134,28 @@ class HiRadixCache(RadixCache):
125
134
  height += 1
126
135
  return height
127
136
 
137
+ def clear_storage_backend(self) -> bool:
138
+ if self.enable_storage:
139
+ try:
140
+ # Check if the storage backend has a clear method (for nixl backends)
141
+ if hasattr(self.cache_controller.storage_backend, "clear"):
142
+ self.cache_controller.storage_backend.clear()
143
+ logger.info(
144
+ "Hierarchical cache storage backend cleared successfully!"
145
+ )
146
+ return True
147
+ else:
148
+ logger.warning(
149
+ f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
150
+ )
151
+ return False
152
+ except Exception as e:
153
+ logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
154
+ return False
155
+ else:
156
+ logger.warning("Hierarchical cache storage backend is not enabled.")
157
+ return False
158
+
128
159
  def write_backup(self, node: TreeNode, write_back=False):
129
160
  host_indices = self.cache_controller.write(
130
161
  device_indices=node.value,
@@ -155,8 +186,9 @@ class HiRadixCache(RadixCache):
155
186
  self.ongoing_backup[operation_id] = node
156
187
  node.protect_host()
157
188
 
158
- def inc_hit_count(self, node: TreeNode):
159
- if self.cache_controller.write_policy == "write_back":
189
+ def _inc_hit_count(self, node: TreeNode, chunked=False):
190
+ # skip the hit count update for chunked requests
191
+ if self.cache_controller.write_policy == "write_back" or chunked:
160
192
  return
161
193
  node.hit_count += 1
162
194
 
@@ -164,51 +196,62 @@ class HiRadixCache(RadixCache):
164
196
  if node.hit_count >= self.write_through_threshold:
165
197
  # write to host if the node is not backuped
166
198
  self.write_backup(node)
167
- else:
168
- if (
169
- self.enable_storage
170
- and (not node.backuped_storage)
171
- and node.hit_count >= self.write_through_threshold_storage
172
- ):
173
- # if the node is backuped on host memory but not on storage
174
- self.write_backup_storage(node)
175
199
 
176
200
  def writing_check(self, write_back=False):
177
201
  if write_back:
178
202
  # blocking till all write back complete
179
203
  while len(self.ongoing_write_through) > 0:
180
- ack_id = self.cache_controller.ack_write_queue.get()
181
- del self.ongoing_write_through[ack_id]
204
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
205
+ finish_event.synchronize()
206
+ for ack_id in ack_list:
207
+ del self.ongoing_write_through[ack_id]
208
+ self.cache_controller.ack_write_queue.clear()
209
+ assert len(self.ongoing_write_through) == 0
182
210
  return
183
- queue_size = torch.tensor(
184
- self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
185
- )
211
+
212
+ # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
213
+ if len(self.ongoing_write_through) == 0:
214
+ return
215
+
216
+ finish_count = 0
217
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
218
+ if not finish_event.query():
219
+ break
220
+ finish_count += 1
221
+ queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
186
222
  if self.tp_world_size > 1:
187
- # synchrnoize TP workers to make the same update to radix cache
223
+ # synchronize TP workers to make the same update to radix cache
188
224
  torch.distributed.all_reduce(
189
225
  queue_size,
190
226
  op=torch.distributed.ReduceOp.MIN,
191
227
  group=self.tp_group,
192
228
  )
193
- for _ in range(queue_size.item()):
194
- ack_id = self.cache_controller.ack_write_queue.get()
195
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
196
- del self.ongoing_write_through[ack_id]
229
+
230
+ finish_count = int(queue_size.item())
231
+ while finish_count > 0:
232
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
233
+ finish_event.synchronize()
234
+ for ack_id in ack_list:
235
+ backuped_node = self.ongoing_write_through.pop(ack_id)
236
+ self.dec_lock_ref(backuped_node)
237
+ if self.enable_storage:
238
+ self.write_backup_storage(backuped_node)
239
+ finish_count -= 1
197
240
 
198
241
  def loading_check(self):
199
- while not self.cache_controller.ack_load_queue.empty():
200
- try:
201
- ack_id = self.cache_controller.ack_load_queue.get_nowait()
202
- start_node, end_node = self.ongoing_load_back[ack_id]
203
- self.dec_lock_ref(end_node)
204
- while end_node != start_node:
205
- assert end_node.loading
206
- end_node.loading = False
207
- end_node = end_node.parent
208
- # clear the reference
209
- del self.ongoing_load_back[ack_id]
210
- except Exception:
242
+ finish_count = 0
243
+ for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
244
+ if not finish_event.query():
245
+ # the KV cache loading is still ongoing
211
246
  break
247
+ finish_count += 1
248
+ # no need to sync across TP workers as batch forwarding is synced
249
+ for ack_id in ack_list:
250
+ end_node = self.ongoing_load_back.pop(ack_id)
251
+ self.dec_lock_ref(end_node)
252
+
253
+ # ACK until all events are processed
254
+ del self.cache_controller.ack_load_queue[:finish_count]
212
255
 
213
256
  def evictable_size(self):
214
257
  return self.evictable_size_
@@ -333,12 +376,11 @@ class HiRadixCache(RadixCache):
333
376
  # no sufficient GPU memory to load back KV caches
334
377
  return None
335
378
 
336
- self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
379
+ self.ongoing_load_back[last_hit_node.id] = last_hit_node
337
380
  offset = 0
338
381
  for node in nodes_to_load:
339
382
  node.value = device_indices[offset : offset + len(node.host_value)]
340
383
  offset += len(node.host_value)
341
- node.loading = True
342
384
  self.evictable_size_ += len(device_indices)
343
385
  self.inc_lock_ref(last_hit_node)
344
386
 
@@ -367,66 +409,72 @@ class HiRadixCache(RadixCache):
367
409
  last_node,
368
410
  )
369
411
 
370
- def ready_to_load_host_cache(self):
371
- producer_index = self.cache_controller.layer_done_counter.next_producer()
372
- self.load_cache_event.set()
373
- return producer_index
412
+ def ready_to_load_host_cache(self) -> int:
413
+ """
414
+ Notify the cache controller to start the KV cache loading.
415
+ Return the consumer index for the schedule batch manager to track.
416
+ """
417
+ return self.cache_controller.start_loading()
374
418
 
375
419
  def check_hicache_events(self):
376
420
  self.writing_check()
377
421
  self.loading_check()
378
422
  if self.enable_storage:
379
- self.check_revoked_prefetch()
380
- self.check_backup_progress()
381
-
382
- def check_revoked_prefetch(self):
383
- queue_size = torch.tensor(
384
- self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
385
- )
386
- if self.tp_world_size > 1:
387
- # synchrnoize TP workers to make the same update to hiradix cache
388
- torch.distributed.all_reduce(
389
- queue_size,
390
- op=torch.distributed.ReduceOp.MIN,
391
- group=self.tp_group,
423
+ self.drain_storage_control_queues()
424
+ if self.enable_storage_metrics:
425
+ self.metrics_collector.log_storage_metrics(
426
+ self.cache_controller.storage_backend.get_stats()
392
427
  )
393
- for _ in range(queue_size.item()):
394
- req_id = self.cache_controller.prefetch_revoke_queue.get()
395
- if req_id in self.ongoing_prefetch:
396
- last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
397
- last_host_node.release_host()
398
- del self.ongoing_prefetch[req_id]
399
- self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
400
- else:
401
- # the revoked operation already got terminated
402
- pass
403
428
 
404
- def check_backup_progress(self):
405
- queue_size = torch.tensor(
406
- self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
429
+ def drain_storage_control_queues(self):
430
+ """
431
+ Combine prefetch revoke, backup ack, and host mem release checks
432
+ to minimize TP synchronization and Python overhead.
433
+ """
434
+ cc = self.cache_controller
435
+
436
+ qsizes = torch.tensor(
437
+ [
438
+ cc.prefetch_revoke_queue.qsize(),
439
+ cc.ack_backup_queue.qsize(),
440
+ cc.host_mem_release_queue.qsize(),
441
+ ],
442
+ dtype=torch.int,
407
443
  )
408
444
  if self.tp_world_size > 1:
409
- # synchrnoize TP workers to make the same update to hiradix cache
410
445
  torch.distributed.all_reduce(
411
- queue_size,
412
- op=torch.distributed.ReduceOp.MIN,
413
- group=self.tp_group,
446
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
414
447
  )
415
- for _ in range(queue_size.item()):
416
- ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
417
- host_node = self.ongoing_backup[ack_id]
418
-
419
- if completed_tokens > 0:
420
- if completed_tokens < len(host_node.key):
421
- # backup is only partially successful, split the node
422
- new_node = self._split_node(
423
- host_node.key, host_node, completed_tokens
424
- )
425
- new_node.backuped_storage = True
426
- else:
427
- host_node.backuped_storage = True
428
- host_node.release_host()
429
- del self.ongoing_backup[ack_id]
448
+
449
+ n_revoke, n_backup, n_release = map(int, qsizes.tolist())
450
+
451
+ # process prefetch revokes
452
+ for _ in range(n_revoke):
453
+ req_id = cc.prefetch_revoke_queue.get()
454
+ info = self.ongoing_prefetch.pop(req_id, None)
455
+ if info is not None:
456
+ last_host_node, token_ids, _, _ = info
457
+ last_host_node.release_host()
458
+ cc.prefetch_tokens_occupied -= len(token_ids)
459
+ # else: the revoked operation already got terminated, nothing to do
460
+
461
+ # process backup acks
462
+ for _ in range(n_backup):
463
+ operation = cc.ack_backup_queue.get()
464
+ ack_id = operation.id
465
+ entry = self.ongoing_backup.pop(ack_id, None)
466
+ if entry is not None:
467
+ entry.release_host()
468
+ if self.enable_storage_metrics:
469
+ self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
470
+
471
+ # release host memory
472
+ host_indices_list = []
473
+ for _ in range(n_release):
474
+ host_indices_list.append(cc.host_mem_release_queue.get())
475
+ if host_indices_list:
476
+ host_indices = torch.cat(host_indices_list, dim=0)
477
+ cc.mem_pool_host.free(host_indices)
430
478
 
431
479
  def can_terminate_prefetch(self, operation: PrefetchOperation):
432
480
  can_terminate = True
@@ -451,15 +499,22 @@ class HiRadixCache(RadixCache):
451
499
  # unknown prefetch stop policy, just return True
452
500
  return True
453
501
 
502
+ operation_terminated = operation.is_terminated()
454
503
  if self.tp_world_size > 1:
455
- can_terminate = torch.tensor(can_terminate, dtype=torch.int)
504
+ states = torch.tensor(
505
+ [1 - int(can_terminate), int(operation_terminated)],
506
+ dtype=torch.int,
507
+ )
456
508
  torch.distributed.all_reduce(
457
- can_terminate,
458
- op=torch.distributed.ReduceOp.MIN,
509
+ states,
510
+ op=torch.distributed.ReduceOp.MAX,
459
511
  group=self.tp_group,
460
512
  )
461
- can_terminate = bool(can_terminate.item())
462
-
513
+ can_terminate = states[0].item() == 0
514
+ operation_terminated = states[1].item() == 1
515
+ # the operation should be terminated if it is already terminated on any TP worker
516
+ # or it meets the termination condition on all TP workers
517
+ can_terminate = can_terminate or operation_terminated
463
518
  return can_terminate
464
519
 
465
520
  def check_prefetch_progress(self, req_id: str) -> bool:
@@ -486,7 +541,7 @@ class HiRadixCache(RadixCache):
486
541
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
487
542
 
488
543
  min_completed_tokens = completed_tokens
489
- if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
544
+ if self.tp_world_size > 1:
490
545
  # synchrnoize TP workers to make the same update to hiradix cache
491
546
  completed_tokens_tensor = torch.tensor(
492
547
  min_completed_tokens, dtype=torch.int
@@ -509,13 +564,18 @@ class HiRadixCache(RadixCache):
509
564
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
510
565
 
511
566
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
512
- self.cache_controller.mem_pool_host.free(
567
+ self.cache_controller.append_host_mem_release(
513
568
  host_indices[min_completed_tokens:completed_tokens]
514
569
  )
515
570
  last_host_node.release_host()
516
571
  del self.ongoing_prefetch[req_id]
517
572
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
518
573
 
574
+ if self.enable_storage_metrics:
575
+ self.metrics_collector.log_prefetched_tokens(
576
+ min_completed_tokens - matched_length
577
+ )
578
+
519
579
  return True
520
580
 
521
581
  def match_prefix(self, key: List[int], **kwargs):
@@ -565,7 +625,11 @@ class HiRadixCache(RadixCache):
565
625
  len(new_input_tokens) % self.page_size
566
626
  )
567
627
  new_input_tokens = new_input_tokens[:prefetch_length]
568
- if not self.enable_storage or prefetch_length < self.prefetch_threshold:
628
+ if (
629
+ not self.enable_storage
630
+ or prefetch_length < self.prefetch_threshold
631
+ or self.cache_controller.prefetch_rate_limited()
632
+ ):
569
633
  return
570
634
 
571
635
  last_host_node.protect_host()
@@ -573,6 +637,10 @@ class HiRadixCache(RadixCache):
573
637
  if host_indices is None:
574
638
  self.evict_host(prefetch_length)
575
639
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
640
+ if host_indices is None:
641
+ last_host_node.release_host()
642
+ # no sufficient host memory for prefetch
643
+ return
576
644
  operation = self.cache_controller.prefetch(
577
645
  req_id, host_indices, new_input_tokens, last_hash
578
646
  )
@@ -651,7 +719,6 @@ class HiRadixCache(RadixCache):
651
719
  new_node.parent = child.parent
652
720
  new_node.lock_ref = child.lock_ref
653
721
  new_node.key = child.key[:split_len]
654
- new_node.loading = child.loading
655
722
  new_node.hit_count = child.hit_count
656
723
 
657
724
  # split value and host value if exists
@@ -672,11 +739,11 @@ class HiRadixCache(RadixCache):
672
739
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
673
740
  return new_node
674
741
 
675
- def _insert_helper(self, node: TreeNode, key: List, value):
676
- node.last_access_time = time.monotonic()
742
+ def insert(self, key: List, value, chunked=False):
677
743
  if len(key) == 0:
678
744
  return 0
679
745
 
746
+ node = self.root_node
680
747
  child_key = self.get_child_key_fn(key)
681
748
  total_prefix_length = 0
682
749
 
@@ -693,7 +760,7 @@ class HiRadixCache(RadixCache):
693
760
  self.token_to_kv_pool_host.update_synced(node.host_value)
694
761
  self.evictable_size_ += len(node.value)
695
762
  else:
696
- self.inc_hit_count(node)
763
+ self._inc_hit_count(node, chunked)
697
764
  total_prefix_length += prefix_len
698
765
  else:
699
766
  # partial match, split the node
@@ -703,7 +770,7 @@ class HiRadixCache(RadixCache):
703
770
  self.token_to_kv_pool_host.update_synced(new_node.host_value)
704
771
  self.evictable_size_ += len(new_node.value)
705
772
  else:
706
- self.inc_hit_count(new_node)
773
+ self._inc_hit_count(new_node, chunked)
707
774
  total_prefix_length += prefix_len
708
775
  node = new_node
709
776
 
@@ -737,7 +804,7 @@ class HiRadixCache(RadixCache):
737
804
  last_hash = new_node.hash_value[-1]
738
805
 
739
806
  if self.cache_controller.write_policy != "write_back":
740
- self.inc_hit_count(new_node)
807
+ self._inc_hit_count(new_node, chunked)
741
808
  return total_prefix_length
742
809
 
743
810
  def _collect_leaves_device(self):
@@ -764,3 +831,19 @@ class HiRadixCache(RadixCache):
764
831
  if not cur_child.evicted:
765
832
  stack.append(cur_child)
766
833
  return ret_list
834
+
835
+ def release_aborted_request(self, rid: str):
836
+ if rid not in self.ongoing_prefetch:
837
+ return
838
+
839
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
840
+ if operation.host_indices is None:
841
+ return
842
+
843
+ completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
844
+ if self.tp_world_size > 1:
845
+ torch.distributed.barrier(group=self.tp_group)
846
+ last_host_node.release_host()
847
+ del self.ongoing_prefetch[rid]
848
+ self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
849
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
183
183
  self.req_to_token_pool.free(req.req_pool_idx)
184
184
  self.dec_lock_ref(req.last_node)
185
185
 
186
- def cache_unfinished_req(self, req: Req):
186
+ def cache_unfinished_req(self, req: Req, chunked=False):
187
187
  """Cache request when it is unfinished."""
188
188
  if self.disable:
189
189
  return