sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
1
+ """Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
2
+
3
+ import heapq
4
+ import time
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Any, List, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
12
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
13
+
14
+ if TYPE_CHECKING:
15
+ from sglang.srt.managers.schedule_batch import Req
16
+ else:
17
+ Req = Any # Placeholder for Req type when not type checking
18
+
19
+
20
+ class LoRAKey:
21
+
22
+ def __init__(self, lora_id: str, token_ids: List[int]):
23
+ self.lora_id = (
24
+ lora_id # lora_id of adaptor, should be hash value of adaptor path
25
+ )
26
+ self.token_ids = token_ids # token_ids of the key
27
+
28
+ def __len__(self):
29
+ return len(self.token_ids)
30
+
31
+
32
+ def get_child_key(key: LoRAKey):
33
+ # Here the key of children dict is the hash of lora_id + str(token_ids[0])
34
+ # So the child key can be matched only when lora_id and token_ids[0] are the same
35
+ if key.lora_id is None:
36
+ return hash(str(key.token_ids[0]))
37
+ else:
38
+ return hash(key.lora_id + str(key.token_ids[0]))
39
+
40
+
41
+ class LoRATreeNode:
42
+
43
+ counter = 0
44
+
45
+ def __init__(self, id: Optional[int] = None):
46
+ self.children = defaultdict(LoRATreeNode)
47
+ self.parent: LoRATreeNode = None
48
+ self.key: LoRAKey = None
49
+ self.value: Optional[torch.Tensor] = None
50
+ self.lock_ref = 0
51
+ self.last_access_time = time.monotonic()
52
+
53
+ self.id = LoRATreeNode.counter if id is None else id
54
+ LoRATreeNode.counter += 1
55
+
56
+ @property
57
+ def evicted(self):
58
+ return self.value is None
59
+
60
+ def __lt__(self, other: "LoRATreeNode"):
61
+ return self.last_access_time < other.last_access_time
62
+
63
+
64
+ def _key_match(key0: LoRAKey, key1: LoRAKey):
65
+ if key0.lora_id != key1.lora_id:
66
+ raise ValueError(
67
+ f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
68
+ )
69
+ i = 0
70
+ for k0, k1 in zip(key0.token_ids, key1.token_ids):
71
+ if k0 != k1:
72
+ break
73
+ i += 1
74
+ return i
75
+
76
+
77
+ class LoRARadixCache(BasePrefixCache):
78
+
79
+ def __init__(
80
+ self,
81
+ req_to_token_pool: ReqToTokenPool,
82
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
83
+ page_size: int,
84
+ disable: bool = False,
85
+ ):
86
+ if page_size > 1:
87
+ raise ValueError("LoRARadixCache currently only supports page_size = 1")
88
+
89
+ if token_to_kv_pool_allocator is None:
90
+ raise ValueError(
91
+ "token_to_kv_pool_allocator is required to run LoraRadixCache"
92
+ )
93
+
94
+ self.req_to_token_pool = req_to_token_pool
95
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
96
+ self.page_size = page_size
97
+ self.disable = disable
98
+ self.device = self.token_to_kv_pool_allocator.device
99
+
100
+ self.key_match_fn = _key_match
101
+ self.get_child_key_fn = get_child_key
102
+ self.reset()
103
+
104
+ def reset(self):
105
+ self.root_node = LoRATreeNode()
106
+ self.root_node.key = LoRAKey(lora_id="", token_ids=[])
107
+ self.root_node.value = None
108
+ self.evictable_size_ = 0
109
+ self.protected_size_ = 0
110
+
111
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
112
+ raise ValueError(
113
+ "LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
114
+ )
115
+
116
+ def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
117
+ """Find the matching prefix from the lora radix tree.
118
+ Args:
119
+ key: A LoRAKey to find a matching prefix.
120
+ Returns:
121
+ A tuple of a tensor of matching prefix token IDs and
122
+ the last node that contains the prefix values. Note that
123
+ this API can modify the internal state of the Radix tree.
124
+ The last node create a new child if the prefix is shorter
125
+ than the last node's value.
126
+ """
127
+ if self.disable or len(key) == 0:
128
+ return MatchResult(
129
+ device_indices=torch.empty(
130
+ (0,),
131
+ dtype=torch.int64,
132
+ device=self.device,
133
+ ),
134
+ last_device_node=self.root_node,
135
+ last_host_node=self.root_node,
136
+ )
137
+
138
+ value, last_node = self._match_prefix_helper(self.root_node, key)
139
+ if value:
140
+ value = torch.cat(value)
141
+ else:
142
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
143
+ return MatchResult(
144
+ device_indices=value,
145
+ last_device_node=last_node,
146
+ last_host_node=last_node,
147
+ )
148
+
149
+ def insert(self, key: LoRAKey, value=None):
150
+ if self.disable:
151
+ return 0
152
+
153
+ if value is None:
154
+ value = [x for x in key.token_ids]
155
+ return self._insert_helper(self.root_node, key, value)
156
+
157
+ def cache_finished_req(self, req: Req):
158
+ """Cache request when it finishes."""
159
+ if self.disable:
160
+ kv_indices = self.req_to_token_pool.req_to_token[
161
+ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
162
+ ]
163
+ self.token_to_kv_pool_allocator.free(kv_indices)
164
+ self.req_to_token_pool.free(req.req_pool_idx)
165
+ return
166
+
167
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
168
+ kv_indices = self.req_to_token_pool.req_to_token[
169
+ req.req_pool_idx, : len(token_ids)
170
+ ]
171
+
172
+ page_aligned_len = len(kv_indices)
173
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
174
+
175
+ # Radix Cache takes one ref in memory pool
176
+ lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
177
+ new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
178
+ self.token_to_kv_pool_allocator.free(
179
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
180
+ )
181
+
182
+ # Remove req slot release the cache lock
183
+ self.req_to_token_pool.free(req.req_pool_idx)
184
+ self.dec_lock_ref(req.last_node)
185
+
186
+ def cache_unfinished_req(self, req: Req):
187
+ """Cache request when it is unfinished."""
188
+ if self.disable:
189
+ return
190
+
191
+ token_ids = req.fill_ids
192
+ kv_indices = self.req_to_token_pool.req_to_token[
193
+ req.req_pool_idx, : len(token_ids)
194
+ ]
195
+
196
+ page_aligned_len = len(kv_indices)
197
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
198
+ page_aligned_token_ids = token_ids[:page_aligned_len]
199
+
200
+ # Radix Cache takes one ref in memory pool
201
+ inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
202
+ new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
203
+ self.token_to_kv_pool_allocator.free(
204
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
205
+ )
206
+
207
+ # The prefix indices could be updated, reuse it
208
+ new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
209
+ self.req_to_token_pool.write(
210
+ (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
211
+ new_indices[len(req.prefix_indices) :],
212
+ )
213
+
214
+ self.dec_lock_ref(req.last_node)
215
+ self.inc_lock_ref(new_last_node)
216
+
217
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
218
+ req.prefix_indices = new_indices
219
+ req.last_node = new_last_node
220
+
221
+ def pretty_print(self):
222
+ self._print_helper(self.root_node, 0)
223
+ print(f"#tokens: {self.total_size()}")
224
+
225
+ def total_size(self):
226
+ return self._total_size_helper()
227
+
228
+ def evict(self, num_tokens: int):
229
+ if self.disable:
230
+ return
231
+
232
+ leaves = self._collect_leaves()
233
+ heapq.heapify(leaves)
234
+
235
+ num_evicted = 0
236
+ while num_evicted < num_tokens and len(leaves):
237
+ x = heapq.heappop(leaves)
238
+
239
+ if x == self.root_node:
240
+ break
241
+ if x.lock_ref > 0:
242
+ continue
243
+
244
+ self.token_to_kv_pool_allocator.free(x.value)
245
+ num_evicted += len(x.value)
246
+ self._delete_leaf(x)
247
+
248
+ if len(x.parent.children) == 0:
249
+ heapq.heappush(leaves, x.parent)
250
+
251
+ def inc_lock_ref(self, node: LoRATreeNode):
252
+ if self.disable:
253
+ return 0
254
+
255
+ delta = 0
256
+ while node != self.root_node:
257
+ if node.lock_ref == 0:
258
+ self.evictable_size_ -= len(node.value)
259
+ self.protected_size_ += len(node.value)
260
+ delta -= len(node.value)
261
+ node.lock_ref += 1
262
+ node = node.parent
263
+ return delta
264
+
265
+ def dec_lock_ref(self, node: LoRATreeNode):
266
+ if self.disable:
267
+ return 0
268
+
269
+ delta = 0
270
+ while node != self.root_node:
271
+ if node.lock_ref == 1:
272
+ self.evictable_size_ += len(node.value)
273
+ self.protected_size_ -= len(node.value)
274
+ delta += len(node.value)
275
+ node.lock_ref -= 1
276
+ node = node.parent
277
+ return delta
278
+
279
+ def evictable_size(self):
280
+ return self.evictable_size_
281
+
282
+ def protected_size(self):
283
+ # protected size refers to the size of the cache that is locked
284
+ return self.protected_size_
285
+
286
+ def all_values_flatten(self):
287
+ values = []
288
+
289
+ def _dfs_helper(node: LoRATreeNode):
290
+ for _, child in node.children.items():
291
+ values.append(child.value)
292
+ _dfs_helper(child)
293
+
294
+ _dfs_helper(self.root_node)
295
+ return torch.cat(values)
296
+
297
+ ##### Internal Helper Functions #####
298
+
299
+ def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
300
+ node.last_access_time = time.monotonic()
301
+
302
+ child_key = self.get_child_key_fn(key)
303
+
304
+ value = []
305
+ while len(key) > 0 and child_key in node.children.keys():
306
+ child = node.children[child_key]
307
+ child.last_access_time = time.monotonic()
308
+ prefix_len = self.key_match_fn(child.key, key)
309
+ if prefix_len < len(child.key):
310
+ new_node = self._split_node(child.key, child, prefix_len)
311
+ value.append(new_node.value)
312
+ node = new_node
313
+ break
314
+ else:
315
+ value.append(child.value)
316
+ node = child
317
+ key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
318
+
319
+ if len(key):
320
+ child_key = self.get_child_key_fn(key)
321
+
322
+ return value, node
323
+
324
+ def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
325
+ # new_node -> child
326
+ new_node = LoRATreeNode()
327
+ key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
328
+ key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
329
+ new_node.children = {self.get_child_key_fn(key_split_2): child}
330
+ new_node.parent = child.parent
331
+ new_node.lock_ref = child.lock_ref
332
+ new_node.key = key_split_1
333
+ new_node.value = child.value[:split_len]
334
+ child.parent = new_node
335
+ child.key = key_split_2
336
+ child.value = child.value[split_len:]
337
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
338
+
339
+ return new_node
340
+
341
+ def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
342
+ node.last_access_time = time.monotonic()
343
+ if len(key) == 0:
344
+ return 0
345
+
346
+ child_key = self.get_child_key_fn(key)
347
+
348
+ total_prefix_length = 0
349
+ while len(key) > 0 and child_key in node.children.keys():
350
+ node = node.children[child_key]
351
+ node.last_access_time = time.monotonic()
352
+ prefix_len = self.key_match_fn(node.key, key)
353
+ total_prefix_length += prefix_len
354
+ key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
355
+ value = value[prefix_len:]
356
+
357
+ if prefix_len < len(node.key):
358
+ new_node = self._split_node(node.key, node, prefix_len)
359
+ node = new_node
360
+
361
+ if len(key):
362
+ child_key = self.get_child_key_fn(key)
363
+
364
+ if len(key):
365
+ new_node = LoRATreeNode()
366
+ new_node.parent = node
367
+ new_node.key = key
368
+ new_node.value = value
369
+ node.children[child_key] = new_node
370
+ self.evictable_size_ += len(value)
371
+ return total_prefix_length
372
+
373
+ def _print_helper(self, node: LoRATreeNode, indent: int):
374
+ """Prints the radix tree in a human-readable format."""
375
+ stack = [(node, indent)]
376
+ while stack:
377
+ current_node, current_indent = stack.pop()
378
+ print(
379
+ " " * current_indent,
380
+ len(current_node.key),
381
+ current_node.key.token_ids[:10],
382
+ f"r={current_node.lock_ref}",
383
+ )
384
+ for key, child in current_node.children.items():
385
+ stack.append((child, current_indent + 2))
386
+
387
+ assert key == self.get_child_key_fn(
388
+ child.key
389
+ ), f"{key=}, {self.get_child_key_fn(child.key)=}"
390
+
391
+ def _delete_leaf(self, node):
392
+ for k, v in node.parent.children.items():
393
+ if v == node:
394
+ break
395
+ del node.parent.children[k]
396
+ self.evictable_size_ -= len(node.key)
397
+
398
+ def _total_size_helper(self):
399
+ total_size = 0
400
+ stack = [self.root_node]
401
+ while stack:
402
+ current_node = stack.pop()
403
+ total_size += len(current_node.value)
404
+ for child in current_node.children.values():
405
+ if child.evicted:
406
+ continue
407
+ stack.append(child)
408
+ return total_size
409
+
410
+ def _collect_leaves(self):
411
+ ret_list = []
412
+ stack = [self.root_node]
413
+
414
+ while stack:
415
+ cur_node = stack.pop()
416
+ if len(cur_node.children) == 0:
417
+ ret_list.append(cur_node)
418
+ else:
419
+ stack.extend(cur_node.children.values())
420
+
421
+ return ret_list
@@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache):
358
358
  dst_v=device_pool.v_buffer[layer_id],
359
359
  src_indices=host_indices,
360
360
  dst_indices=device_indices,
361
+ layer_id=layer_id,
361
362
  item_size=self.token_stride_size,
362
363
  src_layout_dim=self.layout_dim,
363
364
  )
@@ -471,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache):
471
472
  * self.dtype.itemsize
472
473
  )
473
474
  for index in range(0, len(indices), self.page_size):
474
- for layer_id in range(self.layer_num):
475
- k_ptr = (
476
- kv_buffer_data_ptr
477
- + indices[index]
478
- * self.head_num
479
- * self.head_dim
480
- * self.dtype.itemsize
481
- + layer_id
482
- * self.size
483
- * self.head_num
484
- * self.head_dim
485
- * self.dtype.itemsize
486
- )
487
- v_ptr = k_ptr + v_offset
488
- ptr_list.append(k_ptr)
489
- ptr_list.append(v_ptr)
490
- key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{layer_id}_k")
492
- key_list.append(f"{key_}_{layer_id}_v")
475
+ k_ptr = (
476
+ kv_buffer_data_ptr
477
+ + indices[index]
478
+ * self.layer_num
479
+ * self.head_num
480
+ * self.head_dim
481
+ * self.dtype.itemsize
482
+ )
483
+ v_ptr = k_ptr + v_offset
484
+ ptr_list.append(k_ptr)
485
+ ptr_list.append(v_ptr)
486
+ key_ = keys[index // self.page_size]
487
+ key_list.append(f"{key_}_k")
488
+ key_list.append(f"{key_}_v")
493
489
  element_size = (
494
- self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
490
+ self.layer_num
491
+ * self.dtype.itemsize
492
+ * self.page_size
493
+ * self.head_num
494
+ * self.head_dim
495
495
  )
496
496
  element_size_list = [element_size] * len(key_list)
497
497
  return key_list, ptr_list, element_size_list
@@ -585,6 +585,7 @@ class MLATokenToKVPoolHost(HostKVCache):
585
585
  dst=device_pool.kv_buffer[layer_id],
586
586
  src_indices=host_indices,
587
587
  dst_indices=device_indices,
588
+ layer_id=layer_id,
588
589
  item_size=self.token_stride_size,
589
590
  src_layout_dim=self.layout_dim,
590
591
  )
@@ -618,7 +619,7 @@ class MLATokenToKVPoolHost(HostKVCache):
618
619
  elif self.layout == "page_first":
619
620
  transfer_kv_all_layer_mla_lf_pf(
620
621
  src_layers=device_pool.data_ptrs,
621
- dst_k=self.kv_buffer,
622
+ dst=self.kv_buffer,
622
623
  src_indices=device_indices,
623
624
  dst_indices=host_indices,
624
625
  item_size=self.token_stride_size,
@@ -685,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache):
685
686
  key_list = []
686
687
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
687
688
  for index in range(0, len(indices), self.page_size):
688
- for layer_id in range(self.layer_num):
689
- k_ptr = (
690
- kv_buffer_data_ptr
691
- + indices[index]
692
- * (self.kv_lora_rank + self.qk_rope_head_dim)
693
- * self.dtype.itemsize
694
- + layer_id
695
- * self.size
696
- * (self.kv_lora_rank + self.qk_rope_head_dim)
697
- * self.dtype.itemsize
698
- )
699
- ptr_list.append(k_ptr)
700
- key_ = keys[index // self.page_size]
701
- key_list.append(f"{key_}_{layer_id}_k")
689
+ k_ptr = (
690
+ kv_buffer_data_ptr
691
+ + indices[index]
692
+ * self.layer_num
693
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
694
+ * self.dtype.itemsize
695
+ )
696
+ ptr_list.append(k_ptr)
697
+ key_ = keys[index // self.page_size]
698
+ key_list.append(f"{key_}_k")
702
699
  element_size = (
703
- self.dtype.itemsize
700
+ self.layer_num
701
+ * self.dtype.itemsize
704
702
  * self.page_size
705
703
  * (self.kv_lora_rank + self.qk_rope_head_dim)
706
704
  )
@@ -1,24 +1,46 @@
1
+ import logging
2
+ from collections import OrderedDict
1
3
  from typing import Dict
2
4
 
3
5
  import torch
4
6
 
7
+ # Set up logging for cache behavior
8
+ logger = logging.getLogger(__name__)
9
+
5
10
 
6
11
  class MultiModalCache:
7
- """MultiModalCache is used to store vlm encoder results"""
12
+ """MultiModalCache is used to store vlm encoder results with LRU eviction"""
8
13
 
9
14
  def __init__(
10
15
  self,
11
16
  max_size: int,
12
17
  ):
13
18
  self.max_size = max_size
14
- self.mm_cache: Dict[int, torch.Tensor] = {}
19
+ self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
15
20
  self.current_size = 0
16
21
 
22
+ def _allocate(self, embedding_size: int) -> bool:
23
+ """Allocate space by evicting least recently used entries"""
24
+ evictions = 0
25
+ while self.current_size + embedding_size > self.max_size and self.mm_cache:
26
+ _, old_embedding = self.mm_cache.popitem(last=False)
27
+ evicted_size = self._get_tensor_size(old_embedding)
28
+ self.current_size -= evicted_size
29
+ evictions += evicted_size
30
+
31
+ if evictions > 0:
32
+ logger.debug(
33
+ f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
34
+ )
35
+
36
+ if self.current_size + embedding_size > self.max_size:
37
+ return False
38
+ return True
39
+
17
40
  def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
18
- if mm_hash in self.mm_cache:
19
- return True
20
41
  data_size = self._get_tensor_size(embedding)
21
- if self.current_size + data_size > self.max_size:
42
+ # Lazy free cache if not enough space
43
+ if not self._allocate(data_size):
22
44
  return False
23
45
  self.mm_cache[mm_hash] = embedding
24
46
  self.current_size += data_size
@@ -28,14 +50,12 @@ class MultiModalCache:
28
50
  return mm_hash in self.mm_cache
29
51
 
30
52
  def get(self, mm_hash: int) -> torch.Tensor:
31
- return self.mm_cache.get(mm_hash)
32
-
33
- def free(self, mm_hash: int) -> bool:
34
- if mm_hash not in self.mm_cache:
35
- return False
36
- old_embedding = self.mm_cache.pop(mm_hash)
37
- self.current_size -= self._get_tensor_size(old_embedding)
38
- return True
53
+ """Get embedding and update LRU order"""
54
+ if mm_hash in self.mm_cache:
55
+ # Move to end (most recently used)
56
+ self.mm_cache.move_to_end(mm_hash)
57
+ return self.mm_cache[mm_hash]
58
+ return None
39
59
 
40
60
  def clear(self):
41
61
  self.mm_cache.clear()
@@ -62,6 +62,7 @@ class TreeNode:
62
62
  self.host_value: Optional[torch.Tensor] = None
63
63
  # store hash values of each pages
64
64
  self.hash_value: Optional[List[str]] = None
65
+ self.backuped_storage = False
65
66
 
66
67
  self.id = TreeNode.counter if id is None else id
67
68
  TreeNode.counter += 1
@@ -74,10 +75,6 @@ class TreeNode:
74
75
  def backuped(self):
75
76
  return self.host_value is not None
76
77
 
77
- @property
78
- def backuped_storage(self):
79
- return self.hash_value is not None and len(self.hash_value) > 0
80
-
81
78
  def protect_host(self):
82
79
  """Protect the host value from eviction."""
83
80
  self.host_ref_counter += 1
@@ -498,7 +495,7 @@ class RadixCache(BasePrefixCache):
498
495
  # One BlockStored per ``page_size`` chunk.
499
496
  if self.enable_kv_cache_events:
500
497
  # First chunk links to the last page of the parent node (if any).
501
- if node.parent is None:
498
+ if node.parent is None or node != self.root_node:
502
499
  parent_block_hash = None
503
500
  else:
504
501
  last_page_start = (
@@ -96,6 +96,8 @@ class Hf3fsClient:
96
96
  )
97
97
  self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
98
98
  self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
99
+ self.shm_r.unlink()
100
+ self.shm_w.unlink()
99
101
 
100
102
  self.rlock = threading.RLock()
101
103
  self.wlock = threading.RLock()
@@ -176,8 +178,6 @@ class Hf3fsClient:
176
178
  del self.iov_w
177
179
  self.shm_r.close()
178
180
  self.shm_w.close()
179
- self.shm_r.unlink()
180
- self.shm_w.unlink()
181
181
 
182
182
  def flush(self) -> None:
183
183
  os.fsync(self.file)