sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
37
  hicache_io_backend: str,
38
+ hicache_storage_backend: Optional[str] = None,
38
39
  ):
39
40
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
40
41
  if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -49,25 +50,36 @@ class HiRadixCache(RadixCache):
49
50
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
50
51
 
51
52
  self.tp_group = tp_cache_group
53
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
54
+ self.enable_storage = hicache_storage_backend is not None
55
+ # todo: customizable storage prefetch threshold
56
+ self.prefetch_threshold = 256
52
57
 
53
58
  self.load_cache_event = threading.Event()
54
59
  self.cache_controller = HiCacheController(
55
60
  token_to_kv_pool_allocator,
56
61
  self.token_to_kv_pool_host,
57
62
  page_size,
63
+ self.tp_group,
58
64
  load_cache_event=self.load_cache_event,
59
65
  write_policy=hicache_write_policy,
60
66
  io_backend=hicache_io_backend,
67
+ storage_backend=hicache_storage_backend,
68
+ prefetch_threshold=self.prefetch_threshold,
61
69
  )
62
70
 
63
71
  # record the nodes with ongoing write through
64
72
  self.ongoing_write_through = {}
65
73
  # record the node segments with ongoing load back
66
74
  self.ongoing_load_back = {}
75
+ # record the ongoing prefetch requests
76
+ self.ongoing_prefetch = {}
77
+ self.ongoing_backup = {}
67
78
  # todo: dynamically adjust the threshold
68
79
  self.write_through_threshold = (
69
80
  1 if hicache_write_policy == "write_through" else 3
70
81
  )
82
+ self.write_through_threshold_storage = 3
71
83
  self.load_back_threshold = 10
72
84
  super().__init__(
73
85
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -108,13 +120,30 @@ class HiRadixCache(RadixCache):
108
120
 
109
121
  return len(host_indices)
110
122
 
123
+ def write_backup_storage(self, node: TreeNode):
124
+ operation_id = self.cache_controller.write_storage(
125
+ node.host_value, node.key, node.parent.get_last_hash_value()
126
+ )
127
+ self.ongoing_backup[operation_id] = node
128
+ node.protect_host()
129
+
111
130
  def inc_hit_count(self, node: TreeNode):
112
- if node.backuped or self.cache_controller.write_policy == "write_back":
131
+ if self.cache_controller.write_policy == "write_back":
113
132
  return
114
133
  node.hit_count += 1
115
- if node.hit_count >= self.write_through_threshold:
116
- self.write_backup(node)
117
- node.hit_count = 0
134
+
135
+ if not node.backuped:
136
+ if node.hit_count >= self.write_through_threshold:
137
+ # write to host if the node is not backuped
138
+ self.write_backup(node)
139
+ else:
140
+ if (
141
+ self.enable_storage
142
+ and (not node.backuped_storage)
143
+ and node.hit_count >= self.write_through_threshold_storage
144
+ ):
145
+ # if the node is backuped on host memory but not on storage
146
+ self.write_backup_storage(node)
118
147
 
119
148
  def writing_check(self, write_back=False):
120
149
  if write_back:
@@ -126,7 +155,7 @@ class HiRadixCache(RadixCache):
126
155
  queue_size = torch.tensor(
127
156
  self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
128
157
  )
129
- if torch.distributed.get_world_size(group=self.tp_group) > 1:
158
+ if self.tp_world_size > 1:
130
159
  # synchrnoize TP workers to make the same update to radix cache
131
160
  torch.distributed.all_reduce(
132
161
  queue_size,
@@ -221,6 +250,10 @@ class HiRadixCache(RadixCache):
221
250
  if not x.evicted:
222
251
  continue
223
252
 
253
+ # node is protected from eviction as it has ongoing prefetch or backup to storage
254
+ if x.host_ref_counter > 0:
255
+ continue
256
+
224
257
  num_evicted += self.cache_controller.evict_host(x.host_value)
225
258
 
226
259
  for k, v in x.parent.children.items():
@@ -314,6 +347,94 @@ class HiRadixCache(RadixCache):
314
347
  def check_hicache_events(self):
315
348
  self.writing_check()
316
349
  self.loading_check()
350
+ if self.enable_storage:
351
+ self.check_revoked_prefetch()
352
+ self.check_backup_progress()
353
+
354
+ def check_revoked_prefetch(self):
355
+ queue_size = torch.tensor(
356
+ self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
357
+ )
358
+ if self.tp_world_size > 1:
359
+ # synchrnoize TP workers to make the same update to hiradix cache
360
+ torch.distributed.all_reduce(
361
+ queue_size,
362
+ op=torch.distributed.ReduceOp.MIN,
363
+ group=self.tp_group,
364
+ )
365
+ for _ in range(queue_size.item()):
366
+ req_id = self.cache_controller.prefetch_revoke_queue.get()
367
+ if req_id in self.ongoing_prefetch:
368
+ last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
369
+ last_host_node.release_host()
370
+ self.cache_controller.mem_pool_host.free(host_indices)
371
+ del self.ongoing_prefetch[req_id]
372
+
373
+ def check_backup_progress(self):
374
+ queue_size = torch.tensor(
375
+ self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
376
+ )
377
+ if self.tp_world_size > 1:
378
+ # synchrnoize TP workers to make the same update to hiradix cache
379
+ torch.distributed.all_reduce(
380
+ queue_size,
381
+ op=torch.distributed.ReduceOp.MIN,
382
+ group=self.tp_group,
383
+ )
384
+ for _ in range(queue_size.item()):
385
+ ack_id, hash_value, completed_tokens = (
386
+ self.cache_controller.ack_backup_queue.get()
387
+ )
388
+ host_node = self.ongoing_backup[ack_id]
389
+ if completed_tokens < len(host_node.key):
390
+ # backup is only partially successful, split the node
391
+ new_node = self._split_node(host_node.key, host_node, completed_tokens)
392
+ new_node.hash_value = hash_value
393
+ host_node.release_host()
394
+ del self.ongoing_backup[ack_id]
395
+
396
+ def check_prefetch_progress(self, req_id: str):
397
+ if req_id not in self.ongoing_prefetch:
398
+ # there is no ongoing prefetch for this request or it has been revoked
399
+ return
400
+
401
+ # todo: more policies for prefetch progress such as timeout
402
+ # the current policy is to prefetch with best effort and terminate when queuing is over
403
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
404
+ req_id
405
+ ]
406
+ completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
407
+ operation
408
+ )
409
+ logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
410
+
411
+ min_completed_tokens = completed_tokens
412
+ if self.tp_world_size > 1:
413
+ # synchrnoize TP workers to make the same update to hiradix cache
414
+ completed_tokens_tensor = torch.tensor(
415
+ min_completed_tokens, dtype=torch.int
416
+ )
417
+ torch.distributed.all_reduce(
418
+ completed_tokens_tensor,
419
+ op=torch.distributed.ReduceOp.MIN,
420
+ group=self.tp_group,
421
+ )
422
+ min_completed_tokens = completed_tokens_tensor.item()
423
+ fetched_token_ids = token_ids[:min_completed_tokens]
424
+ written_indices = host_indices[:min_completed_tokens]
425
+ matched_length = self._insert_helper_host(
426
+ last_host_node,
427
+ fetched_token_ids,
428
+ written_indices,
429
+ hash_value[:min_completed_tokens],
430
+ )
431
+
432
+ self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
433
+ self.cache_controller.mem_pool_host.free(
434
+ host_indices[min_completed_tokens:completed_tokens]
435
+ )
436
+ last_host_node.release_host()
437
+ del self.ongoing_prefetch[req_id]
317
438
 
318
439
  def match_prefix(self, key: List[int], **kwargs):
319
440
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -348,6 +469,74 @@ class HiRadixCache(RadixCache):
348
469
  host_hit_length=host_hit_length,
349
470
  )
350
471
 
472
+ def prefetch_from_storage(
473
+ self,
474
+ req_id: str,
475
+ last_host_node: TreeNode,
476
+ new_input_tokens: List[int],
477
+ last_hash: Optional[str] = None,
478
+ ):
479
+ # align the number of fetching tokens to the page size
480
+ prefetch_length = len(new_input_tokens) - (
481
+ len(new_input_tokens) % self.page_size
482
+ )
483
+ new_input_tokens = new_input_tokens[:prefetch_length]
484
+ if not self.enable_storage or prefetch_length < self.prefetch_threshold:
485
+ return
486
+
487
+ last_host_node.protect_host()
488
+ host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
489
+ if host_indices is None:
490
+ self.evict_host(prefetch_length)
491
+ host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
492
+ if host_indices is None:
493
+ last_host_node.release_host()
494
+ # no sufficient host memory to prefetch
495
+ return
496
+ operation = self.cache_controller.prefetch(
497
+ req_id, host_indices, new_input_tokens, last_hash
498
+ )
499
+ self.ongoing_prefetch[req_id] = (
500
+ last_host_node,
501
+ new_input_tokens,
502
+ host_indices,
503
+ operation,
504
+ )
505
+
506
+ def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
507
+ node.last_access_time = time.monotonic()
508
+ if len(key) == 0:
509
+ return 0
510
+
511
+ child_key = self.get_child_key_fn(key)
512
+
513
+ matched_length = 0
514
+ while len(key) > 0 and child_key in node.children.keys():
515
+ node = node.children[child_key]
516
+ node.last_access_time = time.monotonic()
517
+ prefix_len = self.key_match_fn(node.key, key)
518
+ key = key[prefix_len:]
519
+ host_value = host_value[prefix_len:]
520
+ hash_value = hash_value[prefix_len:]
521
+ matched_length += prefix_len
522
+
523
+ if prefix_len < len(node.key):
524
+ new_node = self._split_node(node.key, node, prefix_len)
525
+ node = new_node
526
+
527
+ if len(key):
528
+ child_key = self.get_child_key_fn(key)
529
+
530
+ if len(key):
531
+ new_node = TreeNode()
532
+ new_node.parent = node
533
+ new_node.key = key
534
+ new_node.value = None
535
+ new_node.host_value = host_value
536
+ new_node.hash_value = hash_value
537
+ node.children[child_key] = new_node
538
+ return matched_length
539
+
351
540
  def _match_prefix_helper(self, node: TreeNode, key: List):
352
541
  node.last_access_time = time.monotonic()
353
542
  child_key = self.get_child_key_fn(key)
@@ -520,8 +520,13 @@ class SWAKVPool(KVCache):
520
520
  self.layers_mapping[global_layer_id] = (swa_layer_id, True)
521
521
  self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
522
522
 
523
+ k_size, v_size = self.get_kv_size_bytes()
524
+ self.mem_usage = (k_size + v_size) / GB
525
+
523
526
  def get_kv_size_bytes(self):
524
- raise NotImplementedError
527
+ k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
528
+ k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
529
+ return k_size + k_size_swa, v_size + v_size_swa
525
530
 
526
531
  def get_contiguous_buf_infos(self):
527
532
  full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
@@ -597,6 +602,16 @@ class SWAKVPool(KVCache):
597
602
  layer_id_override=layer_id_pool,
598
603
  )
599
604
 
605
+ def load_from_host_per_layer(
606
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
607
+ ):
608
+ raise NotImplementedError("HiCache not supported for SWAKVPool.")
609
+
610
+ def backup_to_host_all_layer(
611
+ self, host_pool, host_indices, device_indices, io_backend
612
+ ):
613
+ raise NotImplementedError("HiCache not supported for SWAKVPool.")
614
+
600
615
 
601
616
  class AscendTokenToKVPool(MHATokenToKVPool):
602
617
 
@@ -71,11 +71,12 @@ class HostKVCache(abc.ABC):
71
71
  requested_bytes = self.size * self.size_per_token
72
72
  # preserve at least 10GB for other usage
73
73
  ten_gb = 10 * (1024**3)
74
- if requested_bytes > host_mem.available - ten_gb:
74
+ available_bytes = host_mem.available - ten_gb
75
+ if requested_bytes > available_bytes:
75
76
  raise ValueError(
76
77
  f"Not enough host memory available. Requesting "
77
78
  f"{requested_bytes / 1e9:.2f} GB but only have "
78
- f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
79
+ f"{available_bytes / 1e9:.2f} GB free. Please reduce the "
79
80
  f"size of the hierarchical cache."
80
81
  )
81
82
  else:
@@ -98,6 +99,20 @@ class HostKVCache(abc.ABC):
98
99
  def init_kv_buffer(self):
99
100
  raise NotImplementedError()
100
101
 
102
+ @abc.abstractmethod
103
+ def get_flat_data_page(self, index) -> torch.Tensor:
104
+ """
105
+ Get a flat data page from the host memory pool.
106
+ """
107
+ raise NotImplementedError()
108
+
109
+ @abc.abstractmethod
110
+ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
111
+ """
112
+ Set a flat data page to the host memory pool.
113
+ """
114
+ raise NotImplementedError()
115
+
101
116
  @synchronized()
102
117
  def clear(self):
103
118
  # Initialize memory states and tracking structures.
@@ -111,6 +126,9 @@ class HostKVCache(abc.ABC):
111
126
 
112
127
  @synchronized()
113
128
  def alloc(self, need_size: int) -> torch.Tensor:
129
+ assert (
130
+ need_size % self.page_size == 0
131
+ ), "The requested size should be a multiple of the page size."
114
132
  if need_size > self.available_size():
115
133
  return None
116
134
 
@@ -226,6 +244,19 @@ class MHATokenToKVPoolHost(HostKVCache):
226
244
  pin_memory=self.pin_memory,
227
245
  )
228
246
 
247
+ # todo, page first memory layout
248
+ def get_flat_data_page(self, index) -> torch.Tensor:
249
+ return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
250
+
251
+ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
252
+ self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
253
+ 2,
254
+ self.layer_num,
255
+ self.page_size,
256
+ self.head_num,
257
+ self.head_dim,
258
+ )
259
+
229
260
  @property
230
261
  def k_buffer(self):
231
262
  return self.kv_buffer[0]
@@ -275,3 +306,14 @@ class MLATokenToKVPoolHost(HostKVCache):
275
306
  device=self.device,
276
307
  pin_memory=self.pin_memory,
277
308
  )
309
+
310
+ def get_flat_data_page(self, index) -> torch.Tensor:
311
+ return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
312
+
313
+ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
314
+ self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
315
+ self.layer_num,
316
+ self.page_size,
317
+ 1,
318
+ self.kv_lora_rank + self.qk_rope_head_dim,
319
+ )
@@ -55,8 +55,13 @@ class TreeNode:
55
55
  self.hit_count = 0
56
56
  # indicating the node is loading KV cache from host
57
57
  self.loading = False
58
+ # indicating the node is locked to protect from eviction
59
+ # incremented when the node is referenced by a storage operation
60
+ self.host_ref_counter = 0
58
61
  # store the host indices of KV cache
59
62
  self.host_value: Optional[torch.Tensor] = None
63
+ # store hash values of each pages
64
+ self.hash_value: Optional[List[str]] = None
60
65
 
61
66
  self.id = TreeNode.counter if id is None else id
62
67
  TreeNode.counter += 1
@@ -69,6 +74,27 @@ class TreeNode:
69
74
  def backuped(self):
70
75
  return self.host_value is not None
71
76
 
77
+ @property
78
+ def backuped_storage(self):
79
+ return self.hash_value is not None and len(self.hash_value) > 0
80
+
81
+ def protect_host(self):
82
+ """Protect the host value from eviction."""
83
+ self.host_ref_counter += 1
84
+
85
+ def release_host(self):
86
+ """Release the host value, allowing it to be evicted."""
87
+ if self.host_ref_counter > 0:
88
+ self.host_ref_counter -= 1
89
+ else:
90
+ raise RuntimeError("Host reference counter is already zero.")
91
+
92
+ def get_last_hash_value(self) -> Optional[str]:
93
+ """Returns the hash value of the last page in this node."""
94
+ if self.hash_value is None or len(self.hash_value) == 0:
95
+ return None
96
+ return self.hash_value[-1]
97
+
72
98
  def __lt__(self, other: "TreeNode"):
73
99
  return self.last_access_time < other.last_access_time
74
100