sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -574,7 +574,7 @@ class TokenizerManager:
574
574
  "The server is not configured to enable custom logit processor. "
575
575
  "Please set `--enable-custom-logits-processor` to enable this feature."
576
576
  )
577
- if self.server_args.lora_paths and obj.lora_path:
577
+ if self.server_args.enable_lora and obj.lora_path:
578
578
  self._validate_lora_adapters(obj)
579
579
 
580
580
  def _validate_input_ids_in_vocab(
@@ -604,7 +604,7 @@ class TokenizerManager:
604
604
  sampling_kwargs = obj.sampling_params
605
605
  sampling_params = SamplingParams(**sampling_kwargs)
606
606
  sampling_params.normalize(self.tokenizer)
607
- sampling_params.verify()
607
+ sampling_params.verify(self.model_config.vocab_size)
608
608
 
609
609
  # Build return object
610
610
  if isinstance(obj, GenerateReqInput):
@@ -1037,6 +1037,10 @@ class TokenizerManager:
1037
1037
  _: Optional[fastapi.Request] = None,
1038
1038
  ) -> LoadLoRAAdapterReqOutput:
1039
1039
  self.auto_create_handle_loop()
1040
+ if not self.server_args.enable_lora:
1041
+ raise ValueError(
1042
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1043
+ )
1040
1044
 
1041
1045
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1042
1046
  # with dp_size > 1.
@@ -1060,6 +1064,10 @@ class TokenizerManager:
1060
1064
  _: Optional[fastapi.Request] = None,
1061
1065
  ) -> UnloadLoRAAdapterReqOutput:
1062
1066
  self.auto_create_handle_loop()
1067
+ if not self.server_args.enable_lora:
1068
+ raise ValueError(
1069
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1070
+ )
1063
1071
 
1064
1072
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1065
1073
  # with dp_size > 1.
@@ -1359,7 +1367,7 @@ class TokenizerManager:
1359
1367
  while True:
1360
1368
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1361
1369
  self._result_dispatcher(recv_obj)
1362
- self.last_receive_tstamp = time.time()
1370
+ self.last_receive_tstamp = time.perf_counter()
1363
1371
 
1364
1372
  def _handle_batch_output(
1365
1373
  self,
@@ -174,6 +174,20 @@ class TpModelWorker:
174
174
  self.model_runner.token_to_kv_pool.size,
175
175
  )
176
176
 
177
+ @property
178
+ def sliding_window_size(self) -> Optional[int]:
179
+ return self.model_runner.sliding_window_size
180
+
181
+ @property
182
+ def is_hybrid(self) -> bool:
183
+ return self.model_runner.is_hybrid is not None
184
+
185
+ def get_tokens_per_layer_info(self):
186
+ return (
187
+ self.model_runner.full_max_total_num_tokens,
188
+ self.model_runner.swa_max_total_num_tokens,
189
+ )
190
+
177
191
  def get_pad_input_ids_func(self):
178
192
  return getattr(self.model_runner.model, "pad_input_ids", None)
179
193
 
@@ -102,6 +102,17 @@ class TpModelWorkerClient:
102
102
  def get_worker_info(self):
103
103
  return self.worker.get_worker_info()
104
104
 
105
+ def get_tokens_per_layer_info(self):
106
+ return self.worker.get_tokens_per_layer_info()
107
+
108
+ @property
109
+ def sliding_window_size(self) -> Optional[int]:
110
+ return self.worker.sliding_window_size
111
+
112
+ @property
113
+ def is_hybrid(self) -> bool:
114
+ return self.worker.is_hybrid
115
+
105
116
  def get_pad_input_ids_func(self):
106
117
  return self.worker.get_pad_input_ids_func()
107
118
 
@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
57
57
  def debug_print(self) -> str:
58
58
  return ""
59
59
 
60
- def log_usage(self, evictable_size: int = 0):
61
- num_used = self.size - (self.available_size() + evictable_size)
62
- msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
63
- return msg, num_used
64
-
65
60
  def available_size(self):
66
61
  return len(self.free_pages) * self.page_size
67
62
 
@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
190
185
  self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
191
186
 
192
187
  def available_size(self):
193
- return min(self.full_available_size(), self.swa_available_size())
188
+ raise NotImplementedError()
194
189
 
195
190
  def full_available_size(self):
196
191
  return self.full_attn_allocator.available_size()
@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
214
209
  )
215
210
  return msg
216
211
 
217
- def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
218
- used_full = self.size_full - (self.full_available_size() + full_evictable_size)
219
- used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
220
- msg = (
221
- f"#token: full={used_full}, swa={used_swa}, "
222
- f"token usage: full={used_full / self.size_full:.2f}, "
223
- f"swa={used_swa / self.size_swa:.2f}, "
224
- )
225
- return msg, used_full
226
-
227
212
  def get_kvcache(self):
228
213
  return self._kvcache
229
214
 
@@ -541,6 +526,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
541
526
  self.is_not_in_free_group = True
542
527
  self.free_group = []
543
528
 
529
+ def get_cpu_copy(self, indices):
530
+ return self._kvcache.get_cpu_copy(indices)
531
+
532
+ def load_cpu_copy(self, kv_cache_cpu, indices):
533
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
534
+
544
535
 
545
536
  def alloc_extend_kernel_ascend(
546
537
  prefix_lens,
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
2
+ from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
56
56
  pass
57
57
 
58
58
  @abstractmethod
59
- def dec_lock_ref(self, node: Any):
59
+ def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
60
60
  pass
61
61
 
62
62
  def evictable_size(self):
63
63
  return 0
64
64
 
65
+ def full_evictable_size(self):
66
+ return 0
67
+
68
+ def swa_evictable_size(self):
69
+ return 0
70
+
65
71
  def protected_size(self):
66
72
  return 0
67
73
 
74
+ def full_protected_size(self):
75
+ return 0
76
+
77
+ def swa_protected_size(self):
78
+ return 0
79
+
68
80
  def total_size(self):
69
81
  raise NotImplementedError()
70
82
 
@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
61
61
  def inc_lock_ref(self, node: Any):
62
62
  return 0
63
63
 
64
- def dec_lock_ref(self, node: Any):
64
+ def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
65
65
  return 0
66
66
 
67
67
  def pretty_print(self):
@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
80
80
  super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
81
81
  assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
82
82
 
83
- def evict(
83
+ def evict_swa(
84
84
  self,
85
85
  req: Req,
86
86
  prelen: int,
@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
95
95
  ]
96
96
  self.token_to_kv_pool_allocator.free_swa(free_slots)
97
97
  req.evicted_seqlen_local = new_evicted_seqlen_local
98
+
99
+ def evict(self, num_tokens: int):
100
+ pass
@@ -0,0 +1,152 @@
1
+ import hashlib
2
+ import logging
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
13
+ hasher = hashlib.sha256()
14
+
15
+ if prior_hash:
16
+ hasher.update(bytes.fromhex(prior_hash))
17
+
18
+ for t in token_ids:
19
+ hasher.update(t.to_bytes(4, byteorder="little", signed=False))
20
+
21
+ return hasher.hexdigest()
22
+
23
+
24
+ class HiCacheStorage(ABC):
25
+ """
26
+ HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
27
+ It abstracts the underlying storage mechanism, allowing different implementations to be used.
28
+ """
29
+
30
+ # todo, translate tensor object access for different TP ranks
31
+ # potentially pass model and TP configs into storage backend
32
+ # todo, the page size of storage backend does not have to be the same as the same as host memory pool
33
+
34
+ @abstractmethod
35
+ def get(
36
+ self, key: str, target_location: Optional[torch.Tensor] = None
37
+ ) -> torch.Tensor | None:
38
+ """
39
+ Retrieve the value associated with the given key.
40
+ Returns None if the key does not exist.
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def batch_get(
46
+ self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
47
+ ) -> List[torch.Tensor | None]:
48
+ """
49
+ Retrieve values for multiple keys.
50
+ Returns a list of tensors or None for each key.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def set(self, key, value) -> bool:
56
+ """
57
+ Store the value associated with the given key.
58
+ Returns True if the operation was successful, False otherwise.
59
+ """
60
+ pass
61
+
62
+ @abstractmethod
63
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
64
+ """
65
+ Store multiple key-value pairs.
66
+ Returns True if all operations were successful, False otherwise.
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def exists(self, key: str) -> bool:
72
+ """
73
+ Check if the key exists in the storage.
74
+ Returns True if the key exists, False otherwise.
75
+ """
76
+ pass
77
+
78
+
79
+ class HiCacheFile(HiCacheStorage):
80
+
81
+ def __init__(self, file_path: str = "/tmp/hicache"):
82
+ self.file_path = file_path
83
+ if not os.path.exists(self.file_path):
84
+ os.makedirs(self.file_path)
85
+ logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
86
+
87
+ def get(
88
+ self, key: str, target_location: Optional[torch.Tensor] = None
89
+ ) -> torch.Tensor | None:
90
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
91
+ try:
92
+ # todo: fixing the target_location logic to enable in-place loading
93
+ loaded_tensor = torch.load(tensor_path)
94
+ if isinstance(loaded_tensor, torch.Tensor):
95
+ return loaded_tensor
96
+ else:
97
+ logger.error(f"Loaded data for key {key} is not a tensor.")
98
+ return None
99
+ except FileNotFoundError:
100
+ return None
101
+
102
+ def batch_get(
103
+ self,
104
+ keys: List[str],
105
+ target_locations: Optional[List[torch.Tensor]] = None,
106
+ ) -> List[torch.Tensor | None]:
107
+ return [
108
+ self.get(key, target_location)
109
+ for key, target_location in zip(
110
+ keys, target_locations or [None] * len(keys)
111
+ )
112
+ ]
113
+
114
+ def set(self, key: str, value: torch.Tensor) -> bool:
115
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
116
+ if self.exists(key):
117
+ logger.debug(f"Key {key} already exists. Skipped.")
118
+ return True
119
+ try:
120
+ torch.save(value, tensor_path)
121
+ return True
122
+ except Exception as e:
123
+ logger.error(f"Failed to save tensor {key}: {e}")
124
+ return False
125
+
126
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
127
+ for key, value in zip(keys, values):
128
+ if not self.set(key, value):
129
+ return False
130
+ return True
131
+
132
+ def exists(self, key: str) -> bool:
133
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
134
+ return os.path.exists(tensor_path)
135
+
136
+ def delete(self, key: str) -> None:
137
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
138
+ try:
139
+ os.remove(tensor_path)
140
+ except FileNotFoundError:
141
+ logger.warning(f"Key {key} does not exist. Cannot delete.")
142
+ return
143
+
144
+ def clear(self) -> None:
145
+ try:
146
+ for filename in os.listdir(self.file_path):
147
+ file_path = os.path.join(self.file_path, filename)
148
+ if os.path.isfile(file_path):
149
+ os.remove(file_path)
150
+ logger.info("Cleared all entries in HiCacheFile storage.")
151
+ except Exception as e:
152
+ logger.error(f"Failed to clear HiCacheFile storage: {e}")
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
37
  hicache_io_backend: str,
38
+ hicache_storage_backend: Optional[str] = None,
38
39
  ):
39
40
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
40
41
  if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
49
50
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
50
51
 
51
52
  self.tp_group = tp_cache_group
53
+ self.enable_storage = hicache_storage_backend is not None
54
+ # todo: customizable storage prefetch threshold
55
+ self.prefetch_threshold = 256
52
56
 
53
57
  self.load_cache_event = threading.Event()
54
58
  self.cache_controller = HiCacheController(
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
58
62
  load_cache_event=self.load_cache_event,
59
63
  write_policy=hicache_write_policy,
60
64
  io_backend=hicache_io_backend,
65
+ storage_backend=hicache_storage_backend,
66
+ prefetch_threshold=self.prefetch_threshold,
61
67
  )
62
68
 
63
69
  # record the nodes with ongoing write through
64
70
  self.ongoing_write_through = {}
65
71
  # record the node segments with ongoing load back
66
72
  self.ongoing_load_back = {}
73
+ # record the ongoing prefetch requests
74
+ self.ongoing_prefetch = {}
75
+ self.ongoing_backup = {}
67
76
  # todo: dynamically adjust the threshold
68
77
  self.write_through_threshold = (
69
78
  1 if hicache_write_policy == "write_through" else 3
70
79
  )
80
+ self.write_through_threshold_storage = 3
71
81
  self.load_back_threshold = 10
72
82
  super().__init__(
73
83
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
108
118
 
109
119
  return len(host_indices)
110
120
 
121
+ def write_backup_storage(self, node: TreeNode):
122
+ operation_id = self.cache_controller.write_storage(
123
+ node.host_value, node.key, node.parent.get_last_hash_value()
124
+ )
125
+ self.ongoing_backup[operation_id] = node
126
+ node.protect_host()
127
+
111
128
  def inc_hit_count(self, node: TreeNode):
112
- if node.backuped or self.cache_controller.write_policy == "write_back":
129
+ if self.cache_controller.write_policy == "write_back":
113
130
  return
114
131
  node.hit_count += 1
115
- if node.hit_count >= self.write_through_threshold:
116
- self.write_backup(node)
117
- node.hit_count = 0
132
+
133
+ if not node.backuped:
134
+ if node.hit_count >= self.write_through_threshold:
135
+ # write to host if the node is not backuped
136
+ self.write_backup(node)
137
+ else:
138
+ if (
139
+ self.enable_storage
140
+ and (not node.backuped_storage)
141
+ and node.hit_count >= self.write_through_threshold_storage
142
+ ):
143
+ # if the node is backuped on host memory but not on storage
144
+ self.write_backup_storage(node)
118
145
 
119
146
  def writing_check(self, write_back=False):
120
147
  if write_back:
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
221
248
  if not x.evicted:
222
249
  continue
223
250
 
251
+ # node is protected from eviction as it has ongoing prefetch or backup to storage
252
+ if x.host_ref_counter > 0:
253
+ continue
254
+
224
255
  num_evicted += self.cache_controller.evict_host(x.host_value)
225
256
 
226
257
  for k, v in x.parent.children.items():
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
314
345
  def check_hicache_events(self):
315
346
  self.writing_check()
316
347
  self.loading_check()
348
+ if self.enable_storage:
349
+ self.check_revoked_prefetch()
350
+ self.check_backup_progress()
351
+
352
+ def check_revoked_prefetch(self):
353
+ queue_size = torch.tensor(
354
+ self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
355
+ )
356
+ if torch.distributed.get_world_size(group=self.tp_group) > 1:
357
+ # synchrnoize TP workers to make the same update to hiradix cache
358
+ torch.distributed.all_reduce(
359
+ queue_size,
360
+ op=torch.distributed.ReduceOp.MIN,
361
+ group=self.tp_group,
362
+ )
363
+ for _ in range(queue_size.item()):
364
+ req_id = self.cache_controller.prefetch_revoke_queue.get()
365
+ if req_id in self.ongoing_prefetch:
366
+ last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
367
+ last_host_node.release_host()
368
+ self.cache_controller.mem_pool_host.free(host_indices)
369
+ del self.ongoing_prefetch[req_id]
370
+
371
+ def check_backup_progress(self):
372
+ queue_size = torch.tensor(
373
+ self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
374
+ )
375
+ if torch.distributed.get_world_size(group=self.tp_group) > 1:
376
+ # synchrnoize TP workers to make the same update to hiradix cache
377
+ torch.distributed.all_reduce(
378
+ queue_size,
379
+ op=torch.distributed.ReduceOp.MIN,
380
+ group=self.tp_group,
381
+ )
382
+ for _ in range(queue_size.item()):
383
+ ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
384
+ self.ongoing_backup[ack_id].hash_value = hash_value
385
+ self.ongoing_backup[ack_id].release_host()
386
+ del self.ongoing_backup[ack_id]
387
+
388
+ def check_prefetch_progress(self, req_id: str):
389
+ if req_id not in self.ongoing_prefetch:
390
+ # there is no ongoing prefetch for this request or it has been revoked
391
+ return
392
+
393
+ # todo: more policies for prefetch progress such as timeout
394
+ # the current policy is to prefetch with best effort and terminate when queuing is over
395
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
396
+ req_id
397
+ ]
398
+ completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
399
+ operation
400
+ )
401
+ logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
402
+
403
+ min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
404
+ if torch.distributed.get_world_size(group=self.tp_group) > 1:
405
+ # synchrnoize TP workers to make the same update to hiradix cache
406
+ torch.distributed.all_reduce(
407
+ min_completed_tokens,
408
+ op=torch.distributed.ReduceOp.MIN,
409
+ group=self.tp_group,
410
+ )
411
+ min_completed_tokens = min_completed_tokens.item()
412
+ fetched_token_ids = token_ids[:min_completed_tokens]
413
+ written_indices = host_indices[:min_completed_tokens]
414
+ matched_length = self._insert_helper_host(
415
+ last_host_node,
416
+ fetched_token_ids,
417
+ written_indices,
418
+ hash_value[:min_completed_tokens],
419
+ )
420
+
421
+ self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
422
+ self.cache_controller.mem_pool_host.free(
423
+ host_indices[min_completed_tokens:completed_tokens]
424
+ )
425
+ last_host_node.release_host()
426
+ del self.ongoing_prefetch[req_id]
317
427
 
318
428
  def match_prefix(self, key: List[int], **kwargs):
319
429
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
348
458
  host_hit_length=host_hit_length,
349
459
  )
350
460
 
461
+ def prefetch_from_storage(
462
+ self,
463
+ req_id: str,
464
+ last_host_node: TreeNode,
465
+ new_input_tokens: List[int],
466
+ last_hash: Optional[str] = None,
467
+ ):
468
+ if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
469
+ return
470
+
471
+ last_host_node.protect_host()
472
+ host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
473
+ if host_indices is None:
474
+ self.evict_host(len(new_input_tokens))
475
+ host_indices = self.cache_controller.mem_pool_host.alloc(
476
+ len(new_input_tokens)
477
+ )
478
+ if host_indices is None:
479
+ last_host_node.release_host()
480
+ # no sufficient host memory to prefetch
481
+ return
482
+ operation = self.cache_controller.prefetch(
483
+ req_id, host_indices, new_input_tokens, last_hash
484
+ )
485
+ self.ongoing_prefetch[req_id] = (
486
+ last_host_node,
487
+ new_input_tokens,
488
+ host_indices,
489
+ operation,
490
+ )
491
+
492
+ def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
493
+ node.last_access_time = time.monotonic()
494
+ if len(key) == 0:
495
+ return 0
496
+
497
+ child_key = self.get_child_key_fn(key)
498
+
499
+ matched_length = 0
500
+ while len(key) > 0 and child_key in node.children.keys():
501
+ node = node.children[child_key]
502
+ node.last_access_time = time.monotonic()
503
+ prefix_len = self.key_match_fn(node.key, key)
504
+ key = key[prefix_len:]
505
+ host_value = host_value[prefix_len:]
506
+ hash_value = hash_value[prefix_len:]
507
+ matched_length += prefix_len
508
+
509
+ if prefix_len < len(node.key):
510
+ new_node = self._split_node(node.key, node, prefix_len)
511
+ node = new_node
512
+
513
+ if len(key):
514
+ child_key = self.get_child_key_fn(key)
515
+
516
+ if len(key):
517
+ new_node = TreeNode()
518
+ new_node.parent = node
519
+ new_node.key = key
520
+ new_node.value = None
521
+ new_node.host_value = host_value
522
+ new_node.hash_value = hash_value
523
+ node.children[child_key] = new_node
524
+ return matched_length
525
+
351
526
  def _match_prefix_helper(self, node: TreeNode, key: List):
352
527
  node.last_access_time = time.monotonic()
353
528
  child_key = self.get_child_key_fn(key)
@@ -520,8 +520,13 @@ class SWAKVPool(KVCache):
520
520
  self.layers_mapping[global_layer_id] = (swa_layer_id, True)
521
521
  self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
522
522
 
523
+ k_size, v_size = self.get_kv_size_bytes()
524
+ self.mem_usage = (k_size + v_size) / GB
525
+
523
526
  def get_kv_size_bytes(self):
524
- raise NotImplementedError
527
+ k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
528
+ k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
529
+ return k_size + k_size_swa, v_size + v_size_swa
525
530
 
526
531
  def get_contiguous_buf_infos(self):
527
532
  full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
@@ -597,6 +602,16 @@ class SWAKVPool(KVCache):
597
602
  layer_id_override=layer_id_pool,
598
603
  )
599
604
 
605
+ def load_from_host_per_layer(
606
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
607
+ ):
608
+ raise NotImplementedError("HiCache not supported for SWAKVPool.")
609
+
610
+ def backup_to_host_all_layer(
611
+ self, host_pool, host_indices, device_indices, io_backend
612
+ ):
613
+ raise NotImplementedError("HiCache not supported for SWAKVPool.")
614
+
600
615
 
601
616
  class AscendTokenToKVPool(MHATokenToKVPool):
602
617