sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1025 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Copyright 2023-2024 SGLang Team
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
you may not use this file except in compliance with the License.
|
7
|
+
You may obtain a copy of the License at
|
8
|
+
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
See the License for the specific language governing permissions and
|
15
|
+
limitations under the License.
|
16
|
+
"""
|
17
|
+
|
18
|
+
"""
|
19
|
+
The radix tree data structure for managing the hybrid (full and SWA) KV cache.
|
20
|
+
"""
|
21
|
+
|
22
|
+
import heapq
|
23
|
+
import time
|
24
|
+
from collections import defaultdict
|
25
|
+
from functools import partial
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
27
|
+
|
28
|
+
import torch
|
29
|
+
|
30
|
+
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
31
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
32
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.managers.schedule_batch import Req
|
36
|
+
|
37
|
+
import logging
|
38
|
+
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
|
42
|
+
class TreeNode:
|
43
|
+
|
44
|
+
counter = 0
|
45
|
+
swa_uuid_counter = 1
|
46
|
+
|
47
|
+
def __init__(self, id: Optional[int] = None):
|
48
|
+
self.children = defaultdict(TreeNode)
|
49
|
+
self.parent: TreeNode = None
|
50
|
+
self.key: List[int] = None
|
51
|
+
self.value: Optional[torch.Tensor] = None
|
52
|
+
# swa_tombstone is used to indicate the kv indices have been freed for swa layers
|
53
|
+
self.swa_tombstone = False
|
54
|
+
# invariant: for any node, if swa_lock_ref is locked, full_lock_ref must be locked;
|
55
|
+
# if full_lock_ref is locked, swa_lock_ref doesn't need to be locked. So,
|
56
|
+
# full_lock_ref is always >= swa_lock_ref.
|
57
|
+
self.full_lock_ref = 0
|
58
|
+
self.swa_lock_ref = 0
|
59
|
+
# last access time is only used for sanity check. LRU is maintained by the lru list.
|
60
|
+
self.last_access_time = time.monotonic()
|
61
|
+
|
62
|
+
self.hit_count = 0
|
63
|
+
# indicating the node is loading KV cache from host
|
64
|
+
self.loading = False
|
65
|
+
# store the host indices of KV cache
|
66
|
+
self.host_value = None
|
67
|
+
|
68
|
+
# for lru list, invariant:
|
69
|
+
# 1. prev has greater last_access_time
|
70
|
+
# 2. next has smaller last_access_time
|
71
|
+
self.prev = None
|
72
|
+
self.next = None
|
73
|
+
self.swa_prev = None
|
74
|
+
self.swa_next = None
|
75
|
+
|
76
|
+
self.id = TreeNode.counter if id is None else id
|
77
|
+
TreeNode.counter += 1
|
78
|
+
self.swa_uuid = None
|
79
|
+
|
80
|
+
@property
|
81
|
+
def evicted(self):
|
82
|
+
return self.value is None
|
83
|
+
|
84
|
+
@property
|
85
|
+
def backuped(self):
|
86
|
+
return self.host_value is not None
|
87
|
+
|
88
|
+
def __lt__(self, other: "TreeNode"):
|
89
|
+
return self.last_access_time < other.last_access_time
|
90
|
+
|
91
|
+
|
92
|
+
def _key_match_page_size1(key0: List, key1: List):
|
93
|
+
i = 0
|
94
|
+
for k0, k1 in zip(key0, key1):
|
95
|
+
if k0 != k1:
|
96
|
+
break
|
97
|
+
i += 1
|
98
|
+
return i
|
99
|
+
|
100
|
+
|
101
|
+
def _key_match_paged(key0: List, key1: List, page_size: int):
|
102
|
+
min_len = min(len(key0), len(key1))
|
103
|
+
|
104
|
+
i = 0
|
105
|
+
while i < min_len:
|
106
|
+
if key0[i : i + page_size] != key1[i : i + page_size]:
|
107
|
+
break
|
108
|
+
i += page_size
|
109
|
+
|
110
|
+
return i
|
111
|
+
|
112
|
+
|
113
|
+
def gen_swa_uuid() -> int:
|
114
|
+
TreeNode.swa_uuid_counter += 1
|
115
|
+
return TreeNode.swa_uuid_counter
|
116
|
+
|
117
|
+
|
118
|
+
class LRUList:
|
119
|
+
def __init__(self, swa: bool = False):
|
120
|
+
self.swa = swa
|
121
|
+
if self.swa:
|
122
|
+
self.prv = "swa_prev"
|
123
|
+
self.nxt = "swa_next"
|
124
|
+
self.lock_ref = "swa_lock_ref"
|
125
|
+
else:
|
126
|
+
self.prv = "prev"
|
127
|
+
self.nxt = "next"
|
128
|
+
self.lock_ref = "full_lock_ref"
|
129
|
+
# Initialize dummy head and tail nodes
|
130
|
+
self.head = TreeNode() # Most recently used side
|
131
|
+
self.tail = TreeNode() # Least recently used side
|
132
|
+
setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
|
133
|
+
setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
|
134
|
+
self.cache = {}
|
135
|
+
|
136
|
+
def _add_node(self, node):
|
137
|
+
"""Helper to add node right after head (most recently used)"""
|
138
|
+
self._add_node_after(self.head, node)
|
139
|
+
|
140
|
+
def _add_node_after(self, old_node, new_node):
|
141
|
+
"""Helper to add node right after old_node"""
|
142
|
+
setattr(new_node, self.prv, old_node) # new_node.prev = old_node
|
143
|
+
setattr(
|
144
|
+
new_node, self.nxt, getattr(old_node, self.nxt)
|
145
|
+
) # new_node.next = old_node.next
|
146
|
+
setattr(
|
147
|
+
getattr(old_node, self.nxt), self.prv, new_node
|
148
|
+
) # old_node.next.prev = new_node
|
149
|
+
setattr(old_node, self.nxt, new_node) # old_node.next = new_node
|
150
|
+
|
151
|
+
def _remove_node(self, node):
|
152
|
+
"""Helper to remove node from linked list"""
|
153
|
+
setattr(
|
154
|
+
getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
|
155
|
+
) # node.prev.next = node.next
|
156
|
+
setattr(
|
157
|
+
getattr(node, self.nxt), self.prv, getattr(node, self.prv)
|
158
|
+
) # node.next.prev = node.prev
|
159
|
+
|
160
|
+
def _get_lru(self) -> Optional[TreeNode]:
|
161
|
+
"""
|
162
|
+
Get the least recently used node
|
163
|
+
"""
|
164
|
+
if len(self.cache) == 0:
|
165
|
+
return None
|
166
|
+
return getattr(self.tail, self.prv)
|
167
|
+
|
168
|
+
def reset_node_mru(self, node):
|
169
|
+
"""
|
170
|
+
Move a (existing) node to most recently used position
|
171
|
+
"""
|
172
|
+
assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
|
173
|
+
assert (
|
174
|
+
not self.swa or not node.swa_tombstone
|
175
|
+
), f"Resetting swa tombstone node in swa lru list: {node.id=}"
|
176
|
+
self._remove_node(node)
|
177
|
+
self._add_node(node)
|
178
|
+
|
179
|
+
def reset_node_and_parents_mru(self, node, root_node):
|
180
|
+
"""
|
181
|
+
Move an (existing) node and its parents to most recently used position. Child node is
|
182
|
+
more recently used than parent node.
|
183
|
+
"""
|
184
|
+
prev_node = self.head
|
185
|
+
while node != root_node:
|
186
|
+
# for swa lru list, only reset non-tombstone nodes
|
187
|
+
if not self.swa or not node.swa_tombstone:
|
188
|
+
assert (
|
189
|
+
node.id in self.cache
|
190
|
+
), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
|
191
|
+
self._remove_node(node)
|
192
|
+
self._add_node_after(prev_node, node)
|
193
|
+
prev_node = node
|
194
|
+
node = node.parent
|
195
|
+
|
196
|
+
def insert_mru(self, node):
|
197
|
+
"""
|
198
|
+
Insert a (new) node as most recently used
|
199
|
+
"""
|
200
|
+
assert (
|
201
|
+
not self.swa or not node.swa_tombstone
|
202
|
+
), f"Inserting swa tombstone node in swa lru list: {node.id=}"
|
203
|
+
assert (
|
204
|
+
node.id not in self.cache
|
205
|
+
), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
|
206
|
+
self.cache[node.id] = node
|
207
|
+
self._add_node(node)
|
208
|
+
|
209
|
+
def remove_node(self, node: TreeNode):
|
210
|
+
"""
|
211
|
+
Remove node from lru list
|
212
|
+
"""
|
213
|
+
assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
|
214
|
+
assert (
|
215
|
+
not self.swa or not node.swa_tombstone
|
216
|
+
), f"Removing swa tombstone node from swa lru list: {node.id=}"
|
217
|
+
del self.cache[node.id]
|
218
|
+
self._remove_node(node)
|
219
|
+
|
220
|
+
def get_lru_no_lock(self) -> Optional[TreeNode]:
|
221
|
+
"""
|
222
|
+
Get the least recently used node that is not locked
|
223
|
+
"""
|
224
|
+
return self.get_prev_no_lock(self.tail, check_id=False)
|
225
|
+
|
226
|
+
def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
|
227
|
+
"""
|
228
|
+
Get the least recently used leaf node that is not locked
|
229
|
+
"""
|
230
|
+
return self.get_prev_leaf_no_lock(self.tail, check_id=False)
|
231
|
+
|
232
|
+
def get_prev_no_lock(
|
233
|
+
self, node: TreeNode, check_id: bool = True
|
234
|
+
) -> Optional[TreeNode]:
|
235
|
+
"""
|
236
|
+
Get the previous (i.e. more recently used) node that is not locked
|
237
|
+
"""
|
238
|
+
if check_id:
|
239
|
+
assert (
|
240
|
+
node.id in self.cache
|
241
|
+
), f"Getting prev of node {node.id=} not in lru list"
|
242
|
+
x = getattr(node, self.prv) # x = node.prev
|
243
|
+
while getattr(x, self.lock_ref) > 0:
|
244
|
+
x = getattr(x, self.prv) # x = x.prev
|
245
|
+
# if x is the head, it means there is no node in the lru list without lock
|
246
|
+
if x == self.head:
|
247
|
+
return None
|
248
|
+
return x
|
249
|
+
|
250
|
+
def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
|
251
|
+
"""
|
252
|
+
Get the previous (i.e. more recently used) leaf node that is not locked
|
253
|
+
"""
|
254
|
+
if check_id:
|
255
|
+
assert (
|
256
|
+
node.id in self.cache
|
257
|
+
), f"Getting prev of node {node.id=} not in lru list"
|
258
|
+
x = getattr(node, self.prv) # x = node.prev
|
259
|
+
while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
|
260
|
+
x = getattr(x, self.prv) # x = x.prev
|
261
|
+
# if x is the head, it means there is no leaf node in the lru list without lock
|
262
|
+
if x == self.head:
|
263
|
+
return None
|
264
|
+
return x
|
265
|
+
|
266
|
+
def in_list(self, node: Optional[TreeNode]):
|
267
|
+
"""
|
268
|
+
Check if the node is in the lru list
|
269
|
+
"""
|
270
|
+
if not node:
|
271
|
+
return False
|
272
|
+
return node.id in self.cache
|
273
|
+
|
274
|
+
# Note: this is expensive, only use for debug
|
275
|
+
def sanity_check_evictable_size(self):
|
276
|
+
"""
|
277
|
+
Check the evictable size (i.e. the size of the nodes that are not locked)
|
278
|
+
"""
|
279
|
+
node = self.get_lru_no_lock()
|
280
|
+
evictable_size = 0
|
281
|
+
while self.in_list(node):
|
282
|
+
evictable_size += len(node.value)
|
283
|
+
node = self.get_prev_no_lock(node)
|
284
|
+
return evictable_size
|
285
|
+
|
286
|
+
# Note: this is expensive, only use for debug or idle check
|
287
|
+
def sanity_check(self, tree_cache: "SWARadixCache"):
|
288
|
+
"""
|
289
|
+
Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
|
290
|
+
checking if the lru list is valid.
|
291
|
+
"""
|
292
|
+
try:
|
293
|
+
if self.swa:
|
294
|
+
nodes = tree_cache._collect_nontombstone_nodes()
|
295
|
+
else:
|
296
|
+
nodes = tree_cache._collect_all_nodes()
|
297
|
+
total_nodes = len(nodes)
|
298
|
+
total_lru_plus_1 = len(self.cache) + 1
|
299
|
+
# heapify based on last_access_time
|
300
|
+
heapq.heapify(nodes)
|
301
|
+
# the root node is not in the lru list
|
302
|
+
assert (
|
303
|
+
len(nodes) == len(self.cache) + 1
|
304
|
+
), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}"
|
305
|
+
|
306
|
+
x_lru = self._get_lru()
|
307
|
+
while len(nodes):
|
308
|
+
x = heapq.heappop(nodes)
|
309
|
+
if x == tree_cache.root_node:
|
310
|
+
# root node is not in the lru list
|
311
|
+
continue
|
312
|
+
assert (
|
313
|
+
x == x_lru
|
314
|
+
), f"Incorrect LRU list, {self.swa=}, x: {x.id=} != x_lru: {x_lru.id=}"
|
315
|
+
assert (
|
316
|
+
x_lru.full_lock_ref == 0
|
317
|
+
), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
|
318
|
+
assert (
|
319
|
+
x_lru.swa_lock_ref == 0
|
320
|
+
), f"x_lru should not be locked when idle, {x_lru.swa_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
|
321
|
+
x_lru = getattr(x, self.prv)
|
322
|
+
|
323
|
+
if self.swa:
|
324
|
+
evictable_size = tree_cache.swa_evictable_size()
|
325
|
+
lru_list_evictable_size = tree_cache.swa_lru_list_evictable_size()
|
326
|
+
else:
|
327
|
+
evictable_size = tree_cache.full_evictable_size()
|
328
|
+
lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
|
329
|
+
|
330
|
+
assert (
|
331
|
+
evictable_size == lru_list_evictable_size
|
332
|
+
), f"{self.swa=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
|
333
|
+
except Exception as e:
|
334
|
+
msg = f"SWA Radix tree sanity check failed, ping @hanming-lu: {e}"
|
335
|
+
logger.error(msg)
|
336
|
+
raise Exception(msg)
|
337
|
+
|
338
|
+
|
339
|
+
class SWARadixCache(BasePrefixCache):
|
340
|
+
def __init__(
|
341
|
+
self,
|
342
|
+
req_to_token_pool: ReqToTokenPool,
|
343
|
+
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
|
344
|
+
sliding_window_size: int,
|
345
|
+
page_size: int,
|
346
|
+
disable: bool = False,
|
347
|
+
):
|
348
|
+
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
349
|
+
self.req_to_token_pool = req_to_token_pool
|
350
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
351
|
+
self.page_size = page_size
|
352
|
+
self.disable = disable
|
353
|
+
|
354
|
+
if self.token_to_kv_pool_allocator:
|
355
|
+
self.device = self.token_to_kv_pool_allocator.device
|
356
|
+
else:
|
357
|
+
self.device = torch.device("cpu")
|
358
|
+
|
359
|
+
if self.page_size == 1:
|
360
|
+
self.key_match_fn = _key_match_page_size1
|
361
|
+
self.get_child_key_fn = lambda key: key[0]
|
362
|
+
else:
|
363
|
+
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
364
|
+
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
365
|
+
|
366
|
+
self.sliding_window_size = sliding_window_size
|
367
|
+
self.reset()
|
368
|
+
|
369
|
+
##### Public API #####
|
370
|
+
|
371
|
+
def reset(self) -> None:
|
372
|
+
self.root_node = TreeNode()
|
373
|
+
self.root_node.key = []
|
374
|
+
self.root_node.value = []
|
375
|
+
self.root_node.full_lock_ref = 1
|
376
|
+
self.root_node.swa_lock_ref = 1
|
377
|
+
self.full_evictable_size_ = 0
|
378
|
+
self.swa_evictable_size_ = 0
|
379
|
+
self.full_protected_size_ = 0
|
380
|
+
self.swa_protected_size_ = 0
|
381
|
+
# LRU lists are used to maintain the order of eviction of the nodes in the tree
|
382
|
+
self.full_lru_list = LRUList(swa=False)
|
383
|
+
self.swa_lru_list = LRUList(swa=True)
|
384
|
+
|
385
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
386
|
+
"""Find the matching prefix from the radix tree.
|
387
|
+
Args:
|
388
|
+
key: A list of token IDs to find a matching prefix.
|
389
|
+
Returns:
|
390
|
+
A tuple of a tensor of matching prefix token IDs and
|
391
|
+
the last node that contains the prefix values. Note that
|
392
|
+
this API can modify the internal state of the Radix tree.
|
393
|
+
The last node create a new child if the prefix is shorter
|
394
|
+
than the last node's value.
|
395
|
+
"""
|
396
|
+
if self.disable or len(key) == 0:
|
397
|
+
return MatchResult(
|
398
|
+
device_indices=torch.empty(
|
399
|
+
(0,),
|
400
|
+
dtype=torch.int64,
|
401
|
+
device=self.device,
|
402
|
+
),
|
403
|
+
last_device_node=self.root_node,
|
404
|
+
last_host_node=self.root_node,
|
405
|
+
)
|
406
|
+
|
407
|
+
if self.page_size != 1:
|
408
|
+
page_aligned_len = len(key) // self.page_size * self.page_size
|
409
|
+
key = key[:page_aligned_len]
|
410
|
+
|
411
|
+
value, last_node = self._match_prefix_helper(key)
|
412
|
+
if value:
|
413
|
+
value = torch.cat(value)
|
414
|
+
else:
|
415
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
416
|
+
return MatchResult(
|
417
|
+
device_indices=value,
|
418
|
+
last_device_node=last_node,
|
419
|
+
last_host_node=last_node,
|
420
|
+
)
|
421
|
+
|
422
|
+
def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
|
423
|
+
if self.disable:
|
424
|
+
return 0
|
425
|
+
|
426
|
+
if value is None:
|
427
|
+
value = [x for x in key]
|
428
|
+
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
429
|
+
|
430
|
+
def cache_finished_req(self, req: Req) -> None:
|
431
|
+
"""Cache request when it finishes."""
|
432
|
+
if self.disable:
|
433
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
434
|
+
req.req_pool_idx,
|
435
|
+
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
436
|
+
]
|
437
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
438
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
439
|
+
return
|
440
|
+
|
441
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
442
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
443
|
+
req.req_pool_idx, : len(token_ids)
|
444
|
+
]
|
445
|
+
|
446
|
+
if self.page_size != 1:
|
447
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
448
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
449
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
450
|
+
else:
|
451
|
+
page_aligned_len = len(kv_indices)
|
452
|
+
page_aligned_kv_indices = kv_indices.clone()
|
453
|
+
|
454
|
+
# Radix Cache takes one ref in memory pool
|
455
|
+
# insert the token_ids and kv_indices into the radix tree
|
456
|
+
# Note: the insert function already frees the overlapped kv_indices
|
457
|
+
new_prefix_len = self.insert(
|
458
|
+
token_ids[:page_aligned_len],
|
459
|
+
page_aligned_kv_indices,
|
460
|
+
len(req.prefix_indices),
|
461
|
+
)
|
462
|
+
|
463
|
+
# Remove req slot release the cache lock
|
464
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
465
|
+
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
466
|
+
|
467
|
+
def cache_unfinished_req(self, req: Req) -> None:
|
468
|
+
"""Cache request when it is unfinished."""
|
469
|
+
if self.disable:
|
470
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
471
|
+
req.req_pool_idx, : len(req.fill_ids)
|
472
|
+
]
|
473
|
+
|
474
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
475
|
+
req.prefix_indices = kv_indices
|
476
|
+
return
|
477
|
+
|
478
|
+
token_ids = req.fill_ids
|
479
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
480
|
+
req.req_pool_idx, : len(token_ids)
|
481
|
+
]
|
482
|
+
|
483
|
+
if self.page_size != 1:
|
484
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
485
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
486
|
+
else:
|
487
|
+
page_aligned_len = len(kv_indices)
|
488
|
+
page_aligned_kv_indices = kv_indices.clone()
|
489
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
490
|
+
|
491
|
+
# Radix Cache takes one ref in memory pool
|
492
|
+
# Note: the insert function already frees the overlapped kv_indices
|
493
|
+
new_prefix_len = self.insert(
|
494
|
+
page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
|
495
|
+
)
|
496
|
+
|
497
|
+
# The prefix indices could be updated, reuse it
|
498
|
+
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
499
|
+
assert len(req.prefix_indices) <= len(
|
500
|
+
new_indices
|
501
|
+
), f"{req.prefix_indices=}, {new_indices=}"
|
502
|
+
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
503
|
+
self.req_to_token_pool.write(
|
504
|
+
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
505
|
+
new_indices[len(req.prefix_indices) :],
|
506
|
+
)
|
507
|
+
|
508
|
+
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
509
|
+
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
|
510
|
+
|
511
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
512
|
+
if self.page_size != 1:
|
513
|
+
req.prefix_indices = torch.cat(
|
514
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
515
|
+
)
|
516
|
+
else:
|
517
|
+
req.prefix_indices = new_indices
|
518
|
+
req.last_node = new_last_node
|
519
|
+
req.swa_uuid_for_lock = swa_uuid_for_lock
|
520
|
+
|
521
|
+
def pretty_print(self) -> None:
|
522
|
+
self._print_helper(self.root_node, 0)
|
523
|
+
total_size, total_swa_size = self._total_size_helper()
|
524
|
+
print(f"#full_tokens: {total_size}, #swa_tokens: {total_swa_size}")
|
525
|
+
|
526
|
+
def total_size(self) -> Tuple[int, int]:
|
527
|
+
return self._total_size_helper()
|
528
|
+
|
529
|
+
def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None:
|
530
|
+
if self.disable:
|
531
|
+
return
|
532
|
+
|
533
|
+
full_num_evicted = 0
|
534
|
+
swa_num_evicted = 0
|
535
|
+
if full_num_tokens > 0:
|
536
|
+
# get the least recently used leaf node that is not locked
|
537
|
+
x = self.full_lru_list.get_leaf_lru_no_lock()
|
538
|
+
|
539
|
+
while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
|
540
|
+
assert (
|
541
|
+
x != self.root_node
|
542
|
+
), f"root node should not exist in full lru list, {x.id=}"
|
543
|
+
assert x.full_lock_ref == 0, f"node is in use, {x.id=}"
|
544
|
+
|
545
|
+
# 1. free node kv indices, evict full and swa tokens
|
546
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
547
|
+
full_num_evicted += len(x.value)
|
548
|
+
swa_num_evicted += len(x.value)
|
549
|
+
|
550
|
+
# 2. get the next leaf, update the lru lists
|
551
|
+
x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
|
552
|
+
self.full_lru_list.remove_node(x)
|
553
|
+
self.swa_lru_list.remove_node(x)
|
554
|
+
|
555
|
+
# 3. delete the leaf node
|
556
|
+
self._delete_leaf(x)
|
557
|
+
|
558
|
+
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
|
559
|
+
x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
|
560
|
+
full_num_evicted += leaf_full_num_evicted
|
561
|
+
|
562
|
+
# 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so
|
563
|
+
# we need to get the first leaf node in the lru list
|
564
|
+
if len(x.parent.children) == 0:
|
565
|
+
x_next = self.full_lru_list.get_leaf_lru_no_lock()
|
566
|
+
|
567
|
+
x = x_next
|
568
|
+
|
569
|
+
if swa_num_evicted < swa_num_tokens:
|
570
|
+
# get the least recently used node that is not locked, doesn't have to be a leaf
|
571
|
+
x = self.swa_lru_list.get_lru_no_lock()
|
572
|
+
|
573
|
+
# evict lru leaf nodes until swa_num_tokens is reached
|
574
|
+
while swa_num_evicted < swa_num_tokens and (self.swa_lru_list.in_list(x)):
|
575
|
+
assert not x.swa_tombstone, f"duplicate swa tombstone node, {x.id=}"
|
576
|
+
assert x != self.root_node, f"root node is not evictable, {x.id=}"
|
577
|
+
assert x.swa_lock_ref == 0, f"node is in use by swa kv indices, {x.id=}"
|
578
|
+
|
579
|
+
if len(x.children) > 0:
|
580
|
+
# 1. an internal node, free swa tokens.
|
581
|
+
self.token_to_kv_pool_allocator.free_swa(x.value)
|
582
|
+
swa_num_evicted += len(x.value)
|
583
|
+
|
584
|
+
# 2. get the next node, update the lru lists
|
585
|
+
x_next = self.swa_lru_list.get_prev_no_lock(x)
|
586
|
+
self.swa_lru_list.remove_node(x)
|
587
|
+
|
588
|
+
# 3. tombstone the node
|
589
|
+
self._tombstone_internal_node(x)
|
590
|
+
else:
|
591
|
+
assert (
|
592
|
+
x.full_lock_ref == 0
|
593
|
+
), f"leaf node with full lock must also have swa lock, {x.id=}"
|
594
|
+
# 1. a leaf node, free full and swa tokens
|
595
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
596
|
+
full_num_evicted += len(x.value)
|
597
|
+
swa_num_evicted += len(x.value)
|
598
|
+
|
599
|
+
# 2. get the next node, update the lru lists
|
600
|
+
x_next = self.swa_lru_list.get_prev_no_lock(x)
|
601
|
+
self.full_lru_list.remove_node(x)
|
602
|
+
self.swa_lru_list.remove_node(x)
|
603
|
+
|
604
|
+
# 3. delete the leaf node
|
605
|
+
self._delete_leaf(x)
|
606
|
+
|
607
|
+
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
|
608
|
+
self._iteratively_delete_tombstone_leaf(x)
|
609
|
+
|
610
|
+
x = x_next
|
611
|
+
|
612
|
+
def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
|
613
|
+
"""
|
614
|
+
Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs
|
615
|
+
to be passed to dec_lock_ref.
|
616
|
+
It locks the full_lock_ref for nodes between the [last node, root), exclusive.
|
617
|
+
It locks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
|
618
|
+
"""
|
619
|
+
if self.disable:
|
620
|
+
return None
|
621
|
+
|
622
|
+
swa_lock_size = 0
|
623
|
+
swa_uuid_for_lock = None
|
624
|
+
while node != self.root_node:
|
625
|
+
# lock full from node to root
|
626
|
+
assert (
|
627
|
+
node.full_lock_ref >= 0
|
628
|
+
), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
629
|
+
if node.full_lock_ref == 0:
|
630
|
+
self.full_evictable_size_ -= len(node.value)
|
631
|
+
self.full_protected_size_ += len(node.value)
|
632
|
+
node.full_lock_ref += 1
|
633
|
+
|
634
|
+
# lock swa if we have not reached the sliding window size.
|
635
|
+
# When we reach the sliding window size, we will set the swa_uuid_for_lock.
|
636
|
+
# caller needs to pass the swa_uuid_for_lock to dec_lock_ref
|
637
|
+
if swa_lock_size < self.sliding_window_size:
|
638
|
+
assert (
|
639
|
+
not node.swa_tombstone
|
640
|
+
), f"inc_lock_swa on swa_tombstone node, {node.id=}"
|
641
|
+
if node.swa_lock_ref == 0:
|
642
|
+
self.swa_evictable_size_ -= len(node.value)
|
643
|
+
self.swa_protected_size_ += len(node.value)
|
644
|
+
node.swa_lock_ref += 1
|
645
|
+
swa_lock_size += len(node.value)
|
646
|
+
if swa_lock_size >= self.sliding_window_size:
|
647
|
+
if node.swa_uuid is None:
|
648
|
+
node.swa_uuid = gen_swa_uuid()
|
649
|
+
swa_uuid_for_lock = node.swa_uuid
|
650
|
+
node = node.parent
|
651
|
+
return swa_uuid_for_lock
|
652
|
+
|
653
|
+
def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None):
|
654
|
+
"""
|
655
|
+
Decrement the lock reference count for the node.
|
656
|
+
It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
|
657
|
+
It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
|
658
|
+
If swa_uuid_for_lock is None, it unlocks to the root, exclusive.
|
659
|
+
"""
|
660
|
+
if self.disable:
|
661
|
+
return
|
662
|
+
|
663
|
+
dec_lock_swa = True
|
664
|
+
while node != self.root_node:
|
665
|
+
assert (
|
666
|
+
node.full_lock_ref > 0
|
667
|
+
), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
668
|
+
if node.full_lock_ref == 1:
|
669
|
+
self.full_evictable_size_ += len(node.value)
|
670
|
+
self.full_protected_size_ -= len(node.value)
|
671
|
+
node.full_lock_ref -= 1
|
672
|
+
|
673
|
+
if dec_lock_swa:
|
674
|
+
assert (
|
675
|
+
not node.swa_tombstone
|
676
|
+
), f"dec_lock_ref on swa_tombstone node, {node.id=}"
|
677
|
+
assert (
|
678
|
+
node.swa_lock_ref > 0
|
679
|
+
), f"dec_lock_ref on node with {node.swa_lock_ref=}, {node.id=}"
|
680
|
+
|
681
|
+
if node.swa_lock_ref == 1:
|
682
|
+
self.swa_evictable_size_ += len(node.value)
|
683
|
+
self.swa_protected_size_ -= len(node.value)
|
684
|
+
node.swa_lock_ref -= 1
|
685
|
+
if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock:
|
686
|
+
dec_lock_swa = False
|
687
|
+
|
688
|
+
node = node.parent
|
689
|
+
|
690
|
+
def sanity_check(self):
|
691
|
+
self.full_lru_list.sanity_check(self)
|
692
|
+
self.swa_lru_list.sanity_check(self)
|
693
|
+
|
694
|
+
def evictable_size(self) -> Tuple[int, int]:
|
695
|
+
# Note: use full_evictable_size() and swa_evictable_size() instead.
|
696
|
+
raise NotImplementedError
|
697
|
+
|
698
|
+
def full_evictable_size(self) -> int:
|
699
|
+
return self.full_evictable_size_
|
700
|
+
|
701
|
+
def swa_evictable_size(self) -> int:
|
702
|
+
return self.swa_evictable_size_
|
703
|
+
|
704
|
+
# Note: this is expensive, only use for debug
|
705
|
+
def full_lru_list_evictable_size(self) -> int:
|
706
|
+
return self.full_lru_list.sanity_check_evictable_size()
|
707
|
+
|
708
|
+
# Note: this is expensive, only use for debug
|
709
|
+
def swa_lru_list_evictable_size(self) -> int:
|
710
|
+
return self.swa_lru_list.sanity_check_evictable_size()
|
711
|
+
|
712
|
+
def protected_size(self) -> Tuple[int, int]:
|
713
|
+
# Note: use full_protected_size() and swa_protected_size() instead.
|
714
|
+
raise NotImplementedError
|
715
|
+
|
716
|
+
def full_protected_size(self) -> int:
|
717
|
+
# protected size refers to the size of the full cache that is locked
|
718
|
+
return self.full_protected_size_
|
719
|
+
|
720
|
+
def swa_protected_size(self) -> int:
|
721
|
+
# protected size refers to the size of the swa cache that is locked
|
722
|
+
return self.swa_protected_size_
|
723
|
+
|
724
|
+
def all_values_flatten(self) -> torch.Tensor:
|
725
|
+
values = []
|
726
|
+
|
727
|
+
def _dfs_helper(node: TreeNode):
|
728
|
+
for _, child in node.children.items():
|
729
|
+
values.append(child.value)
|
730
|
+
_dfs_helper(child)
|
731
|
+
|
732
|
+
_dfs_helper(self.root_node)
|
733
|
+
return torch.cat(values)
|
734
|
+
|
735
|
+
##### Internal Helper Functions #####
|
736
|
+
|
737
|
+
def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
|
738
|
+
"""
|
739
|
+
SWA prefix matching helper. It factors in the sliding window size such that
|
740
|
+
the matched node is guaranteed to either 1. connected to root without swa tombstone,
|
741
|
+
or 2. the number of matching tokens from the matched node to the last swa tombstone
|
742
|
+
node is greater than or equal to the sliding window size.
|
743
|
+
"""
|
744
|
+
node = self.root_node
|
745
|
+
child_key = self.get_child_key_fn(key)
|
746
|
+
|
747
|
+
value = []
|
748
|
+
# for path connected to root without tombstone, always match, so set to inf
|
749
|
+
match_len_since_tombstone = float("inf")
|
750
|
+
best_value_len = 0
|
751
|
+
best_last_node = node
|
752
|
+
while len(key) > 0 and child_key in node.children.keys():
|
753
|
+
child = node.children[child_key]
|
754
|
+
|
755
|
+
# update best_value_len and best_last_node if needed
|
756
|
+
if (
|
757
|
+
child.swa_tombstone
|
758
|
+
and match_len_since_tombstone >= self.sliding_window_size
|
759
|
+
):
|
760
|
+
best_value_len = len(value)
|
761
|
+
best_last_node = node
|
762
|
+
match_len_since_tombstone = 0
|
763
|
+
|
764
|
+
prefix_len = self.key_match_fn(child.key, key)
|
765
|
+
if prefix_len < len(child.key):
|
766
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
767
|
+
value.append(new_node.value)
|
768
|
+
if not new_node.swa_tombstone:
|
769
|
+
match_len_since_tombstone += len(new_node.value)
|
770
|
+
node = new_node
|
771
|
+
break
|
772
|
+
else:
|
773
|
+
value.append(child.value)
|
774
|
+
if not child.swa_tombstone:
|
775
|
+
match_len_since_tombstone += len(child.value)
|
776
|
+
node = child
|
777
|
+
key = key[prefix_len:]
|
778
|
+
|
779
|
+
if len(key):
|
780
|
+
child_key = self.get_child_key_fn(key)
|
781
|
+
|
782
|
+
# handle best_value_len and best_last_node, for the case that last node is fully matched
|
783
|
+
if match_len_since_tombstone >= self.sliding_window_size:
|
784
|
+
best_value_len = len(value)
|
785
|
+
best_last_node = node
|
786
|
+
|
787
|
+
# update time for matched nodes, and make nodes closer to root to be least recently used
|
788
|
+
# this allows swa to evict nodes closer to root first
|
789
|
+
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
790
|
+
self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
791
|
+
|
792
|
+
# This last_access_time is for sanity check, can be deleted after validation in production
|
793
|
+
cur_time = time.monotonic()
|
794
|
+
while node:
|
795
|
+
node.last_access_time = cur_time
|
796
|
+
cur_time -= 0.0001
|
797
|
+
node = node.parent
|
798
|
+
|
799
|
+
return value[:best_value_len], best_last_node
|
800
|
+
|
801
|
+
def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
|
802
|
+
# new_node -> child
|
803
|
+
new_node = TreeNode()
|
804
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
805
|
+
new_node.parent = child.parent
|
806
|
+
new_node.swa_tombstone = child.swa_tombstone
|
807
|
+
new_node.full_lock_ref = child.full_lock_ref
|
808
|
+
new_node.swa_lock_ref = child.swa_lock_ref
|
809
|
+
new_node.key = child.key[:split_len]
|
810
|
+
new_node.value = child.value[:split_len]
|
811
|
+
# parent inherits the swa_uuid from child for swa lock ref
|
812
|
+
new_node.swa_uuid = child.swa_uuid
|
813
|
+
child.swa_uuid = None
|
814
|
+
# child time should be later than parent's time for swa tombstone
|
815
|
+
child.last_access_time = time.monotonic()
|
816
|
+
|
817
|
+
# remove the child from the lru lists because it is being split
|
818
|
+
self.full_lru_list.remove_node(child)
|
819
|
+
if not new_node.swa_tombstone:
|
820
|
+
self.swa_lru_list.remove_node(child)
|
821
|
+
child.parent = new_node
|
822
|
+
child.key = child.key[split_len:]
|
823
|
+
child.value = child.value[split_len:]
|
824
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
825
|
+
|
826
|
+
# insert the new node and child into the lru lists, insert
|
827
|
+
# parent first so that parent is after child in the lru list
|
828
|
+
self.full_lru_list.insert_mru(new_node)
|
829
|
+
self.full_lru_list.insert_mru(child)
|
830
|
+
if not new_node.swa_tombstone:
|
831
|
+
self.swa_lru_list.insert_mru(new_node)
|
832
|
+
self.swa_lru_list.insert_mru(child)
|
833
|
+
return new_node
|
834
|
+
|
835
|
+
def _insert_helper(
|
836
|
+
self, node: TreeNode, key: List, value, update_kv_after_len: int
|
837
|
+
) -> int:
|
838
|
+
# Update the last access time from root to leaf, so that
|
839
|
+
# swa will tombstone the node closer to root first
|
840
|
+
node.last_access_time = time.monotonic()
|
841
|
+
if node != self.root_node:
|
842
|
+
self.full_lru_list.reset_node_mru(node)
|
843
|
+
if not node.swa_tombstone:
|
844
|
+
self.swa_lru_list.reset_node_mru(node)
|
845
|
+
if len(key) == 0:
|
846
|
+
return 0
|
847
|
+
|
848
|
+
child_key = self.get_child_key_fn(key)
|
849
|
+
|
850
|
+
total_prefix_length = 0
|
851
|
+
while len(key) > 0 and child_key in node.children.keys():
|
852
|
+
node = node.children[child_key]
|
853
|
+
node.last_access_time = time.monotonic()
|
854
|
+
self.full_lru_list.reset_node_mru(node)
|
855
|
+
if not node.swa_tombstone:
|
856
|
+
self.swa_lru_list.reset_node_mru(node)
|
857
|
+
prefix_len = self.key_match_fn(node.key, key)
|
858
|
+
|
859
|
+
if prefix_len < len(node.key):
|
860
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
861
|
+
node = new_node
|
862
|
+
|
863
|
+
# if tombstone after update_kv_after_len, update node.value to be the input value.
|
864
|
+
# This is needed because it is possible that the last sliding window size tokens
|
865
|
+
# contains tombstone. If this is the case and we don't update the kv value, then
|
866
|
+
# the prefill prefix matching will stuck.
|
867
|
+
if update_kv_after_len < total_prefix_length + prefix_len:
|
868
|
+
first_diff_idx = max(0, update_kv_after_len - total_prefix_length)
|
869
|
+
if node.swa_tombstone:
|
870
|
+
assert (
|
871
|
+
node.swa_lock_ref == 0
|
872
|
+
), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}"
|
873
|
+
self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:])
|
874
|
+
node.value = value[:prefix_len]
|
875
|
+
node.swa_tombstone = False
|
876
|
+
|
877
|
+
# insert the node into the lru lists
|
878
|
+
self.swa_lru_list.insert_mru(node)
|
879
|
+
|
880
|
+
self.swa_evictable_size_ += len(node.value)
|
881
|
+
else:
|
882
|
+
self.token_to_kv_pool_allocator.free(
|
883
|
+
value[first_diff_idx:prefix_len]
|
884
|
+
)
|
885
|
+
|
886
|
+
total_prefix_length += prefix_len
|
887
|
+
key = key[prefix_len:]
|
888
|
+
value = value[prefix_len:]
|
889
|
+
|
890
|
+
if len(key):
|
891
|
+
child_key = self.get_child_key_fn(key)
|
892
|
+
|
893
|
+
if len(key):
|
894
|
+
new_node = TreeNode()
|
895
|
+
new_node.parent = node
|
896
|
+
new_node.key = key
|
897
|
+
new_node.value = value
|
898
|
+
self.full_lru_list.insert_mru(new_node)
|
899
|
+
self.swa_lru_list.insert_mru(new_node)
|
900
|
+
node.children[child_key] = new_node
|
901
|
+
self.full_evictable_size_ += len(value)
|
902
|
+
self.swa_evictable_size_ += len(value)
|
903
|
+
return total_prefix_length
|
904
|
+
|
905
|
+
def _iteratively_delete_tombstone_leaf(
|
906
|
+
self, node: TreeNode
|
907
|
+
) -> Tuple[TreeNode, int]:
|
908
|
+
full_num_evicted = 0
|
909
|
+
while node.parent.swa_tombstone and len(node.parent.children) == 0:
|
910
|
+
# root node is not evictable
|
911
|
+
if node.parent == self.root_node:
|
912
|
+
break
|
913
|
+
# if locked, means node is in use, skip
|
914
|
+
if node.parent.full_lock_ref > 0:
|
915
|
+
break
|
916
|
+
assert (
|
917
|
+
node.parent.swa_lock_ref == 0
|
918
|
+
), f"tombstone swa_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.swa_lock_ref=}, {node.parent.id=}"
|
919
|
+
# delete tombstone node evicts full tokens
|
920
|
+
self.token_to_kv_pool_allocator.free(node.parent.value)
|
921
|
+
full_num_evicted += len(node.parent.value)
|
922
|
+
self.full_lru_list.remove_node(node.parent)
|
923
|
+
self._delete_tombstone_leaf(node.parent)
|
924
|
+
node = node.parent
|
925
|
+
|
926
|
+
return node, full_num_evicted
|
927
|
+
|
928
|
+
def _delete_leaf(self, node: TreeNode) -> None:
|
929
|
+
assert (
|
930
|
+
not node.swa_tombstone
|
931
|
+
), f"Invariant violated: leaf node is a tombstone, {node.id=}"
|
932
|
+
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
933
|
+
for k, v in node.parent.children.items():
|
934
|
+
if v == node:
|
935
|
+
break
|
936
|
+
del node.parent.children[k]
|
937
|
+
self.full_evictable_size_ -= len(node.key)
|
938
|
+
self.swa_evictable_size_ -= len(node.key)
|
939
|
+
|
940
|
+
def _tombstone_internal_node(self, node: TreeNode) -> None:
|
941
|
+
assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
|
942
|
+
node.swa_tombstone = True
|
943
|
+
self.swa_evictable_size_ -= len(node.key)
|
944
|
+
|
945
|
+
def _delete_tombstone_leaf(self, node: TreeNode) -> None:
|
946
|
+
assert (
|
947
|
+
node.swa_tombstone
|
948
|
+
), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
|
949
|
+
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
950
|
+
for k, v in node.parent.children.items():
|
951
|
+
if v == node:
|
952
|
+
break
|
953
|
+
del node.parent.children[k]
|
954
|
+
self.full_evictable_size_ -= len(node.key)
|
955
|
+
|
956
|
+
def _collect_leaves(self) -> List[TreeNode]:
|
957
|
+
ret_list = []
|
958
|
+
stack = [self.root_node]
|
959
|
+
|
960
|
+
while stack:
|
961
|
+
cur_node = stack.pop()
|
962
|
+
if len(cur_node.children) == 0:
|
963
|
+
ret_list.append(cur_node)
|
964
|
+
else:
|
965
|
+
stack.extend(cur_node.children.values())
|
966
|
+
|
967
|
+
return ret_list
|
968
|
+
|
969
|
+
def _collect_nontombstone_nodes(self) -> List[TreeNode]:
|
970
|
+
ret_list = []
|
971
|
+
stack = [self.root_node]
|
972
|
+
|
973
|
+
while stack:
|
974
|
+
cur_node = stack.pop()
|
975
|
+
if not cur_node.swa_tombstone:
|
976
|
+
ret_list.append(cur_node)
|
977
|
+
stack.extend(cur_node.children.values())
|
978
|
+
|
979
|
+
return ret_list
|
980
|
+
|
981
|
+
def _collect_all_nodes(self) -> List[TreeNode]:
|
982
|
+
ret_list = []
|
983
|
+
stack = [self.root_node]
|
984
|
+
while stack:
|
985
|
+
cur_node = stack.pop()
|
986
|
+
ret_list.append(cur_node)
|
987
|
+
stack.extend(cur_node.children.values())
|
988
|
+
return ret_list
|
989
|
+
|
990
|
+
def _print_helper(self, node: TreeNode, indent: int) -> None:
|
991
|
+
"""Prints the radix tree in a human-readable format."""
|
992
|
+
stack = [(node, indent)]
|
993
|
+
while stack:
|
994
|
+
current_node, current_indent = stack.pop()
|
995
|
+
print(
|
996
|
+
" " * current_indent,
|
997
|
+
current_node.id,
|
998
|
+
len(current_node.key),
|
999
|
+
f"fr={current_node.full_lock_ref}",
|
1000
|
+
f"sr={current_node.swa_lock_ref}",
|
1001
|
+
f"fll={self.full_lru_list.in_list(current_node)}",
|
1002
|
+
f"sll={self.swa_lru_list.in_list(current_node)}",
|
1003
|
+
f"ts={current_node.swa_tombstone}",
|
1004
|
+
)
|
1005
|
+
for key, child in current_node.children.items():
|
1006
|
+
stack.append((child, current_indent + 2))
|
1007
|
+
|
1008
|
+
assert key == self.get_child_key_fn(
|
1009
|
+
child.key
|
1010
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
1011
|
+
|
1012
|
+
def _total_size_helper(self) -> Tuple[int, int]:
|
1013
|
+
total_size = 0
|
1014
|
+
total_swa_size = 0
|
1015
|
+
stack = [self.root_node]
|
1016
|
+
while stack:
|
1017
|
+
current_node = stack.pop()
|
1018
|
+
total_size += len(current_node.value)
|
1019
|
+
if not current_node.swa_tombstone:
|
1020
|
+
total_swa_size += len(current_node.value)
|
1021
|
+
for child in current_node.children.values():
|
1022
|
+
if child.evicted:
|
1023
|
+
continue
|
1024
|
+
stack.append(child)
|
1025
|
+
return total_size, total_swa_size
|