sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
33
33
  get_tensor_model_parallel_world_size,
34
34
  )
35
35
  from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
36
37
  get_attention_tp_rank,
37
38
  get_attention_tp_size,
38
39
  is_dp_attention_enabled,
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42
43
  logger = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
45
64
  class LayerDoneCounter:
46
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
47
66
  self.num_layers = num_layers
48
67
  # extra producer and consumer counters for overlap mode
49
68
  self.num_counters = 3
50
- self.counters = [num_layers] * self.num_counters
51
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
52
- self.producer_index = 0
53
- self.consumer_index = 0
54
-
55
- def next_producer(self):
56
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
57
72
 
58
73
  def update_producer(self):
59
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
60
80
  return self.producer_index
61
81
 
62
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
63
83
  self.consumer_index = index
64
84
 
65
- def increment(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] += 1
68
- self.conditions[self.producer_index].notify_all()
69
-
70
- def wait_until(self, threshold):
71
- with self.conditions[self.consumer_index]:
72
- while self.counters[self.consumer_index] <= threshold:
73
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
74
89
 
75
90
  def reset(self):
76
- with self.conditions[self.producer_index]:
77
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
78
93
 
79
94
 
80
95
  class CacheOperation:
@@ -98,36 +113,30 @@ class CacheOperation:
98
113
  # default priority is the order of creation
99
114
  self.priority = priority if priority is not None else self.id
100
115
 
101
- def merge(self, other: "CacheOperation") -> None:
102
- # multiple operations can be merged into a single operation for batch processing
103
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
104
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
105
- self.priority = min(self.priority, other.priority)
106
- self.node_ids.extend(other.node_ids)
107
-
108
- def split(self, factor) -> List["CacheOperation"]:
109
- # split an operation into smaller operations to reduce the size of intermediate buffers
110
- if factor <= 1:
111
- return [self]
112
-
113
- chunk_size = math.ceil(len(self.host_indices) / factor)
114
- split_ops = []
115
- for i in range(0, len(self.host_indices), chunk_size):
116
- split_ops.append(
117
- CacheOperation(
118
- host_indices=self.host_indices[i : i + chunk_size],
119
- device_indices=self.device_indices[i : i + chunk_size],
120
- node_id=0,
121
- )
122
- )
123
- # Inherit the node_ids on the final chunk
124
- if split_ops:
125
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
126
134
 
127
- return split_ops
128
135
 
129
- def __lt__(self, other: "CacheOperation"):
130
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
131
140
 
132
141
 
133
142
  class TransferBuffer:
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
206
215
  ):
207
216
  self.request_id = request_id
208
217
 
209
- self._done_flag = False
210
218
  self._lock = threading.Lock()
211
-
219
+ self._terminated_flag = False
212
220
  self.start_time = time.monotonic()
213
221
 
214
222
  super().__init__(host_indices, token_ids, last_hash)
215
223
 
216
224
  def increment(self, num_tokens: int):
217
225
  with self._lock:
218
- if self._done_flag:
226
+ if self._terminated_flag:
219
227
  return False
220
228
  self.completed_tokens += num_tokens
221
229
  return True
222
230
 
223
- def mark_done(self):
231
+ def mark_terminate(self):
224
232
  with self._lock:
225
- self._done_flag = True
233
+ self._terminated_flag = True
226
234
 
227
- def is_done(self) -> bool:
228
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
229
237
 
230
238
 
231
239
  class HiCacheController:
@@ -236,7 +244,7 @@ class HiCacheController:
236
244
  mem_pool_host: HostKVCache,
237
245
  page_size: int,
238
246
  tp_group: torch.distributed.ProcessGroup,
239
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
240
248
  write_policy: str = "write_through_selective",
241
249
  io_backend: str = "",
242
250
  storage_backend: Optional[str] = None,
@@ -250,26 +258,21 @@ class HiCacheController:
250
258
  self.write_policy = write_policy
251
259
  self.page_size = page_size
252
260
  self.io_backend = io_backend
253
-
254
261
  self.enable_storage = False
255
262
 
256
- # todo: move backend initialization to storage backend module
257
263
  if storage_backend is not None:
258
264
  self.storage_backend_type = storage_backend
259
265
  from sglang.srt.mem_cache.hicache_storage import get_hash_str
260
266
 
261
267
  self.get_hash_str = get_hash_str
262
-
263
268
  self.storage_config = self._generate_storage_config(
264
269
  model_name, storage_backend_extra_config
265
270
  )
266
- # In MLA backend, only one rank needs to backup the KV cache
271
+ # for MLA models, only one rank needs to backup the KV cache
267
272
  self.backup_skip = (
268
273
  self.storage_config.is_mla_model
269
- # todo: for load balancing, decide which rank to backup the KV cache by hash value
274
+ # todo: load balancing
270
275
  and self.storage_config.tp_rank != 0
271
- # todo: support other storage backends
272
- and self.storage_backend_type in ["file", "mooncake"]
273
276
  )
274
277
 
275
278
  if storage_backend == "file":
@@ -309,12 +312,15 @@ class HiCacheController:
309
312
  raise NotImplementedError(
310
313
  f"Unsupported storage backend: {storage_backend}"
311
314
  )
315
+
312
316
  self.enable_storage = True
313
317
  # todo: threshold policy for prefetching
314
318
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
315
319
  self.prefetch_capacity_limit = int(
316
320
  0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
317
321
  )
322
+ # granularity of batch storage IO operations, in number of pages
323
+ self.storage_batch_size = 128
318
324
  # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
319
325
  self.prefetch_tokens_occupied = 0
320
326
 
@@ -325,15 +331,26 @@ class HiCacheController:
325
331
  self.prefetch_tp_group = torch.distributed.new_group(
326
332
  group_ranks, backend="gloo"
327
333
  )
328
- self.prefetch_io_tp_group = torch.distributed.new_group(
329
- group_ranks, backend="gloo"
330
- )
331
- self.backup_tp_group = torch.distributed.new_group(
332
- group_ranks, backend="gloo"
333
- )
334
334
 
335
- self.load_cache_event = load_cache_event
336
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
335
+ # Select the get and set functions
336
+ self.page_get_func = self._generic_page_get
337
+ self.page_set_func = self._generic_page_set
338
+ self.batch_exists_func = self.storage_backend.batch_exists
339
+ self.is_3fs_zerocopy = (
340
+ self.storage_backend_type == "hf3fs"
341
+ and self.mem_pool_host.layout == "page_first"
342
+ )
343
+ if self.storage_backend_type == "mooncake":
344
+ self.page_get_func = self._mooncake_page_get
345
+ self.page_set_func = self._mooncake_page_set
346
+ elif self.is_3fs_zerocopy:
347
+ self.page_get_func = self._3fs_zero_copy_page_get
348
+ self.page_set_func = self._3fs_zero_copy_page_set
349
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
350
+
351
+ self.device = self.mem_pool_device.device
352
+ self.layer_num = self.mem_pool_device.layer_num
353
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
337
354
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
338
355
 
339
356
  if write_policy not in [
@@ -343,11 +360,11 @@ class HiCacheController:
343
360
  ]:
344
361
  raise ValueError(f"Invalid write policy: {write_policy}")
345
362
 
346
- self.write_queue = PriorityQueue()
347
- self.load_queue = PriorityQueue()
348
-
349
- self.ack_write_queue = Queue()
350
- self.ack_load_queue = Queue()
363
+ # self.write_queue = PriorityQueue[CacheOperation]()
364
+ self.load_queue: List[CacheOperation] = []
365
+ self.write_queue: List[CacheOperation] = []
366
+ self.ack_load_queue: List[HiCacheAck] = []
367
+ self.ack_write_queue: List[HiCacheAck] = []
351
368
 
352
369
  self.stop_event = threading.Event()
353
370
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -358,16 +375,6 @@ class HiCacheController:
358
375
  self.write_stream = torch.cuda.Stream()
359
376
  self.load_stream = torch.cuda.Stream()
360
377
 
361
- self.write_thread = threading.Thread(
362
- target=self.write_thread_func_direct, daemon=True
363
- )
364
- self.load_thread = threading.Thread(
365
- target=self.load_thread_func_layer_by_layer, daemon=True
366
- )
367
-
368
- self.write_thread.start()
369
- self.load_thread.start()
370
-
371
378
  if self.enable_storage:
372
379
  self.prefetch_thread = threading.Thread(
373
380
  target=self.prefetch_thread_func, daemon=True
@@ -380,6 +387,7 @@ class HiCacheController:
380
387
 
381
388
  self.prefetch_revoke_queue = Queue()
382
389
  self.ack_backup_queue = Queue()
390
+ self.host_mem_release_queue = Queue()
383
391
 
384
392
  self.prefetch_thread.start()
385
393
  self.backup_thread.start()
@@ -393,9 +401,11 @@ class HiCacheController:
393
401
  if is_dp_attention_enabled():
394
402
  self.tp_rank = get_attention_tp_rank()
395
403
  self.tp_size = get_attention_tp_size()
404
+ self.dp_rank = get_attention_dp_rank()
396
405
  else:
397
406
  self.tp_rank = get_tensor_model_parallel_rank()
398
407
  self.tp_size = get_tensor_model_parallel_world_size()
408
+ self.dp_rank = 0
399
409
 
400
410
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
401
411
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
@@ -414,21 +424,20 @@ class HiCacheController:
414
424
  tp_rank=self.tp_rank,
415
425
  tp_size=self.tp_size,
416
426
  is_mla_model=is_mla_backend,
427
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
417
428
  model_name=model_name,
418
429
  extra_config=extra_config,
419
430
  )
420
431
 
421
432
  def reset(self):
422
433
  self.stop_event.set()
423
- self.write_thread.join()
424
- self.load_thread.join()
425
434
 
426
- self.write_queue.queue.clear()
427
- self.load_queue.queue.clear()
435
+ self.write_queue.clear()
436
+ self.load_queue.clear()
428
437
  self.write_buffer.clear()
429
438
  self.load_buffer.clear()
430
- self.ack_write_queue.queue.clear()
431
- self.ack_load_queue.queue.clear()
439
+ self.ack_write_queue.clear()
440
+ self.ack_load_queue.clear()
432
441
  if self.enable_storage:
433
442
  self.prefetch_thread.join()
434
443
  self.backup_thread.join()
@@ -437,15 +446,7 @@ class HiCacheController:
437
446
  self.prefetch_revoke_queue.queue.clear()
438
447
  self.ack_backup_queue.queue.clear()
439
448
 
440
- self.write_thread = threading.Thread(
441
- target=self.write_thread_func_direct, daemon=True
442
- )
443
- self.load_thread = threading.Thread(
444
- target=self.load_thread_func_layer_by_layer, daemon=True
445
- )
446
449
  self.stop_event.clear()
447
- self.write_thread.start()
448
- self.load_thread.start()
449
450
 
450
451
  if self.enable_storage:
451
452
  self.prefetch_thread = threading.Thread(
@@ -461,7 +462,7 @@ class HiCacheController:
461
462
  self,
462
463
  device_indices: torch.Tensor,
463
464
  priority: Optional[int] = None,
464
- node_id: int = 0,
465
+ node_id: int = -1,
465
466
  ) -> Optional[torch.Tensor]:
466
467
  """
467
468
  Back up KV caches from device memory to host memory.
@@ -470,17 +471,46 @@ class HiCacheController:
470
471
  if host_indices is None:
471
472
  return None
472
473
  self.mem_pool_host.protect_write(host_indices)
473
- torch.cuda.current_stream().synchronize()
474
- self.write_queue.put(
474
+ self.write_queue.append(
475
475
  CacheOperation(host_indices, device_indices, node_id, priority)
476
476
  )
477
+ self.start_writing()
477
478
  return host_indices
478
479
 
480
+ def start_writing(self) -> None:
481
+ if len(self.write_queue) == 0:
482
+ return
483
+
484
+ op = CacheOperation.merge_ops(self.write_queue)
485
+ host_indices, device_indices = self.move_indices(op)
486
+ self.write_queue.clear()
487
+
488
+ start_event = torch.cuda.Event()
489
+ finish_event = torch.cuda.Event()
490
+
491
+ start_event.record()
492
+ with torch.cuda.stream(self.write_stream):
493
+ start_event.wait(self.write_stream)
494
+ self.mem_pool_host.backup_from_device_all_layer(
495
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
496
+ )
497
+ self.mem_pool_host.complete_io(op.host_indices)
498
+ finish_event.record()
499
+ # NOTE: We must save the host indices and device indices here,
500
+ # this is because we need to guarantee that these tensors are
501
+ # still alive when the write stream is executing.
502
+ if host_indices.is_cuda:
503
+ host_indices.record_stream(self.write_stream)
504
+ if device_indices.is_cuda:
505
+ device_indices.record_stream(self.write_stream)
506
+
507
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
508
+
479
509
  def load(
480
510
  self,
481
511
  host_indices: torch.Tensor,
482
512
  priority: Optional[int] = None,
483
- node_id: int = 0,
513
+ node_id: int = -1,
484
514
  ) -> Optional[torch.Tensor]:
485
515
  """
486
516
  Load KV caches from host memory to device memory.
@@ -489,17 +519,18 @@ class HiCacheController:
489
519
  if device_indices is None:
490
520
  return None
491
521
  self.mem_pool_host.protect_load(host_indices)
492
- # to ensure the device indices are ready before accessed by another CUDA stream
493
- torch.cuda.current_stream().synchronize()
494
- self.load_queue.put(
522
+ self.load_queue.append(
495
523
  CacheOperation(host_indices, device_indices, node_id, priority)
496
524
  )
497
525
  return device_indices
498
526
 
499
- def move_indices(self, host_indices, device_indices):
527
+ def move_indices(self, op: CacheOperation):
528
+ host_indices, device_indices = op.host_indices, op.device_indices
500
529
  # move indices to GPU if using kernels, to host if using direct indexing
501
530
  if self.io_backend == "kernel":
502
- return host_indices.to(self.mem_pool_device.device), device_indices
531
+ if not host_indices.is_cuda:
532
+ host_indices = host_indices.to(self.device, non_blocking=True)
533
+ return host_indices, device_indices
503
534
  elif self.io_backend == "direct":
504
535
  device_indices = device_indices.cpu()
505
536
  host_indices, idx = host_indices.sort()
@@ -507,58 +538,20 @@ class HiCacheController:
507
538
  else:
508
539
  raise ValueError(f"Unsupported io backend")
509
540
 
510
- def write_thread_func_direct(self):
511
- """
512
- Directly write through KV caches to host memory without buffering.
513
- """
514
- torch.cuda.set_stream(self.write_stream)
515
- while not self.stop_event.is_set():
516
- try:
517
- operation = self.write_queue.get(block=True, timeout=1)
518
- host_indices, device_indices = self.move_indices(
519
- operation.host_indices, operation.device_indices
520
- )
521
- self.mem_pool_host.backup_from_device_all_layer(
522
- self.mem_pool_device, host_indices, device_indices, self.io_backend
523
- )
524
- self.write_stream.synchronize()
525
- self.mem_pool_host.complete_io(operation.host_indices)
526
- for node_id in operation.node_ids:
527
- if node_id != 0:
528
- self.ack_write_queue.put(node_id)
529
- except Empty:
530
- continue
531
- except Exception as e:
532
- logger.error(e)
541
+ def start_loading(self) -> int:
542
+ if len(self.load_queue) == 0:
543
+ return -1
533
544
 
534
- def load_thread_func_layer_by_layer(self):
535
- """
536
- Load KV caches from host memory to device memory layer by layer.
537
- """
538
- torch.cuda.set_stream(self.load_stream)
539
- while not self.stop_event.is_set():
540
- self.load_cache_event.wait(timeout=1)
541
- if not self.load_cache_event.is_set():
542
- continue
543
- self.load_cache_event.clear()
544
- self.layer_done_counter.update_producer()
545
-
546
- batch_operation = None
547
- while self.load_queue.qsize() > 0:
548
- op = self.load_queue.get(block=True)
549
- if batch_operation is None:
550
- batch_operation = op
551
- else:
552
- batch_operation.merge(op)
553
- if batch_operation is None:
554
- continue
545
+ producer_id = self.layer_done_counter.update_producer()
546
+ op = CacheOperation.merge_ops(self.load_queue)
547
+ host_indices, device_indices = self.move_indices(op)
548
+ self.load_queue.clear()
549
+ producer_event = self.layer_done_counter.events[producer_id]
550
+ producer_event.start_event.record()
555
551
 
556
- # start layer-wise KV cache transfer from CPU to GPU
557
- self.layer_done_counter.reset()
558
- host_indices, device_indices = self.move_indices(
559
- batch_operation.host_indices, batch_operation.device_indices
560
- )
561
- for i in range(self.mem_pool_host.layer_num):
552
+ with torch.cuda.stream(self.load_stream):
553
+ producer_event.start_event.wait(self.load_stream)
554
+ for i in range(self.layer_num):
562
555
  self.mem_pool_host.load_to_device_per_layer(
563
556
  self.mem_pool_device,
564
557
  host_indices,
@@ -566,13 +559,24 @@ class HiCacheController:
566
559
  i,
567
560
  self.io_backend,
568
561
  )
569
- self.load_stream.synchronize()
570
- self.layer_done_counter.increment()
571
-
572
- self.mem_pool_host.complete_io(batch_operation.host_indices)
573
- for node_id in batch_operation.node_ids:
574
- if node_id != 0:
575
- self.ack_load_queue.put(node_id)
562
+ producer_event.complete(i)
563
+ self.mem_pool_host.complete_io(op.host_indices)
564
+ # NOTE: We must save the host indices and device indices here,
565
+ # this is because we need to guarantee that these tensors are
566
+ # still alive when the load stream is executing.
567
+ if host_indices.is_cuda:
568
+ host_indices.record_stream(self.load_stream)
569
+ if device_indices.is_cuda:
570
+ device_indices.record_stream(self.load_stream)
571
+
572
+ self.ack_load_queue.append(
573
+ HiCacheAck(
574
+ start_event=producer_event.start_event,
575
+ finish_event=producer_event.finish_event,
576
+ node_ids=op.node_ids,
577
+ )
578
+ )
579
+ return producer_id
576
580
 
577
581
  def evict_device(
578
582
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
@@ -615,31 +619,41 @@ class HiCacheController:
615
619
  return operation
616
620
 
617
621
  def terminate_prefetch(self, operation):
618
- operation.mark_done()
622
+ operation.mark_terminate()
619
623
  return operation.completed_tokens, operation.hash_value
620
624
 
621
- # zero copy
625
+ def append_host_mem_release(self, host_indices: torch.Tensor):
626
+ chunks = host_indices.split(self.mem_pool_host.page_size)
627
+ for chunk in chunks:
628
+ self.host_mem_release_queue.put(chunk)
629
+
630
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
631
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
632
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
633
+ return hit_page_num
634
+
622
635
  def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
623
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
636
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
624
637
  hash_values, host_indices
625
638
  )
626
639
  page_data = self.storage_backend.batch_get(hashes, dsts)
627
640
  if page_data:
628
- operation.increment(self.page_size * len(hashes))
641
+ inc = self.page_size * len(hashes) // factor
642
+ operation.increment(inc)
629
643
  else:
630
644
  logger.warning(
631
645
  f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
632
646
  )
633
647
 
634
- # zero copy
635
648
  def _mooncake_page_get(self, operation, hash_values, host_indices):
636
649
  key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
637
650
  hash_values,
638
651
  host_indices,
652
+ self.storage_config.tp_rank,
639
653
  )
640
654
  get_result = self.storage_backend.batch_get(
641
655
  key_strs,
642
- target_location=buffer_ptrs,
656
+ target_locations=buffer_ptrs,
643
657
  target_sizes=buffer_sizes,
644
658
  )
645
659
  if get_result != len(hash_values):
@@ -649,12 +663,10 @@ class HiCacheController:
649
663
  if get_result != 0:
650
664
  operation.increment(get_result * self.page_size)
651
665
 
652
- # non-zero copy
653
666
  def _generic_page_get(self, operation, hash_values, host_indices):
654
- # todo: zero copy
655
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
656
- hash_values
657
- )
667
+ dummy_page_dst = [
668
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
669
+ ]
658
670
  page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
659
671
  if page_data is None:
660
672
  return
@@ -664,49 +676,36 @@ class HiCacheController:
664
676
  f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
665
677
  )
666
678
  break
667
- if operation.increment(self.page_size):
668
- self.mem_pool_host.set_from_flat_data_page(
669
- host_indices[i * self.page_size],
670
- page_data[i],
671
- )
672
- else:
673
- break
679
+ # Must set the data before increasing the completed tokens.
680
+ # Otherwise this page may be read before being set.
681
+ self.mem_pool_host.set_from_flat_data_page(
682
+ host_indices[i * self.page_size],
683
+ page_data[i],
684
+ )
685
+ if not operation.increment(self.page_size):
686
+ break # Operation terminated by controller
674
687
 
675
688
  def _page_transfer(self, operation):
676
- # Select the get function and batch size
677
- if self.is_mooncake_backend():
678
- get_func = self._mooncake_page_get
679
- batch_size = 128
680
- elif self.storage_backend_type == "hf3fs":
681
- if self.mem_pool_host.layout == "page_first":
682
- get_func = self._3fs_zero_copy_page_get
683
- elif self.mem_pool_host.layout == "layer_first":
684
- get_func = self._generic_page_get
685
- batch_size = 128
686
- else:
687
- get_func = self._generic_page_get
688
- batch_size = 8
689
-
690
689
  # Transfer batch by batch
691
- for i in range(0, len(operation.hash_value), batch_size):
692
- batch_hashes = operation.hash_value[i : i + batch_size]
690
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
691
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
693
692
  batch_host_indices = operation.host_indices[
694
693
  i * self.page_size : (i + len(batch_hashes)) * self.page_size
695
694
  ]
696
695
  prev_completed_tokens = operation.completed_tokens
697
696
  # Get one batch token, and update the completed_tokens if succeed
698
- get_func(operation, batch_hashes, batch_host_indices)
697
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
699
698
  # Check termination
700
699
  if (
701
700
  operation.completed_tokens
702
701
  != prev_completed_tokens + len(batch_hashes) * self.page_size
703
702
  ):
703
+ operation.mark_terminate()
704
704
  break # Some operations fail or operation terminated by controller
705
705
  # release pre-allocated memory
706
- self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
707
-
708
- def is_mooncake_backend(self):
709
- return self.storage_backend_type == "mooncake"
706
+ self.append_host_mem_release(
707
+ operation.host_indices[operation.completed_tokens :]
708
+ )
710
709
 
711
710
  def prefetch_io_aux_func(self):
712
711
  """
@@ -716,47 +715,49 @@ class HiCacheController:
716
715
  try:
717
716
  operation = self.prefetch_buffer.get(block=True, timeout=1)
718
717
  self._page_transfer(operation)
719
-
720
- if self.tp_world_size > 1:
721
- # to ensure all TP workers release the host memory at the same time
722
- torch.distributed.barrier(group=self.prefetch_io_tp_group)
723
718
  # operation terminated by controller, release pre-allocated memory
724
- self.mem_pool_host.free(
719
+ self.append_host_mem_release(
725
720
  operation.host_indices[operation.completed_tokens :]
726
721
  )
727
722
  except Empty:
728
723
  continue
729
724
 
730
- def prefetch_rate_limit_check(self) -> bool:
725
+ def prefetch_rate_limited(self) -> bool:
731
726
  """
732
727
  Rate limit the prefetching operations to avoid overwhelming the storage backend.
733
728
  """
734
729
  # cancel prefetch if too much memory is occupied
735
730
  if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
736
- return False
731
+ return True
737
732
  # todo: more sophisticated rate limiting based on storage backend performance
738
- return True
733
+ return False
739
734
 
740
- def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
735
+ def _storage_hit_query(self, operation) -> tuple[list[str], int]:
741
736
  last_hash = operation.last_hash
742
737
  tokens_to_fetch = operation.token_ids
743
738
 
744
739
  storage_query_count = 0
745
- remaining_tokens = len(tokens_to_fetch)
746
740
  hash_value = []
747
- while remaining_tokens >= self.page_size:
748
- last_hash = self.get_hash_str(
749
- tokens_to_fetch[
750
- storage_query_count : storage_query_count + self.page_size
751
- ],
752
- last_hash,
741
+
742
+ for start in range(
743
+ 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
744
+ ):
745
+ end = min(
746
+ start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
753
747
  )
754
- hash_value.append(last_hash)
755
- storage_query_count += self.page_size
756
- remaining_tokens -= self.page_size
757
- # deferring to batch exists
758
- hit_page_num = self.storage_backend.batch_exists(hash_value)
759
- return hash_value[:hit_page_num], hit_page_num * self.page_size
748
+ batch_tokens = tokens_to_fetch[start:end]
749
+ batch_hashes = []
750
+ for i in range(0, len(batch_tokens), self.page_size):
751
+ last_hash = self.get_hash_str(
752
+ batch_tokens[i : i + self.page_size], last_hash
753
+ )
754
+ batch_hashes.append(last_hash)
755
+ hit_page_num = self.batch_exists_func(batch_hashes)
756
+ hash_value.extend(batch_hashes[:hit_page_num])
757
+ storage_query_count += hit_page_num * self.page_size
758
+ if hit_page_num < len(batch_hashes):
759
+ break
760
+ return hash_value, storage_query_count
760
761
 
761
762
  def prefetch_thread_func(self):
762
763
  """
@@ -771,13 +772,7 @@ class HiCacheController:
771
772
  if operation is None:
772
773
  continue
773
774
 
774
- if (
775
- operation.host_indices is not None
776
- ) and self.prefetch_rate_limit_check():
777
- hash_value, storage_hit_count = self._generic_storage_hit_query(
778
- operation
779
- )
780
-
775
+ hash_value, storage_hit_count = self._storage_hit_query(operation)
781
776
  if self.tp_world_size > 1:
782
777
  storage_hit_count_tensor = torch.tensor(
783
778
  storage_hit_count, dtype=torch.int
@@ -792,8 +787,7 @@ class HiCacheController:
792
787
  if storage_hit_count < self.prefetch_threshold:
793
788
  # not to prefetch if not enough benefits
794
789
  self.prefetch_revoke_queue.put(operation.request_id)
795
- if operation.host_indices is not None:
796
- self.mem_pool_host.free(operation.host_indices)
790
+ self.append_host_mem_release(operation.host_indices)
797
791
  logger.debug(
798
792
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
799
793
  )
@@ -802,7 +796,9 @@ class HiCacheController:
802
796
  : (storage_hit_count // self.page_size)
803
797
  ]
804
798
  # free the pre-allocated memory for pages that are not hit
805
- self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
799
+ self.append_host_mem_release(
800
+ operation.host_indices[storage_hit_count:]
801
+ )
806
802
  operation.host_indices = operation.host_indices[:storage_hit_count]
807
803
  logger.debug(
808
804
  f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
@@ -838,45 +834,33 @@ class HiCacheController:
838
834
  key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
839
835
  hash_values,
840
836
  host_indices,
837
+ self.storage_config.tp_rank,
841
838
  )
842
839
  success = self.storage_backend.batch_set(
843
840
  key_strs,
844
- target_location=buffer_ptrs,
841
+ target_locations=buffer_ptrs,
845
842
  target_sizes=buffer_sizes,
846
843
  )
847
844
  return success
848
845
 
849
846
  # zero copy
850
847
  def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
851
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
848
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
852
849
  hash_values, host_indices
853
850
  )
854
851
  return self.storage_backend.batch_set(hashes, dsts)
855
852
 
856
853
  # Backup batch by batch
857
854
  def _page_backup(self, operation):
858
- # Select the set function and batch size
859
- if self.is_mooncake_backend():
860
- backup_set_func = self._mooncake_page_set
861
- batch_size = 128
862
- elif self.storage_backend_type == "hf3fs":
863
- if self.mem_pool_host.layout == "page_first":
864
- backup_set_func = self._3fs_zero_copy_page_set
865
- elif self.mem_pool_host.layout == "layer_first":
866
- backup_set_func = self._generic_page_set
867
- batch_size = 128
868
- else:
869
- backup_set_func = self._generic_page_set
870
- batch_size = 8
871
855
  # Backup batch by batch
872
- for i in range(0, len(operation.hash_value), batch_size):
873
- batch_hashes = operation.hash_value[i : i + batch_size]
856
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
857
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
874
858
  batch_host_indices = operation.host_indices[
875
859
  i * self.page_size : (i + len(batch_hashes)) * self.page_size
876
860
  ]
877
861
  # Set one batch token, and record if success.
878
862
  # todo: allow partial success
879
- success = backup_set_func(batch_hashes, batch_host_indices)
863
+ success = self.page_set_func(batch_hashes, batch_host_indices)
880
864
  if not success:
881
865
  logger.warning(
882
866
  f"Write page to storage: {len(batch_hashes)} pages failed."
@@ -896,27 +880,7 @@ class HiCacheController:
896
880
 
897
881
  if not self.backup_skip:
898
882
  self._page_backup(operation)
899
- min_completed_tokens = operation.completed_tokens
900
- else:
901
- min_completed_tokens = len(operation.token_ids)
902
-
903
- if self.tp_world_size > 1:
904
- completed_tokens_tensor = torch.tensor(
905
- min_completed_tokens, dtype=torch.int
906
- )
907
- torch.distributed.all_reduce(
908
- completed_tokens_tensor,
909
- op=torch.distributed.ReduceOp.MIN,
910
- group=self.backup_tp_group,
911
- )
912
- min_completed_tokens = completed_tokens_tensor.item()
913
-
914
- self.ack_backup_queue.put(
915
- (
916
- operation.id,
917
- min_completed_tokens,
918
- )
919
- )
883
+ self.ack_backup_queue.put(operation)
920
884
 
921
885
  except Empty:
922
886
  continue