sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1025 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Copyright 2023-2024 SGLang Team
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ """
19
+ The radix tree data structure for managing the hybrid (full and SWA) KV cache.
20
+ """
21
+
22
+ import heapq
23
+ import time
24
+ from collections import defaultdict
25
+ from functools import partial
26
+ from typing import TYPE_CHECKING, List, Optional, Tuple
27
+
28
+ import torch
29
+
30
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
31
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
32
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
33
+
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.managers.schedule_batch import Req
36
+
37
+ import logging
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class TreeNode:
43
+
44
+ counter = 0
45
+ swa_uuid_counter = 1
46
+
47
+ def __init__(self, id: Optional[int] = None):
48
+ self.children = defaultdict(TreeNode)
49
+ self.parent: TreeNode = None
50
+ self.key: List[int] = None
51
+ self.value: Optional[torch.Tensor] = None
52
+ # swa_tombstone is used to indicate the kv indices have been freed for swa layers
53
+ self.swa_tombstone = False
54
+ # invariant: for any node, if swa_lock_ref is locked, full_lock_ref must be locked;
55
+ # if full_lock_ref is locked, swa_lock_ref doesn't need to be locked. So,
56
+ # full_lock_ref is always >= swa_lock_ref.
57
+ self.full_lock_ref = 0
58
+ self.swa_lock_ref = 0
59
+ # last access time is only used for sanity check. LRU is maintained by the lru list.
60
+ self.last_access_time = time.monotonic()
61
+
62
+ self.hit_count = 0
63
+ # indicating the node is loading KV cache from host
64
+ self.loading = False
65
+ # store the host indices of KV cache
66
+ self.host_value = None
67
+
68
+ # for lru list, invariant:
69
+ # 1. prev has greater last_access_time
70
+ # 2. next has smaller last_access_time
71
+ self.prev = None
72
+ self.next = None
73
+ self.swa_prev = None
74
+ self.swa_next = None
75
+
76
+ self.id = TreeNode.counter if id is None else id
77
+ TreeNode.counter += 1
78
+ self.swa_uuid = None
79
+
80
+ @property
81
+ def evicted(self):
82
+ return self.value is None
83
+
84
+ @property
85
+ def backuped(self):
86
+ return self.host_value is not None
87
+
88
+ def __lt__(self, other: "TreeNode"):
89
+ return self.last_access_time < other.last_access_time
90
+
91
+
92
+ def _key_match_page_size1(key0: List, key1: List):
93
+ i = 0
94
+ for k0, k1 in zip(key0, key1):
95
+ if k0 != k1:
96
+ break
97
+ i += 1
98
+ return i
99
+
100
+
101
+ def _key_match_paged(key0: List, key1: List, page_size: int):
102
+ min_len = min(len(key0), len(key1))
103
+
104
+ i = 0
105
+ while i < min_len:
106
+ if key0[i : i + page_size] != key1[i : i + page_size]:
107
+ break
108
+ i += page_size
109
+
110
+ return i
111
+
112
+
113
+ def gen_swa_uuid() -> int:
114
+ TreeNode.swa_uuid_counter += 1
115
+ return TreeNode.swa_uuid_counter
116
+
117
+
118
+ class LRUList:
119
+ def __init__(self, swa: bool = False):
120
+ self.swa = swa
121
+ if self.swa:
122
+ self.prv = "swa_prev"
123
+ self.nxt = "swa_next"
124
+ self.lock_ref = "swa_lock_ref"
125
+ else:
126
+ self.prv = "prev"
127
+ self.nxt = "next"
128
+ self.lock_ref = "full_lock_ref"
129
+ # Initialize dummy head and tail nodes
130
+ self.head = TreeNode() # Most recently used side
131
+ self.tail = TreeNode() # Least recently used side
132
+ setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
133
+ setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
134
+ self.cache = {}
135
+
136
+ def _add_node(self, node):
137
+ """Helper to add node right after head (most recently used)"""
138
+ self._add_node_after(self.head, node)
139
+
140
+ def _add_node_after(self, old_node, new_node):
141
+ """Helper to add node right after old_node"""
142
+ setattr(new_node, self.prv, old_node) # new_node.prev = old_node
143
+ setattr(
144
+ new_node, self.nxt, getattr(old_node, self.nxt)
145
+ ) # new_node.next = old_node.next
146
+ setattr(
147
+ getattr(old_node, self.nxt), self.prv, new_node
148
+ ) # old_node.next.prev = new_node
149
+ setattr(old_node, self.nxt, new_node) # old_node.next = new_node
150
+
151
+ def _remove_node(self, node):
152
+ """Helper to remove node from linked list"""
153
+ setattr(
154
+ getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
155
+ ) # node.prev.next = node.next
156
+ setattr(
157
+ getattr(node, self.nxt), self.prv, getattr(node, self.prv)
158
+ ) # node.next.prev = node.prev
159
+
160
+ def _get_lru(self) -> Optional[TreeNode]:
161
+ """
162
+ Get the least recently used node
163
+ """
164
+ if len(self.cache) == 0:
165
+ return None
166
+ return getattr(self.tail, self.prv)
167
+
168
+ def reset_node_mru(self, node):
169
+ """
170
+ Move a (existing) node to most recently used position
171
+ """
172
+ assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
173
+ assert (
174
+ not self.swa or not node.swa_tombstone
175
+ ), f"Resetting swa tombstone node in swa lru list: {node.id=}"
176
+ self._remove_node(node)
177
+ self._add_node(node)
178
+
179
+ def reset_node_and_parents_mru(self, node, root_node):
180
+ """
181
+ Move an (existing) node and its parents to most recently used position. Child node is
182
+ more recently used than parent node.
183
+ """
184
+ prev_node = self.head
185
+ while node != root_node:
186
+ # for swa lru list, only reset non-tombstone nodes
187
+ if not self.swa or not node.swa_tombstone:
188
+ assert (
189
+ node.id in self.cache
190
+ ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
191
+ self._remove_node(node)
192
+ self._add_node_after(prev_node, node)
193
+ prev_node = node
194
+ node = node.parent
195
+
196
+ def insert_mru(self, node):
197
+ """
198
+ Insert a (new) node as most recently used
199
+ """
200
+ assert (
201
+ not self.swa or not node.swa_tombstone
202
+ ), f"Inserting swa tombstone node in swa lru list: {node.id=}"
203
+ assert (
204
+ node.id not in self.cache
205
+ ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
206
+ self.cache[node.id] = node
207
+ self._add_node(node)
208
+
209
+ def remove_node(self, node: TreeNode):
210
+ """
211
+ Remove node from lru list
212
+ """
213
+ assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
214
+ assert (
215
+ not self.swa or not node.swa_tombstone
216
+ ), f"Removing swa tombstone node from swa lru list: {node.id=}"
217
+ del self.cache[node.id]
218
+ self._remove_node(node)
219
+
220
+ def get_lru_no_lock(self) -> Optional[TreeNode]:
221
+ """
222
+ Get the least recently used node that is not locked
223
+ """
224
+ return self.get_prev_no_lock(self.tail, check_id=False)
225
+
226
+ def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
227
+ """
228
+ Get the least recently used leaf node that is not locked
229
+ """
230
+ return self.get_prev_leaf_no_lock(self.tail, check_id=False)
231
+
232
+ def get_prev_no_lock(
233
+ self, node: TreeNode, check_id: bool = True
234
+ ) -> Optional[TreeNode]:
235
+ """
236
+ Get the previous (i.e. more recently used) node that is not locked
237
+ """
238
+ if check_id:
239
+ assert (
240
+ node.id in self.cache
241
+ ), f"Getting prev of node {node.id=} not in lru list"
242
+ x = getattr(node, self.prv) # x = node.prev
243
+ while getattr(x, self.lock_ref) > 0:
244
+ x = getattr(x, self.prv) # x = x.prev
245
+ # if x is the head, it means there is no node in the lru list without lock
246
+ if x == self.head:
247
+ return None
248
+ return x
249
+
250
+ def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
251
+ """
252
+ Get the previous (i.e. more recently used) leaf node that is not locked
253
+ """
254
+ if check_id:
255
+ assert (
256
+ node.id in self.cache
257
+ ), f"Getting prev of node {node.id=} not in lru list"
258
+ x = getattr(node, self.prv) # x = node.prev
259
+ while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
260
+ x = getattr(x, self.prv) # x = x.prev
261
+ # if x is the head, it means there is no leaf node in the lru list without lock
262
+ if x == self.head:
263
+ return None
264
+ return x
265
+
266
+ def in_list(self, node: Optional[TreeNode]):
267
+ """
268
+ Check if the node is in the lru list
269
+ """
270
+ if not node:
271
+ return False
272
+ return node.id in self.cache
273
+
274
+ # Note: this is expensive, only use for debug
275
+ def sanity_check_evictable_size(self):
276
+ """
277
+ Check the evictable size (i.e. the size of the nodes that are not locked)
278
+ """
279
+ node = self.get_lru_no_lock()
280
+ evictable_size = 0
281
+ while self.in_list(node):
282
+ evictable_size += len(node.value)
283
+ node = self.get_prev_no_lock(node)
284
+ return evictable_size
285
+
286
+ # Note: this is expensive, only use for debug or idle check
287
+ def sanity_check(self, tree_cache: "SWARadixCache"):
288
+ """
289
+ Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
290
+ checking if the lru list is valid.
291
+ """
292
+ try:
293
+ if self.swa:
294
+ nodes = tree_cache._collect_nontombstone_nodes()
295
+ else:
296
+ nodes = tree_cache._collect_all_nodes()
297
+ total_nodes = len(nodes)
298
+ total_lru_plus_1 = len(self.cache) + 1
299
+ # heapify based on last_access_time
300
+ heapq.heapify(nodes)
301
+ # the root node is not in the lru list
302
+ assert (
303
+ len(nodes) == len(self.cache) + 1
304
+ ), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}"
305
+
306
+ x_lru = self._get_lru()
307
+ while len(nodes):
308
+ x = heapq.heappop(nodes)
309
+ if x == tree_cache.root_node:
310
+ # root node is not in the lru list
311
+ continue
312
+ assert (
313
+ x == x_lru
314
+ ), f"Incorrect LRU list, {self.swa=}, x: {x.id=} != x_lru: {x_lru.id=}"
315
+ assert (
316
+ x_lru.full_lock_ref == 0
317
+ ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
318
+ assert (
319
+ x_lru.swa_lock_ref == 0
320
+ ), f"x_lru should not be locked when idle, {x_lru.swa_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
321
+ x_lru = getattr(x, self.prv)
322
+
323
+ if self.swa:
324
+ evictable_size = tree_cache.swa_evictable_size()
325
+ lru_list_evictable_size = tree_cache.swa_lru_list_evictable_size()
326
+ else:
327
+ evictable_size = tree_cache.full_evictable_size()
328
+ lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
329
+
330
+ assert (
331
+ evictable_size == lru_list_evictable_size
332
+ ), f"{self.swa=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
333
+ except Exception as e:
334
+ msg = f"SWA Radix tree sanity check failed, ping @hanming-lu: {e}"
335
+ logger.error(msg)
336
+ raise Exception(msg)
337
+
338
+
339
+ class SWARadixCache(BasePrefixCache):
340
+ def __init__(
341
+ self,
342
+ req_to_token_pool: ReqToTokenPool,
343
+ token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
344
+ sliding_window_size: int,
345
+ page_size: int,
346
+ disable: bool = False,
347
+ ):
348
+ assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
349
+ self.req_to_token_pool = req_to_token_pool
350
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
351
+ self.page_size = page_size
352
+ self.disable = disable
353
+
354
+ if self.token_to_kv_pool_allocator:
355
+ self.device = self.token_to_kv_pool_allocator.device
356
+ else:
357
+ self.device = torch.device("cpu")
358
+
359
+ if self.page_size == 1:
360
+ self.key_match_fn = _key_match_page_size1
361
+ self.get_child_key_fn = lambda key: key[0]
362
+ else:
363
+ self.key_match_fn = partial(_key_match_paged, page_size=page_size)
364
+ self.get_child_key_fn = lambda key: tuple(key[:page_size])
365
+
366
+ self.sliding_window_size = sliding_window_size
367
+ self.reset()
368
+
369
+ ##### Public API #####
370
+
371
+ def reset(self) -> None:
372
+ self.root_node = TreeNode()
373
+ self.root_node.key = []
374
+ self.root_node.value = []
375
+ self.root_node.full_lock_ref = 1
376
+ self.root_node.swa_lock_ref = 1
377
+ self.full_evictable_size_ = 0
378
+ self.swa_evictable_size_ = 0
379
+ self.full_protected_size_ = 0
380
+ self.swa_protected_size_ = 0
381
+ # LRU lists are used to maintain the order of eviction of the nodes in the tree
382
+ self.full_lru_list = LRUList(swa=False)
383
+ self.swa_lru_list = LRUList(swa=True)
384
+
385
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
386
+ """Find the matching prefix from the radix tree.
387
+ Args:
388
+ key: A list of token IDs to find a matching prefix.
389
+ Returns:
390
+ A tuple of a tensor of matching prefix token IDs and
391
+ the last node that contains the prefix values. Note that
392
+ this API can modify the internal state of the Radix tree.
393
+ The last node create a new child if the prefix is shorter
394
+ than the last node's value.
395
+ """
396
+ if self.disable or len(key) == 0:
397
+ return MatchResult(
398
+ device_indices=torch.empty(
399
+ (0,),
400
+ dtype=torch.int64,
401
+ device=self.device,
402
+ ),
403
+ last_device_node=self.root_node,
404
+ last_host_node=self.root_node,
405
+ )
406
+
407
+ if self.page_size != 1:
408
+ page_aligned_len = len(key) // self.page_size * self.page_size
409
+ key = key[:page_aligned_len]
410
+
411
+ value, last_node = self._match_prefix_helper(key)
412
+ if value:
413
+ value = torch.cat(value)
414
+ else:
415
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
416
+ return MatchResult(
417
+ device_indices=value,
418
+ last_device_node=last_node,
419
+ last_host_node=last_node,
420
+ )
421
+
422
+ def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
423
+ if self.disable:
424
+ return 0
425
+
426
+ if value is None:
427
+ value = [x for x in key]
428
+ return self._insert_helper(self.root_node, key, value, prev_prefix_len)
429
+
430
+ def cache_finished_req(self, req: Req) -> None:
431
+ """Cache request when it finishes."""
432
+ if self.disable:
433
+ kv_indices = self.req_to_token_pool.req_to_token[
434
+ req.req_pool_idx,
435
+ : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
436
+ ]
437
+ self.token_to_kv_pool_allocator.free(kv_indices)
438
+ self.req_to_token_pool.free(req.req_pool_idx)
439
+ return
440
+
441
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
442
+ kv_indices = self.req_to_token_pool.req_to_token[
443
+ req.req_pool_idx, : len(token_ids)
444
+ ]
445
+
446
+ if self.page_size != 1:
447
+ page_aligned_len = len(kv_indices) // self.page_size * self.page_size
448
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
449
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
450
+ else:
451
+ page_aligned_len = len(kv_indices)
452
+ page_aligned_kv_indices = kv_indices.clone()
453
+
454
+ # Radix Cache takes one ref in memory pool
455
+ # insert the token_ids and kv_indices into the radix tree
456
+ # Note: the insert function already frees the overlapped kv_indices
457
+ new_prefix_len = self.insert(
458
+ token_ids[:page_aligned_len],
459
+ page_aligned_kv_indices,
460
+ len(req.prefix_indices),
461
+ )
462
+
463
+ # Remove req slot release the cache lock
464
+ self.req_to_token_pool.free(req.req_pool_idx)
465
+ self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
+
467
+ def cache_unfinished_req(self, req: Req) -> None:
468
+ """Cache request when it is unfinished."""
469
+ if self.disable:
470
+ kv_indices = self.req_to_token_pool.req_to_token[
471
+ req.req_pool_idx, : len(req.fill_ids)
472
+ ]
473
+
474
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
475
+ req.prefix_indices = kv_indices
476
+ return
477
+
478
+ token_ids = req.fill_ids
479
+ kv_indices = self.req_to_token_pool.req_to_token[
480
+ req.req_pool_idx, : len(token_ids)
481
+ ]
482
+
483
+ if self.page_size != 1:
484
+ page_aligned_len = len(kv_indices) // self.page_size * self.page_size
485
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
486
+ else:
487
+ page_aligned_len = len(kv_indices)
488
+ page_aligned_kv_indices = kv_indices.clone()
489
+ page_aligned_token_ids = token_ids[:page_aligned_len]
490
+
491
+ # Radix Cache takes one ref in memory pool
492
+ # Note: the insert function already frees the overlapped kv_indices
493
+ new_prefix_len = self.insert(
494
+ page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
495
+ )
496
+
497
+ # The prefix indices could be updated, reuse it
498
+ new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
499
+ assert len(req.prefix_indices) <= len(
500
+ new_indices
501
+ ), f"{req.prefix_indices=}, {new_indices=}"
502
+ assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
503
+ self.req_to_token_pool.write(
504
+ (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
505
+ new_indices[len(req.prefix_indices) :],
506
+ )
507
+
508
+ self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
509
+ swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
510
+
511
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
512
+ if self.page_size != 1:
513
+ req.prefix_indices = torch.cat(
514
+ [new_indices, kv_indices[len(new_indices) :]]
515
+ )
516
+ else:
517
+ req.prefix_indices = new_indices
518
+ req.last_node = new_last_node
519
+ req.swa_uuid_for_lock = swa_uuid_for_lock
520
+
521
+ def pretty_print(self) -> None:
522
+ self._print_helper(self.root_node, 0)
523
+ total_size, total_swa_size = self._total_size_helper()
524
+ print(f"#full_tokens: {total_size}, #swa_tokens: {total_swa_size}")
525
+
526
+ def total_size(self) -> Tuple[int, int]:
527
+ return self._total_size_helper()
528
+
529
+ def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None:
530
+ if self.disable:
531
+ return
532
+
533
+ full_num_evicted = 0
534
+ swa_num_evicted = 0
535
+ if full_num_tokens > 0:
536
+ # get the least recently used leaf node that is not locked
537
+ x = self.full_lru_list.get_leaf_lru_no_lock()
538
+
539
+ while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
540
+ assert (
541
+ x != self.root_node
542
+ ), f"root node should not exist in full lru list, {x.id=}"
543
+ assert x.full_lock_ref == 0, f"node is in use, {x.id=}"
544
+
545
+ # 1. free node kv indices, evict full and swa tokens
546
+ self.token_to_kv_pool_allocator.free(x.value)
547
+ full_num_evicted += len(x.value)
548
+ swa_num_evicted += len(x.value)
549
+
550
+ # 2. get the next leaf, update the lru lists
551
+ x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
552
+ self.full_lru_list.remove_node(x)
553
+ self.swa_lru_list.remove_node(x)
554
+
555
+ # 3. delete the leaf node
556
+ self._delete_leaf(x)
557
+
558
+ # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
559
+ x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
560
+ full_num_evicted += leaf_full_num_evicted
561
+
562
+ # 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so
563
+ # we need to get the first leaf node in the lru list
564
+ if len(x.parent.children) == 0:
565
+ x_next = self.full_lru_list.get_leaf_lru_no_lock()
566
+
567
+ x = x_next
568
+
569
+ if swa_num_evicted < swa_num_tokens:
570
+ # get the least recently used node that is not locked, doesn't have to be a leaf
571
+ x = self.swa_lru_list.get_lru_no_lock()
572
+
573
+ # evict lru leaf nodes until swa_num_tokens is reached
574
+ while swa_num_evicted < swa_num_tokens and (self.swa_lru_list.in_list(x)):
575
+ assert not x.swa_tombstone, f"duplicate swa tombstone node, {x.id=}"
576
+ assert x != self.root_node, f"root node is not evictable, {x.id=}"
577
+ assert x.swa_lock_ref == 0, f"node is in use by swa kv indices, {x.id=}"
578
+
579
+ if len(x.children) > 0:
580
+ # 1. an internal node, free swa tokens.
581
+ self.token_to_kv_pool_allocator.free_swa(x.value)
582
+ swa_num_evicted += len(x.value)
583
+
584
+ # 2. get the next node, update the lru lists
585
+ x_next = self.swa_lru_list.get_prev_no_lock(x)
586
+ self.swa_lru_list.remove_node(x)
587
+
588
+ # 3. tombstone the node
589
+ self._tombstone_internal_node(x)
590
+ else:
591
+ assert (
592
+ x.full_lock_ref == 0
593
+ ), f"leaf node with full lock must also have swa lock, {x.id=}"
594
+ # 1. a leaf node, free full and swa tokens
595
+ self.token_to_kv_pool_allocator.free(x.value)
596
+ full_num_evicted += len(x.value)
597
+ swa_num_evicted += len(x.value)
598
+
599
+ # 2. get the next node, update the lru lists
600
+ x_next = self.swa_lru_list.get_prev_no_lock(x)
601
+ self.full_lru_list.remove_node(x)
602
+ self.swa_lru_list.remove_node(x)
603
+
604
+ # 3. delete the leaf node
605
+ self._delete_leaf(x)
606
+
607
+ # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
608
+ self._iteratively_delete_tombstone_leaf(x)
609
+
610
+ x = x_next
611
+
612
+ def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
613
+ """
614
+ Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs
615
+ to be passed to dec_lock_ref.
616
+ It locks the full_lock_ref for nodes between the [last node, root), exclusive.
617
+ It locks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
618
+ """
619
+ if self.disable:
620
+ return None
621
+
622
+ swa_lock_size = 0
623
+ swa_uuid_for_lock = None
624
+ while node != self.root_node:
625
+ # lock full from node to root
626
+ assert (
627
+ node.full_lock_ref >= 0
628
+ ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
629
+ if node.full_lock_ref == 0:
630
+ self.full_evictable_size_ -= len(node.value)
631
+ self.full_protected_size_ += len(node.value)
632
+ node.full_lock_ref += 1
633
+
634
+ # lock swa if we have not reached the sliding window size.
635
+ # When we reach the sliding window size, we will set the swa_uuid_for_lock.
636
+ # caller needs to pass the swa_uuid_for_lock to dec_lock_ref
637
+ if swa_lock_size < self.sliding_window_size:
638
+ assert (
639
+ not node.swa_tombstone
640
+ ), f"inc_lock_swa on swa_tombstone node, {node.id=}"
641
+ if node.swa_lock_ref == 0:
642
+ self.swa_evictable_size_ -= len(node.value)
643
+ self.swa_protected_size_ += len(node.value)
644
+ node.swa_lock_ref += 1
645
+ swa_lock_size += len(node.value)
646
+ if swa_lock_size >= self.sliding_window_size:
647
+ if node.swa_uuid is None:
648
+ node.swa_uuid = gen_swa_uuid()
649
+ swa_uuid_for_lock = node.swa_uuid
650
+ node = node.parent
651
+ return swa_uuid_for_lock
652
+
653
+ def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None):
654
+ """
655
+ Decrement the lock reference count for the node.
656
+ It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
657
+ It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
658
+ If swa_uuid_for_lock is None, it unlocks to the root, exclusive.
659
+ """
660
+ if self.disable:
661
+ return
662
+
663
+ dec_lock_swa = True
664
+ while node != self.root_node:
665
+ assert (
666
+ node.full_lock_ref > 0
667
+ ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
668
+ if node.full_lock_ref == 1:
669
+ self.full_evictable_size_ += len(node.value)
670
+ self.full_protected_size_ -= len(node.value)
671
+ node.full_lock_ref -= 1
672
+
673
+ if dec_lock_swa:
674
+ assert (
675
+ not node.swa_tombstone
676
+ ), f"dec_lock_ref on swa_tombstone node, {node.id=}"
677
+ assert (
678
+ node.swa_lock_ref > 0
679
+ ), f"dec_lock_ref on node with {node.swa_lock_ref=}, {node.id=}"
680
+
681
+ if node.swa_lock_ref == 1:
682
+ self.swa_evictable_size_ += len(node.value)
683
+ self.swa_protected_size_ -= len(node.value)
684
+ node.swa_lock_ref -= 1
685
+ if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock:
686
+ dec_lock_swa = False
687
+
688
+ node = node.parent
689
+
690
+ def sanity_check(self):
691
+ self.full_lru_list.sanity_check(self)
692
+ self.swa_lru_list.sanity_check(self)
693
+
694
+ def evictable_size(self) -> Tuple[int, int]:
695
+ # Note: use full_evictable_size() and swa_evictable_size() instead.
696
+ raise NotImplementedError
697
+
698
+ def full_evictable_size(self) -> int:
699
+ return self.full_evictable_size_
700
+
701
+ def swa_evictable_size(self) -> int:
702
+ return self.swa_evictable_size_
703
+
704
+ # Note: this is expensive, only use for debug
705
+ def full_lru_list_evictable_size(self) -> int:
706
+ return self.full_lru_list.sanity_check_evictable_size()
707
+
708
+ # Note: this is expensive, only use for debug
709
+ def swa_lru_list_evictable_size(self) -> int:
710
+ return self.swa_lru_list.sanity_check_evictable_size()
711
+
712
+ def protected_size(self) -> Tuple[int, int]:
713
+ # Note: use full_protected_size() and swa_protected_size() instead.
714
+ raise NotImplementedError
715
+
716
+ def full_protected_size(self) -> int:
717
+ # protected size refers to the size of the full cache that is locked
718
+ return self.full_protected_size_
719
+
720
+ def swa_protected_size(self) -> int:
721
+ # protected size refers to the size of the swa cache that is locked
722
+ return self.swa_protected_size_
723
+
724
+ def all_values_flatten(self) -> torch.Tensor:
725
+ values = []
726
+
727
+ def _dfs_helper(node: TreeNode):
728
+ for _, child in node.children.items():
729
+ values.append(child.value)
730
+ _dfs_helper(child)
731
+
732
+ _dfs_helper(self.root_node)
733
+ return torch.cat(values)
734
+
735
+ ##### Internal Helper Functions #####
736
+
737
+ def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
738
+ """
739
+ SWA prefix matching helper. It factors in the sliding window size such that
740
+ the matched node is guaranteed to either 1. connected to root without swa tombstone,
741
+ or 2. the number of matching tokens from the matched node to the last swa tombstone
742
+ node is greater than or equal to the sliding window size.
743
+ """
744
+ node = self.root_node
745
+ child_key = self.get_child_key_fn(key)
746
+
747
+ value = []
748
+ # for path connected to root without tombstone, always match, so set to inf
749
+ match_len_since_tombstone = float("inf")
750
+ best_value_len = 0
751
+ best_last_node = node
752
+ while len(key) > 0 and child_key in node.children.keys():
753
+ child = node.children[child_key]
754
+
755
+ # update best_value_len and best_last_node if needed
756
+ if (
757
+ child.swa_tombstone
758
+ and match_len_since_tombstone >= self.sliding_window_size
759
+ ):
760
+ best_value_len = len(value)
761
+ best_last_node = node
762
+ match_len_since_tombstone = 0
763
+
764
+ prefix_len = self.key_match_fn(child.key, key)
765
+ if prefix_len < len(child.key):
766
+ new_node = self._split_node(child.key, child, prefix_len)
767
+ value.append(new_node.value)
768
+ if not new_node.swa_tombstone:
769
+ match_len_since_tombstone += len(new_node.value)
770
+ node = new_node
771
+ break
772
+ else:
773
+ value.append(child.value)
774
+ if not child.swa_tombstone:
775
+ match_len_since_tombstone += len(child.value)
776
+ node = child
777
+ key = key[prefix_len:]
778
+
779
+ if len(key):
780
+ child_key = self.get_child_key_fn(key)
781
+
782
+ # handle best_value_len and best_last_node, for the case that last node is fully matched
783
+ if match_len_since_tombstone >= self.sliding_window_size:
784
+ best_value_len = len(value)
785
+ best_last_node = node
786
+
787
+ # update time for matched nodes, and make nodes closer to root to be least recently used
788
+ # this allows swa to evict nodes closer to root first
789
+ self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
790
+ self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
791
+
792
+ # This last_access_time is for sanity check, can be deleted after validation in production
793
+ cur_time = time.monotonic()
794
+ while node:
795
+ node.last_access_time = cur_time
796
+ cur_time -= 0.0001
797
+ node = node.parent
798
+
799
+ return value[:best_value_len], best_last_node
800
+
801
+ def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
802
+ # new_node -> child
803
+ new_node = TreeNode()
804
+ new_node.children = {self.get_child_key_fn(key[split_len:]): child}
805
+ new_node.parent = child.parent
806
+ new_node.swa_tombstone = child.swa_tombstone
807
+ new_node.full_lock_ref = child.full_lock_ref
808
+ new_node.swa_lock_ref = child.swa_lock_ref
809
+ new_node.key = child.key[:split_len]
810
+ new_node.value = child.value[:split_len]
811
+ # parent inherits the swa_uuid from child for swa lock ref
812
+ new_node.swa_uuid = child.swa_uuid
813
+ child.swa_uuid = None
814
+ # child time should be later than parent's time for swa tombstone
815
+ child.last_access_time = time.monotonic()
816
+
817
+ # remove the child from the lru lists because it is being split
818
+ self.full_lru_list.remove_node(child)
819
+ if not new_node.swa_tombstone:
820
+ self.swa_lru_list.remove_node(child)
821
+ child.parent = new_node
822
+ child.key = child.key[split_len:]
823
+ child.value = child.value[split_len:]
824
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
825
+
826
+ # insert the new node and child into the lru lists, insert
827
+ # parent first so that parent is after child in the lru list
828
+ self.full_lru_list.insert_mru(new_node)
829
+ self.full_lru_list.insert_mru(child)
830
+ if not new_node.swa_tombstone:
831
+ self.swa_lru_list.insert_mru(new_node)
832
+ self.swa_lru_list.insert_mru(child)
833
+ return new_node
834
+
835
+ def _insert_helper(
836
+ self, node: TreeNode, key: List, value, update_kv_after_len: int
837
+ ) -> int:
838
+ # Update the last access time from root to leaf, so that
839
+ # swa will tombstone the node closer to root first
840
+ node.last_access_time = time.monotonic()
841
+ if node != self.root_node:
842
+ self.full_lru_list.reset_node_mru(node)
843
+ if not node.swa_tombstone:
844
+ self.swa_lru_list.reset_node_mru(node)
845
+ if len(key) == 0:
846
+ return 0
847
+
848
+ child_key = self.get_child_key_fn(key)
849
+
850
+ total_prefix_length = 0
851
+ while len(key) > 0 and child_key in node.children.keys():
852
+ node = node.children[child_key]
853
+ node.last_access_time = time.monotonic()
854
+ self.full_lru_list.reset_node_mru(node)
855
+ if not node.swa_tombstone:
856
+ self.swa_lru_list.reset_node_mru(node)
857
+ prefix_len = self.key_match_fn(node.key, key)
858
+
859
+ if prefix_len < len(node.key):
860
+ new_node = self._split_node(node.key, node, prefix_len)
861
+ node = new_node
862
+
863
+ # if tombstone after update_kv_after_len, update node.value to be the input value.
864
+ # This is needed because it is possible that the last sliding window size tokens
865
+ # contains tombstone. If this is the case and we don't update the kv value, then
866
+ # the prefill prefix matching will stuck.
867
+ if update_kv_after_len < total_prefix_length + prefix_len:
868
+ first_diff_idx = max(0, update_kv_after_len - total_prefix_length)
869
+ if node.swa_tombstone:
870
+ assert (
871
+ node.swa_lock_ref == 0
872
+ ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}"
873
+ self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:])
874
+ node.value = value[:prefix_len]
875
+ node.swa_tombstone = False
876
+
877
+ # insert the node into the lru lists
878
+ self.swa_lru_list.insert_mru(node)
879
+
880
+ self.swa_evictable_size_ += len(node.value)
881
+ else:
882
+ self.token_to_kv_pool_allocator.free(
883
+ value[first_diff_idx:prefix_len]
884
+ )
885
+
886
+ total_prefix_length += prefix_len
887
+ key = key[prefix_len:]
888
+ value = value[prefix_len:]
889
+
890
+ if len(key):
891
+ child_key = self.get_child_key_fn(key)
892
+
893
+ if len(key):
894
+ new_node = TreeNode()
895
+ new_node.parent = node
896
+ new_node.key = key
897
+ new_node.value = value
898
+ self.full_lru_list.insert_mru(new_node)
899
+ self.swa_lru_list.insert_mru(new_node)
900
+ node.children[child_key] = new_node
901
+ self.full_evictable_size_ += len(value)
902
+ self.swa_evictable_size_ += len(value)
903
+ return total_prefix_length
904
+
905
+ def _iteratively_delete_tombstone_leaf(
906
+ self, node: TreeNode
907
+ ) -> Tuple[TreeNode, int]:
908
+ full_num_evicted = 0
909
+ while node.parent.swa_tombstone and len(node.parent.children) == 0:
910
+ # root node is not evictable
911
+ if node.parent == self.root_node:
912
+ break
913
+ # if locked, means node is in use, skip
914
+ if node.parent.full_lock_ref > 0:
915
+ break
916
+ assert (
917
+ node.parent.swa_lock_ref == 0
918
+ ), f"tombstone swa_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.swa_lock_ref=}, {node.parent.id=}"
919
+ # delete tombstone node evicts full tokens
920
+ self.token_to_kv_pool_allocator.free(node.parent.value)
921
+ full_num_evicted += len(node.parent.value)
922
+ self.full_lru_list.remove_node(node.parent)
923
+ self._delete_tombstone_leaf(node.parent)
924
+ node = node.parent
925
+
926
+ return node, full_num_evicted
927
+
928
+ def _delete_leaf(self, node: TreeNode) -> None:
929
+ assert (
930
+ not node.swa_tombstone
931
+ ), f"Invariant violated: leaf node is a tombstone, {node.id=}"
932
+ assert len(node.children) == 0, f"leaf node has children, {node.id=}"
933
+ for k, v in node.parent.children.items():
934
+ if v == node:
935
+ break
936
+ del node.parent.children[k]
937
+ self.full_evictable_size_ -= len(node.key)
938
+ self.swa_evictable_size_ -= len(node.key)
939
+
940
+ def _tombstone_internal_node(self, node: TreeNode) -> None:
941
+ assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
942
+ node.swa_tombstone = True
943
+ self.swa_evictable_size_ -= len(node.key)
944
+
945
+ def _delete_tombstone_leaf(self, node: TreeNode) -> None:
946
+ assert (
947
+ node.swa_tombstone
948
+ ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
949
+ assert len(node.children) == 0, f"leaf node has children, {node.id=}"
950
+ for k, v in node.parent.children.items():
951
+ if v == node:
952
+ break
953
+ del node.parent.children[k]
954
+ self.full_evictable_size_ -= len(node.key)
955
+
956
+ def _collect_leaves(self) -> List[TreeNode]:
957
+ ret_list = []
958
+ stack = [self.root_node]
959
+
960
+ while stack:
961
+ cur_node = stack.pop()
962
+ if len(cur_node.children) == 0:
963
+ ret_list.append(cur_node)
964
+ else:
965
+ stack.extend(cur_node.children.values())
966
+
967
+ return ret_list
968
+
969
+ def _collect_nontombstone_nodes(self) -> List[TreeNode]:
970
+ ret_list = []
971
+ stack = [self.root_node]
972
+
973
+ while stack:
974
+ cur_node = stack.pop()
975
+ if not cur_node.swa_tombstone:
976
+ ret_list.append(cur_node)
977
+ stack.extend(cur_node.children.values())
978
+
979
+ return ret_list
980
+
981
+ def _collect_all_nodes(self) -> List[TreeNode]:
982
+ ret_list = []
983
+ stack = [self.root_node]
984
+ while stack:
985
+ cur_node = stack.pop()
986
+ ret_list.append(cur_node)
987
+ stack.extend(cur_node.children.values())
988
+ return ret_list
989
+
990
+ def _print_helper(self, node: TreeNode, indent: int) -> None:
991
+ """Prints the radix tree in a human-readable format."""
992
+ stack = [(node, indent)]
993
+ while stack:
994
+ current_node, current_indent = stack.pop()
995
+ print(
996
+ " " * current_indent,
997
+ current_node.id,
998
+ len(current_node.key),
999
+ f"fr={current_node.full_lock_ref}",
1000
+ f"sr={current_node.swa_lock_ref}",
1001
+ f"fll={self.full_lru_list.in_list(current_node)}",
1002
+ f"sll={self.swa_lru_list.in_list(current_node)}",
1003
+ f"ts={current_node.swa_tombstone}",
1004
+ )
1005
+ for key, child in current_node.children.items():
1006
+ stack.append((child, current_indent + 2))
1007
+
1008
+ assert key == self.get_child_key_fn(
1009
+ child.key
1010
+ ), f"{key=}, {self.get_child_key_fn(child.key)=}"
1011
+
1012
+ def _total_size_helper(self) -> Tuple[int, int]:
1013
+ total_size = 0
1014
+ total_swa_size = 0
1015
+ stack = [self.root_node]
1016
+ while stack:
1017
+ current_node = stack.pop()
1018
+ total_size += len(current_node.value)
1019
+ if not current_node.swa_tombstone:
1020
+ total_swa_size += len(current_node.value)
1021
+ for child in current_node.children.values():
1022
+ if child.evicted:
1023
+ continue
1024
+ stack.append(child)
1025
+ return total_size, total_swa_size