sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -18,53 +18,78 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
25
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
26
+
25
27
  if TYPE_CHECKING:
26
28
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
29
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28
30
 
29
- from sglang.srt.distributed import get_tensor_model_parallel_rank
30
- from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
31
+ from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size,
34
+ )
35
+ from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
37
+ get_attention_tp_rank,
38
+ get_attention_tp_size,
39
+ is_dp_attention_enabled,
40
+ )
41
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
31
42
 
32
43
  logger = logging.getLogger(__name__)
33
44
 
34
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
35
64
  class LayerDoneCounter:
36
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
37
66
  self.num_layers = num_layers
38
67
  # extra producer and consumer counters for overlap mode
39
68
  self.num_counters = 3
40
- self.counters = [num_layers] * self.num_counters
41
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
42
- self.producer_index = 0
43
- self.consumer_index = 0
44
-
45
- def next_producer(self):
46
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
47
72
 
48
73
  def update_producer(self):
49
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
50
80
  return self.producer_index
51
81
 
52
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
53
83
  self.consumer_index = index
54
84
 
55
- def increment(self):
56
- with self.conditions[self.producer_index]:
57
- self.counters[self.producer_index] += 1
58
- self.conditions[self.producer_index].notify_all()
59
-
60
- def wait_until(self, threshold):
61
- with self.conditions[self.consumer_index]:
62
- while self.counters[self.consumer_index] <= threshold:
63
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
64
89
 
65
90
  def reset(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
68
93
 
69
94
 
70
95
  class CacheOperation:
@@ -88,36 +113,30 @@ class CacheOperation:
88
113
  # default priority is the order of creation
89
114
  self.priority = priority if priority is not None else self.id
90
115
 
91
- def merge(self, other: "CacheOperation") -> None:
92
- # multiple operations can be merged into a single operation for batch processing
93
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
94
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
95
- self.priority = min(self.priority, other.priority)
96
- self.node_ids.extend(other.node_ids)
97
-
98
- def split(self, factor) -> List["CacheOperation"]:
99
- # split an operation into smaller operations to reduce the size of intermediate buffers
100
- if factor <= 1:
101
- return [self]
102
-
103
- chunk_size = math.ceil(len(self.host_indices) / factor)
104
- split_ops = []
105
- for i in range(0, len(self.host_indices), chunk_size):
106
- split_ops.append(
107
- CacheOperation(
108
- host_indices=self.host_indices[i : i + chunk_size],
109
- device_indices=self.device_indices[i : i + chunk_size],
110
- node_id=0,
111
- )
112
- )
113
- # Inherit the node_ids on the final chunk
114
- if split_ops:
115
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
116
134
 
117
- return split_ops
118
135
 
119
- def __lt__(self, other: "CacheOperation"):
120
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
121
140
 
122
141
 
123
142
  class TransferBuffer:
@@ -196,26 +215,25 @@ class PrefetchOperation(StorageOperation):
196
215
  ):
197
216
  self.request_id = request_id
198
217
 
199
- self._done_flag = False
200
218
  self._lock = threading.Lock()
201
-
219
+ self._terminated_flag = False
202
220
  self.start_time = time.monotonic()
203
221
 
204
222
  super().__init__(host_indices, token_ids, last_hash)
205
223
 
206
224
  def increment(self, num_tokens: int):
207
225
  with self._lock:
208
- if self._done_flag:
226
+ if self._terminated_flag:
209
227
  return False
210
228
  self.completed_tokens += num_tokens
211
229
  return True
212
230
 
213
- def mark_done(self):
231
+ def mark_terminate(self):
214
232
  with self._lock:
215
- self._done_flag = True
233
+ self._terminated_flag = True
216
234
 
217
- def is_done(self) -> bool:
218
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
219
237
 
220
238
 
221
239
  class HiCacheController:
@@ -226,11 +244,13 @@ class HiCacheController:
226
244
  mem_pool_host: HostKVCache,
227
245
  page_size: int,
228
246
  tp_group: torch.distributed.ProcessGroup,
229
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
230
248
  write_policy: str = "write_through_selective",
231
249
  io_backend: str = "",
232
250
  storage_backend: Optional[str] = None,
233
251
  prefetch_threshold: int = 256,
252
+ model_name: Optional[str] = None,
253
+ storage_backend_extra_config: Optional[str] = None,
234
254
  ):
235
255
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
236
256
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -238,30 +258,37 @@ class HiCacheController:
238
258
  self.write_policy = write_policy
239
259
  self.page_size = page_size
240
260
  self.io_backend = io_backend
241
-
242
261
  self.enable_storage = False
243
- self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
244
- # todo: move backend initialization to storage backend module
262
+
245
263
  if storage_backend is not None:
246
264
  self.storage_backend_type = storage_backend
247
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
265
+ from sglang.srt.mem_cache.hicache_storage import get_hash_str
266
+
267
+ self.get_hash_str = get_hash_str
268
+ self.storage_config = self._generate_storage_config(
269
+ model_name, storage_backend_extra_config
270
+ )
271
+ # for MLA models, only one rank needs to backup the KV cache
272
+ self.backup_skip = (
273
+ self.storage_config.is_mla_model
274
+ # todo: load balancing
275
+ and self.storage_config.tp_rank != 0
276
+ )
248
277
 
249
278
  if storage_backend == "file":
250
- self.storage_backend = HiCacheFile(is_mla=self.is_mla)
251
- self.get_hash_str = get_hash_str
279
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile
280
+
281
+ self.storage_backend = HiCacheFile(self.storage_config)
252
282
  elif storage_backend == "nixl":
253
283
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
254
284
 
255
285
  self.storage_backend = HiCacheNixl()
256
- self.get_hash_str = get_hash_str
257
286
  elif storage_backend == "mooncake":
258
287
  from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
259
288
  MooncakeStore,
260
- get_hash_str_mooncake,
261
289
  )
262
290
 
263
- self.storage_backend = MooncakeStore(is_mla=self.is_mla)
264
- self.get_hash_str = get_hash_str_mooncake
291
+ self.storage_backend = MooncakeStore(self.storage_config)
265
292
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
266
293
  assert self.mem_pool_host.layout == "page_first"
267
294
  elif storage_backend == "hf3fs":
@@ -279,19 +306,21 @@ class HiCacheController:
279
306
  )
280
307
  dtype = mem_pool_host.dtype
281
308
  self.storage_backend = HiCacheHF3FS.from_env_config(
282
- bytes_per_page, dtype
309
+ bytes_per_page, dtype, self.storage_config
283
310
  )
284
- self.get_hash_str = get_hash_str
285
311
  else:
286
312
  raise NotImplementedError(
287
313
  f"Unsupported storage backend: {storage_backend}"
288
314
  )
315
+
289
316
  self.enable_storage = True
290
317
  # todo: threshold policy for prefetching
291
318
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
292
319
  self.prefetch_capacity_limit = int(
293
320
  0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
294
321
  )
322
+ # granularity of batch storage IO operations, in number of pages
323
+ self.storage_batch_size = 128
295
324
  # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
296
325
  self.prefetch_tokens_occupied = 0
297
326
 
@@ -302,15 +331,26 @@ class HiCacheController:
302
331
  self.prefetch_tp_group = torch.distributed.new_group(
303
332
  group_ranks, backend="gloo"
304
333
  )
305
- self.prefetch_io_tp_group = torch.distributed.new_group(
306
- group_ranks, backend="gloo"
307
- )
308
- self.backup_tp_group = torch.distributed.new_group(
309
- group_ranks, backend="gloo"
310
- )
311
334
 
312
- self.load_cache_event = load_cache_event
313
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
335
+ # Select the get and set functions
336
+ self.page_get_func = self._generic_page_get
337
+ self.page_set_func = self._generic_page_set
338
+ self.batch_exists_func = self.storage_backend.batch_exists
339
+ self.is_3fs_zerocopy = (
340
+ self.storage_backend_type == "hf3fs"
341
+ and self.mem_pool_host.layout == "page_first"
342
+ )
343
+ if self.storage_backend_type == "mooncake":
344
+ self.page_get_func = self._mooncake_page_get
345
+ self.page_set_func = self._mooncake_page_set
346
+ elif self.is_3fs_zerocopy:
347
+ self.page_get_func = self._3fs_zero_copy_page_get
348
+ self.page_set_func = self._3fs_zero_copy_page_set
349
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
350
+
351
+ self.device = self.mem_pool_device.device
352
+ self.layer_num = self.mem_pool_device.layer_num
353
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
314
354
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
315
355
 
316
356
  if write_policy not in [
@@ -320,11 +360,11 @@ class HiCacheController:
320
360
  ]:
321
361
  raise ValueError(f"Invalid write policy: {write_policy}")
322
362
 
323
- self.write_queue = PriorityQueue()
324
- self.load_queue = PriorityQueue()
325
-
326
- self.ack_write_queue = Queue()
327
- self.ack_load_queue = Queue()
363
+ # self.write_queue = PriorityQueue[CacheOperation]()
364
+ self.load_queue: List[CacheOperation] = []
365
+ self.write_queue: List[CacheOperation] = []
366
+ self.ack_load_queue: List[HiCacheAck] = []
367
+ self.ack_write_queue: List[HiCacheAck] = []
328
368
 
329
369
  self.stop_event = threading.Event()
330
370
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -335,16 +375,6 @@ class HiCacheController:
335
375
  self.write_stream = torch.cuda.Stream()
336
376
  self.load_stream = torch.cuda.Stream()
337
377
 
338
- self.write_thread = threading.Thread(
339
- target=self.write_thread_func_direct, daemon=True
340
- )
341
- self.load_thread = threading.Thread(
342
- target=self.load_thread_func_layer_by_layer, daemon=True
343
- )
344
-
345
- self.write_thread.start()
346
- self.load_thread.start()
347
-
348
378
  if self.enable_storage:
349
379
  self.prefetch_thread = threading.Thread(
350
380
  target=self.prefetch_thread_func, daemon=True
@@ -357,21 +387,57 @@ class HiCacheController:
357
387
 
358
388
  self.prefetch_revoke_queue = Queue()
359
389
  self.ack_backup_queue = Queue()
390
+ self.host_mem_release_queue = Queue()
360
391
 
361
392
  self.prefetch_thread.start()
362
393
  self.backup_thread.start()
363
394
 
395
+ def _generate_storage_config(
396
+ self,
397
+ model_name: Optional[str] = None,
398
+ storage_backend_extra_config: Optional[str] = None,
399
+ ):
400
+
401
+ if is_dp_attention_enabled():
402
+ self.tp_rank = get_attention_tp_rank()
403
+ self.tp_size = get_attention_tp_size()
404
+ self.dp_rank = get_attention_dp_rank()
405
+ else:
406
+ self.tp_rank = get_tensor_model_parallel_rank()
407
+ self.tp_size = get_tensor_model_parallel_world_size()
408
+ self.dp_rank = 0
409
+
410
+ # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
411
+ is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
412
+
413
+ # Parse extra config JSON if provided
414
+ extra_config = None
415
+ if storage_backend_extra_config:
416
+ try:
417
+ import json
418
+
419
+ extra_config = json.loads(storage_backend_extra_config)
420
+ except Exception as e:
421
+ logger.error(f"Invalid backend extra config JSON: {e}")
422
+
423
+ return HiCacheStorageConfig(
424
+ tp_rank=self.tp_rank,
425
+ tp_size=self.tp_size,
426
+ is_mla_model=is_mla_backend,
427
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
428
+ model_name=model_name,
429
+ extra_config=extra_config,
430
+ )
431
+
364
432
  def reset(self):
365
433
  self.stop_event.set()
366
- self.write_thread.join()
367
- self.load_thread.join()
368
434
 
369
- self.write_queue.queue.clear()
370
- self.load_queue.queue.clear()
435
+ self.write_queue.clear()
436
+ self.load_queue.clear()
371
437
  self.write_buffer.clear()
372
438
  self.load_buffer.clear()
373
- self.ack_write_queue.queue.clear()
374
- self.ack_load_queue.queue.clear()
439
+ self.ack_write_queue.clear()
440
+ self.ack_load_queue.clear()
375
441
  if self.enable_storage:
376
442
  self.prefetch_thread.join()
377
443
  self.backup_thread.join()
@@ -380,15 +446,7 @@ class HiCacheController:
380
446
  self.prefetch_revoke_queue.queue.clear()
381
447
  self.ack_backup_queue.queue.clear()
382
448
 
383
- self.write_thread = threading.Thread(
384
- target=self.write_thread_func_direct, daemon=True
385
- )
386
- self.load_thread = threading.Thread(
387
- target=self.load_thread_func_layer_by_layer, daemon=True
388
- )
389
449
  self.stop_event.clear()
390
- self.write_thread.start()
391
- self.load_thread.start()
392
450
 
393
451
  if self.enable_storage:
394
452
  self.prefetch_thread = threading.Thread(
@@ -400,20 +458,11 @@ class HiCacheController:
400
458
  self.prefetch_thread.start()
401
459
  self.backup_thread.start()
402
460
 
403
- @property
404
- def backup_skip(self):
405
- return (
406
- self.is_mla
407
- and get_tensor_model_parallel_rank() != 0
408
- # todo: only support file and mooncake
409
- and self.storage_backend_type in ["file", "mooncake"]
410
- )
411
-
412
461
  def write(
413
462
  self,
414
463
  device_indices: torch.Tensor,
415
464
  priority: Optional[int] = None,
416
- node_id: int = 0,
465
+ node_id: int = -1,
417
466
  ) -> Optional[torch.Tensor]:
418
467
  """
419
468
  Back up KV caches from device memory to host memory.
@@ -422,17 +471,46 @@ class HiCacheController:
422
471
  if host_indices is None:
423
472
  return None
424
473
  self.mem_pool_host.protect_write(host_indices)
425
- torch.cuda.current_stream().synchronize()
426
- self.write_queue.put(
474
+ self.write_queue.append(
427
475
  CacheOperation(host_indices, device_indices, node_id, priority)
428
476
  )
477
+ self.start_writing()
429
478
  return host_indices
430
479
 
480
+ def start_writing(self) -> None:
481
+ if len(self.write_queue) == 0:
482
+ return
483
+
484
+ op = CacheOperation.merge_ops(self.write_queue)
485
+ host_indices, device_indices = self.move_indices(op)
486
+ self.write_queue.clear()
487
+
488
+ start_event = torch.cuda.Event()
489
+ finish_event = torch.cuda.Event()
490
+
491
+ start_event.record()
492
+ with torch.cuda.stream(self.write_stream):
493
+ start_event.wait(self.write_stream)
494
+ self.mem_pool_host.backup_from_device_all_layer(
495
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
496
+ )
497
+ self.mem_pool_host.complete_io(op.host_indices)
498
+ finish_event.record()
499
+ # NOTE: We must save the host indices and device indices here,
500
+ # this is because we need to guarantee that these tensors are
501
+ # still alive when the write stream is executing.
502
+ if host_indices.is_cuda:
503
+ host_indices.record_stream(self.write_stream)
504
+ if device_indices.is_cuda:
505
+ device_indices.record_stream(self.write_stream)
506
+
507
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
508
+
431
509
  def load(
432
510
  self,
433
511
  host_indices: torch.Tensor,
434
512
  priority: Optional[int] = None,
435
- node_id: int = 0,
513
+ node_id: int = -1,
436
514
  ) -> Optional[torch.Tensor]:
437
515
  """
438
516
  Load KV caches from host memory to device memory.
@@ -441,17 +519,18 @@ class HiCacheController:
441
519
  if device_indices is None:
442
520
  return None
443
521
  self.mem_pool_host.protect_load(host_indices)
444
- # to ensure the device indices are ready before accessed by another CUDA stream
445
- torch.cuda.current_stream().synchronize()
446
- self.load_queue.put(
522
+ self.load_queue.append(
447
523
  CacheOperation(host_indices, device_indices, node_id, priority)
448
524
  )
449
525
  return device_indices
450
526
 
451
- def move_indices(self, host_indices, device_indices):
527
+ def move_indices(self, op: CacheOperation):
528
+ host_indices, device_indices = op.host_indices, op.device_indices
452
529
  # move indices to GPU if using kernels, to host if using direct indexing
453
530
  if self.io_backend == "kernel":
454
- return host_indices.to(self.mem_pool_device.device), device_indices
531
+ if not host_indices.is_cuda:
532
+ host_indices = host_indices.to(self.device, non_blocking=True)
533
+ return host_indices, device_indices
455
534
  elif self.io_backend == "direct":
456
535
  device_indices = device_indices.cpu()
457
536
  host_indices, idx = host_indices.sort()
@@ -459,58 +538,20 @@ class HiCacheController:
459
538
  else:
460
539
  raise ValueError(f"Unsupported io backend")
461
540
 
462
- def write_thread_func_direct(self):
463
- """
464
- Directly write through KV caches to host memory without buffering.
465
- """
466
- torch.cuda.set_stream(self.write_stream)
467
- while not self.stop_event.is_set():
468
- try:
469
- operation = self.write_queue.get(block=True, timeout=1)
470
- host_indices, device_indices = self.move_indices(
471
- operation.host_indices, operation.device_indices
472
- )
473
- self.mem_pool_host.backup_from_device_all_layer(
474
- self.mem_pool_device, host_indices, device_indices, self.io_backend
475
- )
476
- self.write_stream.synchronize()
477
- self.mem_pool_host.complete_io(operation.host_indices)
478
- for node_id in operation.node_ids:
479
- if node_id != 0:
480
- self.ack_write_queue.put(node_id)
481
- except Empty:
482
- continue
483
- except Exception as e:
484
- logger.error(e)
541
+ def start_loading(self) -> int:
542
+ if len(self.load_queue) == 0:
543
+ return -1
485
544
 
486
- def load_thread_func_layer_by_layer(self):
487
- """
488
- Load KV caches from host memory to device memory layer by layer.
489
- """
490
- torch.cuda.set_stream(self.load_stream)
491
- while not self.stop_event.is_set():
492
- self.load_cache_event.wait(timeout=1)
493
- if not self.load_cache_event.is_set():
494
- continue
495
- self.load_cache_event.clear()
496
- self.layer_done_counter.update_producer()
497
-
498
- batch_operation = None
499
- while self.load_queue.qsize() > 0:
500
- op = self.load_queue.get(block=True)
501
- if batch_operation is None:
502
- batch_operation = op
503
- else:
504
- batch_operation.merge(op)
505
- if batch_operation is None:
506
- continue
545
+ producer_id = self.layer_done_counter.update_producer()
546
+ op = CacheOperation.merge_ops(self.load_queue)
547
+ host_indices, device_indices = self.move_indices(op)
548
+ self.load_queue.clear()
549
+ producer_event = self.layer_done_counter.events[producer_id]
550
+ producer_event.start_event.record()
507
551
 
508
- # start layer-wise KV cache transfer from CPU to GPU
509
- self.layer_done_counter.reset()
510
- host_indices, device_indices = self.move_indices(
511
- batch_operation.host_indices, batch_operation.device_indices
512
- )
513
- for i in range(self.mem_pool_host.layer_num):
552
+ with torch.cuda.stream(self.load_stream):
553
+ producer_event.start_event.wait(self.load_stream)
554
+ for i in range(self.layer_num):
514
555
  self.mem_pool_host.load_to_device_per_layer(
515
556
  self.mem_pool_device,
516
557
  host_indices,
@@ -518,13 +559,24 @@ class HiCacheController:
518
559
  i,
519
560
  self.io_backend,
520
561
  )
521
- self.load_stream.synchronize()
522
- self.layer_done_counter.increment()
523
-
524
- self.mem_pool_host.complete_io(batch_operation.host_indices)
525
- for node_id in batch_operation.node_ids:
526
- if node_id != 0:
527
- self.ack_load_queue.put(node_id)
562
+ producer_event.complete(i)
563
+ self.mem_pool_host.complete_io(op.host_indices)
564
+ # NOTE: We must save the host indices and device indices here,
565
+ # this is because we need to guarantee that these tensors are
566
+ # still alive when the load stream is executing.
567
+ if host_indices.is_cuda:
568
+ host_indices.record_stream(self.load_stream)
569
+ if device_indices.is_cuda:
570
+ device_indices.record_stream(self.load_stream)
571
+
572
+ self.ack_load_queue.append(
573
+ HiCacheAck(
574
+ start_event=producer_event.start_event,
575
+ finish_event=producer_event.finish_event,
576
+ node_ids=op.node_ids,
577
+ )
578
+ )
579
+ return producer_id
528
580
 
529
581
  def evict_device(
530
582
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
@@ -567,63 +619,93 @@ class HiCacheController:
567
619
  return operation
568
620
 
569
621
  def terminate_prefetch(self, operation):
570
- operation.mark_done()
622
+ operation.mark_terminate()
571
623
  return operation.completed_tokens, operation.hash_value
572
624
 
573
- def zerocopy_page_transfer(self, operation, batch_size=8):
574
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
575
- operation.hash_value, operation.host_indices
625
+ def append_host_mem_release(self, host_indices: torch.Tensor):
626
+ chunks = host_indices.split(self.mem_pool_host.page_size)
627
+ for chunk in chunks:
628
+ self.host_mem_release_queue.put(chunk)
629
+
630
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
631
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
632
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
633
+ return hit_page_num
634
+
635
+ def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
636
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
637
+ hash_values, host_indices
576
638
  )
577
- for i in range(0, len(hashes), batch_size):
578
- page_hashes = hashes[i : i + batch_size]
579
- page_dsts = dsts[i : i + batch_size]
580
- page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
581
- if page_data is None:
582
- logger.warning(
583
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
584
- )
585
- break
586
- completed_tokens = operation.completed_tokens
587
- if operation.increment(self.page_size * len(page_hashes)):
588
- for i in range(len(page_hashes)):
589
- completed_tokens += self.page_size
590
- else:
591
- break
639
+ page_data = self.storage_backend.batch_get(hashes, dsts)
640
+ if page_data:
641
+ inc = self.page_size * len(hashes) // factor
642
+ operation.increment(inc)
643
+ else:
644
+ logger.warning(
645
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
646
+ )
592
647
 
593
- def generic_page_transfer(self, operation, batch_size=8):
594
- for i in range(0, len(operation.hash_value), batch_size):
595
- page_hashes = operation.hash_value[i : i + batch_size]
596
- # todo: zero copy
597
- dummy_page_dst = [
598
- self.mem_pool_host.get_dummy_flat_data_page()
599
- for _ in range(len(page_hashes))
600
- ]
601
- page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
602
- if page_data is None:
648
+ def _mooncake_page_get(self, operation, hash_values, host_indices):
649
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
650
+ hash_values,
651
+ host_indices,
652
+ self.storage_config.tp_rank,
653
+ )
654
+ get_result = self.storage_backend.batch_get(
655
+ key_strs,
656
+ target_locations=buffer_ptrs,
657
+ target_sizes=buffer_sizes,
658
+ )
659
+ if get_result != len(hash_values):
660
+ logger.warning(
661
+ f"Prefetch operation {operation.request_id} failed or partially failed."
662
+ )
663
+ if get_result != 0:
664
+ operation.increment(get_result * self.page_size)
665
+
666
+ def _generic_page_get(self, operation, hash_values, host_indices):
667
+ dummy_page_dst = [
668
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
669
+ ]
670
+ page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
671
+ if page_data is None:
672
+ return
673
+ for i in range(len(hash_values)):
674
+ if page_data[i] is None:
603
675
  logger.warning(
604
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
676
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
605
677
  )
606
678
  break
607
- completed_tokens = operation.completed_tokens
608
- if operation.increment(self.page_size * len(page_hashes)):
609
- for i in range(len(page_hashes)):
610
- self.mem_pool_host.set_from_flat_data_page(
611
- operation.host_indices[completed_tokens],
612
- page_data[i],
613
- )
614
- completed_tokens += self.page_size
615
- else:
616
- break
617
-
618
- def mooncake_page_transfer(self, operation):
619
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
620
- operation.hash_value, operation.host_indices
679
+ # Must set the data before increasing the completed tokens.
680
+ # Otherwise this page may be read before being set.
681
+ self.mem_pool_host.set_from_flat_data_page(
682
+ host_indices[i * self.page_size],
683
+ page_data[i],
684
+ )
685
+ if not operation.increment(self.page_size):
686
+ break # Operation terminated by controller
687
+
688
+ def _page_transfer(self, operation):
689
+ # Transfer batch by batch
690
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
691
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
692
+ batch_host_indices = operation.host_indices[
693
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
694
+ ]
695
+ prev_completed_tokens = operation.completed_tokens
696
+ # Get one batch token, and update the completed_tokens if succeed
697
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
698
+ # Check termination
699
+ if (
700
+ operation.completed_tokens
701
+ != prev_completed_tokens + len(batch_hashes) * self.page_size
702
+ ):
703
+ operation.mark_terminate()
704
+ break # Some operations fail or operation terminated by controller
705
+ # release pre-allocated memory
706
+ self.append_host_mem_release(
707
+ operation.host_indices[operation.completed_tokens :]
621
708
  )
622
- self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
623
- operation.increment(len(operation.hash_value) * self.page_size)
624
-
625
- def is_mooncake_backend(self):
626
- return self.storage_backend_type == "mooncake"
627
709
 
628
710
  def prefetch_io_aux_func(self):
629
711
  """
@@ -632,35 +714,50 @@ class HiCacheController:
632
714
  while not self.stop_event.is_set():
633
715
  try:
634
716
  operation = self.prefetch_buffer.get(block=True, timeout=1)
635
- if self.is_mooncake_backend():
636
- self.mooncake_page_transfer(operation)
637
- elif self.storage_backend_type == "hf3fs":
638
- if self.mem_pool_host.layout == "page_first":
639
- self.zerocopy_page_transfer(operation, batch_size=128)
640
- elif self.mem_pool_host.layout == "layer_first":
641
- self.generic_page_transfer(operation, batch_size=128)
642
- else:
643
- self.generic_page_transfer(operation)
644
-
645
- if self.tp_world_size > 1:
646
- # to ensure all TP workers release the host memory at the same time
647
- torch.distributed.barrier(group=self.prefetch_io_tp_group)
717
+ self._page_transfer(operation)
648
718
  # operation terminated by controller, release pre-allocated memory
649
- self.mem_pool_host.free(
719
+ self.append_host_mem_release(
650
720
  operation.host_indices[operation.completed_tokens :]
651
721
  )
652
722
  except Empty:
653
723
  continue
654
724
 
655
- def prefetch_rate_limit_check(self) -> bool:
725
+ def prefetch_rate_limited(self) -> bool:
656
726
  """
657
727
  Rate limit the prefetching operations to avoid overwhelming the storage backend.
658
728
  """
659
729
  # cancel prefetch if too much memory is occupied
660
730
  if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
661
- return False
731
+ return True
662
732
  # todo: more sophisticated rate limiting based on storage backend performance
663
- return True
733
+ return False
734
+
735
+ def _storage_hit_query(self, operation) -> tuple[list[str], int]:
736
+ last_hash = operation.last_hash
737
+ tokens_to_fetch = operation.token_ids
738
+
739
+ storage_query_count = 0
740
+ hash_value = []
741
+
742
+ for start in range(
743
+ 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
744
+ ):
745
+ end = min(
746
+ start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
747
+ )
748
+ batch_tokens = tokens_to_fetch[start:end]
749
+ batch_hashes = []
750
+ for i in range(0, len(batch_tokens), self.page_size):
751
+ last_hash = self.get_hash_str(
752
+ batch_tokens[i : i + self.page_size], last_hash
753
+ )
754
+ batch_hashes.append(last_hash)
755
+ hit_page_num = self.batch_exists_func(batch_hashes)
756
+ hash_value.extend(batch_hashes[:hit_page_num])
757
+ storage_query_count += hit_page_num * self.page_size
758
+ if hit_page_num < len(batch_hashes):
759
+ break
760
+ return hash_value, storage_query_count
664
761
 
665
762
  def prefetch_thread_func(self):
666
763
  """
@@ -675,39 +772,7 @@ class HiCacheController:
675
772
  if operation is None:
676
773
  continue
677
774
 
678
- storage_hit_count = 0
679
- if (
680
- operation.host_indices is not None
681
- ) and self.prefetch_rate_limit_check():
682
- last_hash = operation.last_hash
683
- tokens_to_fetch = operation.token_ids
684
-
685
- remaining_tokens = len(tokens_to_fetch)
686
- hash_value = []
687
- while remaining_tokens >= self.page_size:
688
- last_hash = self.get_hash_str(
689
- tokens_to_fetch[
690
- storage_hit_count : storage_hit_count + self.page_size
691
- ],
692
- last_hash,
693
- )
694
-
695
- # todo, more unified interface
696
- if not self.is_mooncake_backend():
697
- if not self.storage_backend.exists(last_hash):
698
- break
699
- hash_value.append(last_hash)
700
- storage_hit_count += self.page_size
701
- remaining_tokens -= self.page_size
702
-
703
- if self.is_mooncake_backend():
704
- # deferring to batch exists for mooncake store
705
- exist_result = self.storage_backend.exists(hash_value)
706
- storage_hit_count = (
707
- sum(1 for v in exist_result.values() if v != 0)
708
- * self.page_size
709
- )
710
-
775
+ hash_value, storage_hit_count = self._storage_hit_query(operation)
711
776
  if self.tp_world_size > 1:
712
777
  storage_hit_count_tensor = torch.tensor(
713
778
  storage_hit_count, dtype=torch.int
@@ -722,8 +787,7 @@ class HiCacheController:
722
787
  if storage_hit_count < self.prefetch_threshold:
723
788
  # not to prefetch if not enough benefits
724
789
  self.prefetch_revoke_queue.put(operation.request_id)
725
- if operation.host_indices is not None:
726
- self.mem_pool_host.free(operation.host_indices)
790
+ self.append_host_mem_release(operation.host_indices)
727
791
  logger.debug(
728
792
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
729
793
  )
@@ -732,7 +796,9 @@ class HiCacheController:
732
796
  : (storage_hit_count // self.page_size)
733
797
  ]
734
798
  # free the pre-allocated memory for pages that are not hit
735
- self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
799
+ self.append_host_mem_release(
800
+ operation.host_indices[storage_hit_count:]
801
+ )
736
802
  operation.host_indices = operation.host_indices[:storage_hit_count]
737
803
  logger.debug(
738
804
  f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
@@ -755,59 +821,52 @@ class HiCacheController:
755
821
  self.backup_queue.put(operation)
756
822
  return operation.id
757
823
 
758
- def zerocopy_page_backup(self, operation, batch_size=8):
759
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
760
- operation.hash_value, operation.host_indices
824
+ # non-zero copy
825
+ def _generic_page_set(self, hash_values, host_indices) -> bool:
826
+ data = [
827
+ self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
828
+ for i in range(len(hash_values))
829
+ ]
830
+ return self.storage_backend.batch_set(hash_values, data)
831
+
832
+ # zero copy
833
+ def _mooncake_page_set(self, hash_values, host_indices) -> bool:
834
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
835
+ hash_values,
836
+ host_indices,
837
+ self.storage_config.tp_rank,
761
838
  )
762
- for i in range(0, len(hashes), batch_size):
763
- page_hashes = hashes[i : i + batch_size]
764
- page_data = dsts[i : i + batch_size]
765
- success = self.storage_backend.batch_set(page_hashes, page_data)
766
- if not success:
767
- logger.warning(f"Failed to write page {page_hashes} to storage.")
768
- break
769
- operation.completed_tokens += self.page_size * len(page_hashes)
770
-
771
- def generic_page_backup(self, operation, batch_size=8):
772
- for i in range(0, len(operation.hash_value), batch_size):
773
- page_hashes = operation.hash_value[i : i + batch_size]
774
- page_data = [
775
- self.mem_pool_host.get_flat_data_page(
776
- operation.host_indices[j * self.page_size]
777
- )
778
- for j in range(i, i + len(page_hashes))
839
+ success = self.storage_backend.batch_set(
840
+ key_strs,
841
+ target_locations=buffer_ptrs,
842
+ target_sizes=buffer_sizes,
843
+ )
844
+ return success
845
+
846
+ # zero copy
847
+ def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
848
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
849
+ hash_values, host_indices
850
+ )
851
+ return self.storage_backend.batch_set(hashes, dsts)
852
+
853
+ # Backup batch by batch
854
+ def _page_backup(self, operation):
855
+ # Backup batch by batch
856
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
857
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
858
+ batch_host_indices = operation.host_indices[
859
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
779
860
  ]
780
- success = self.storage_backend.batch_set(page_hashes, page_data)
861
+ # Set one batch token, and record if success.
862
+ # todo: allow partial success
863
+ success = self.page_set_func(batch_hashes, batch_host_indices)
781
864
  if not success:
782
- logger.warning(f"Failed to write page {page_hashes} to storage.")
783
- break
784
- operation.completed_tokens += self.page_size * len(page_hashes)
785
-
786
- def mooncake_page_backup(self, operation):
787
- if len(operation.hash_value):
788
- exist_hashvalues = self.storage_backend.exists(operation.hash_value)
789
- indices = operation.host_indices.tolist()
790
- non_exist_keys = []
791
- non_exist_indices = []
792
- for i in range(len(operation.hash_value)):
793
- if not exist_hashvalues[operation.hash_value[i]]:
794
- non_exist_keys.append(operation.hash_value[i])
795
- non_exist_indices.extend(
796
- indices[i * self.page_size : (i + 1) * self.page_size]
797
- )
798
- if len(non_exist_keys) > 0:
799
- key_strs, buffer_ptrs, buffer_sizes = (
800
- self.mem_pool_host.get_buffer_meta(
801
- non_exist_keys, non_exist_indices
802
- )
803
- )
804
- # TODO: check the return value of batch set to see how many tokens are set successfully
805
- self.storage_backend.batch_set(
806
- key_strs,
807
- target_location=buffer_ptrs,
808
- target_sizes=buffer_sizes,
865
+ logger.warning(
866
+ f"Write page to storage: {len(batch_hashes)} pages failed."
809
867
  )
810
- operation.completed_tokens += len(operation.hash_value) * self.page_size
868
+ break
869
+ operation.completed_tokens += self.page_size * len(batch_hashes)
811
870
 
812
871
  def backup_thread_func(self):
813
872
  """
@@ -820,36 +879,8 @@ class HiCacheController:
820
879
  continue
821
880
 
822
881
  if not self.backup_skip:
823
- if self.is_mooncake_backend():
824
- self.mooncake_page_backup(operation)
825
- elif self.storage_backend_type == "hf3fs":
826
- if self.mem_pool_host.layout == "page_first":
827
- self.zerocopy_page_backup(operation, batch_size=128)
828
- elif self.mem_pool_host.layout == "layer_first":
829
- self.generic_page_backup(operation, batch_size=128)
830
- else:
831
- self.generic_page_backup(operation)
832
- min_completed_tokens = operation.completed_tokens
833
- else:
834
- min_completed_tokens = len(operation.token_ids)
835
-
836
- if self.tp_world_size > 1:
837
- completed_tokens_tensor = torch.tensor(
838
- min_completed_tokens, dtype=torch.int
839
- )
840
- torch.distributed.all_reduce(
841
- completed_tokens_tensor,
842
- op=torch.distributed.ReduceOp.MIN,
843
- group=self.backup_tp_group,
844
- )
845
- min_completed_tokens = completed_tokens_tensor.item()
846
-
847
- self.ack_backup_queue.put(
848
- (
849
- operation.id,
850
- min_completed_tokens,
851
- )
852
- )
882
+ self._page_backup(operation)
883
+ self.ack_backup_queue.put(operation)
853
884
 
854
885
  except Empty:
855
886
  continue