sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
|
|
1
|
+
"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
|
2
|
+
|
3
|
+
import heapq
|
4
|
+
import time
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
12
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from sglang.srt.managers.schedule_batch import Req
|
16
|
+
else:
|
17
|
+
Req = Any # Placeholder for Req type when not type checking
|
18
|
+
|
19
|
+
|
20
|
+
class LoRAKey:
|
21
|
+
|
22
|
+
def __init__(self, lora_id: str, token_ids: List[int]):
|
23
|
+
self.lora_id = (
|
24
|
+
lora_id # lora_id of adaptor, should be hash value of adaptor path
|
25
|
+
)
|
26
|
+
self.token_ids = token_ids # token_ids of the key
|
27
|
+
|
28
|
+
def __len__(self):
|
29
|
+
return len(self.token_ids)
|
30
|
+
|
31
|
+
|
32
|
+
def get_child_key(key: LoRAKey):
|
33
|
+
# Here the key of children dict is the hash of lora_id + str(token_ids[0])
|
34
|
+
# So the child key can be matched only when lora_id and token_ids[0] are the same
|
35
|
+
if key.lora_id is None:
|
36
|
+
return hash(str(key.token_ids[0]))
|
37
|
+
else:
|
38
|
+
return hash(key.lora_id + str(key.token_ids[0]))
|
39
|
+
|
40
|
+
|
41
|
+
class LoRATreeNode:
|
42
|
+
|
43
|
+
counter = 0
|
44
|
+
|
45
|
+
def __init__(self, id: Optional[int] = None):
|
46
|
+
self.children = defaultdict(LoRATreeNode)
|
47
|
+
self.parent: LoRATreeNode = None
|
48
|
+
self.key: LoRAKey = None
|
49
|
+
self.value: Optional[torch.Tensor] = None
|
50
|
+
self.lock_ref = 0
|
51
|
+
self.last_access_time = time.monotonic()
|
52
|
+
|
53
|
+
self.id = LoRATreeNode.counter if id is None else id
|
54
|
+
LoRATreeNode.counter += 1
|
55
|
+
|
56
|
+
@property
|
57
|
+
def evicted(self):
|
58
|
+
return self.value is None
|
59
|
+
|
60
|
+
def __lt__(self, other: "LoRATreeNode"):
|
61
|
+
return self.last_access_time < other.last_access_time
|
62
|
+
|
63
|
+
|
64
|
+
def _key_match(key0: LoRAKey, key1: LoRAKey):
|
65
|
+
if key0.lora_id != key1.lora_id:
|
66
|
+
raise ValueError(
|
67
|
+
f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
|
68
|
+
)
|
69
|
+
i = 0
|
70
|
+
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
71
|
+
if k0 != k1:
|
72
|
+
break
|
73
|
+
i += 1
|
74
|
+
return i
|
75
|
+
|
76
|
+
|
77
|
+
class LoRARadixCache(BasePrefixCache):
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
req_to_token_pool: ReqToTokenPool,
|
82
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
83
|
+
page_size: int,
|
84
|
+
disable: bool = False,
|
85
|
+
):
|
86
|
+
if page_size > 1:
|
87
|
+
raise ValueError("LoRARadixCache currently only supports page_size = 1")
|
88
|
+
|
89
|
+
if token_to_kv_pool_allocator is None:
|
90
|
+
raise ValueError(
|
91
|
+
"token_to_kv_pool_allocator is required to run LoraRadixCache"
|
92
|
+
)
|
93
|
+
|
94
|
+
self.req_to_token_pool = req_to_token_pool
|
95
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
96
|
+
self.page_size = page_size
|
97
|
+
self.disable = disable
|
98
|
+
self.device = self.token_to_kv_pool_allocator.device
|
99
|
+
|
100
|
+
self.key_match_fn = _key_match
|
101
|
+
self.get_child_key_fn = get_child_key
|
102
|
+
self.reset()
|
103
|
+
|
104
|
+
def reset(self):
|
105
|
+
self.root_node = LoRATreeNode()
|
106
|
+
self.root_node.key = LoRAKey(lora_id="", token_ids=[])
|
107
|
+
self.root_node.value = None
|
108
|
+
self.evictable_size_ = 0
|
109
|
+
self.protected_size_ = 0
|
110
|
+
|
111
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
112
|
+
raise ValueError(
|
113
|
+
"LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
|
114
|
+
)
|
115
|
+
|
116
|
+
def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
|
117
|
+
"""Find the matching prefix from the lora radix tree.
|
118
|
+
Args:
|
119
|
+
key: A LoRAKey to find a matching prefix.
|
120
|
+
Returns:
|
121
|
+
A tuple of a tensor of matching prefix token IDs and
|
122
|
+
the last node that contains the prefix values. Note that
|
123
|
+
this API can modify the internal state of the Radix tree.
|
124
|
+
The last node create a new child if the prefix is shorter
|
125
|
+
than the last node's value.
|
126
|
+
"""
|
127
|
+
if self.disable or len(key) == 0:
|
128
|
+
return MatchResult(
|
129
|
+
device_indices=torch.empty(
|
130
|
+
(0,),
|
131
|
+
dtype=torch.int64,
|
132
|
+
device=self.device,
|
133
|
+
),
|
134
|
+
last_device_node=self.root_node,
|
135
|
+
last_host_node=self.root_node,
|
136
|
+
)
|
137
|
+
|
138
|
+
value, last_node = self._match_prefix_helper(self.root_node, key)
|
139
|
+
if value:
|
140
|
+
value = torch.cat(value)
|
141
|
+
else:
|
142
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
143
|
+
return MatchResult(
|
144
|
+
device_indices=value,
|
145
|
+
last_device_node=last_node,
|
146
|
+
last_host_node=last_node,
|
147
|
+
)
|
148
|
+
|
149
|
+
def insert(self, key: LoRAKey, value=None):
|
150
|
+
if self.disable:
|
151
|
+
return 0
|
152
|
+
|
153
|
+
if value is None:
|
154
|
+
value = [x for x in key.token_ids]
|
155
|
+
return self._insert_helper(self.root_node, key, value)
|
156
|
+
|
157
|
+
def cache_finished_req(self, req: Req):
|
158
|
+
"""Cache request when it finishes."""
|
159
|
+
if self.disable:
|
160
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
161
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
162
|
+
]
|
163
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
164
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
165
|
+
return
|
166
|
+
|
167
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
168
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
169
|
+
req.req_pool_idx, : len(token_ids)
|
170
|
+
]
|
171
|
+
|
172
|
+
page_aligned_len = len(kv_indices)
|
173
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
174
|
+
|
175
|
+
# Radix Cache takes one ref in memory pool
|
176
|
+
lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
|
177
|
+
new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
|
178
|
+
self.token_to_kv_pool_allocator.free(
|
179
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
180
|
+
)
|
181
|
+
|
182
|
+
# Remove req slot release the cache lock
|
183
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
184
|
+
self.dec_lock_ref(req.last_node)
|
185
|
+
|
186
|
+
def cache_unfinished_req(self, req: Req):
|
187
|
+
"""Cache request when it is unfinished."""
|
188
|
+
if self.disable:
|
189
|
+
return
|
190
|
+
|
191
|
+
token_ids = req.fill_ids
|
192
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
193
|
+
req.req_pool_idx, : len(token_ids)
|
194
|
+
]
|
195
|
+
|
196
|
+
page_aligned_len = len(kv_indices)
|
197
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
198
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
199
|
+
|
200
|
+
# Radix Cache takes one ref in memory pool
|
201
|
+
inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
|
202
|
+
new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
|
203
|
+
self.token_to_kv_pool_allocator.free(
|
204
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
205
|
+
)
|
206
|
+
|
207
|
+
# The prefix indices could be updated, reuse it
|
208
|
+
new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
|
209
|
+
self.req_to_token_pool.write(
|
210
|
+
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
211
|
+
new_indices[len(req.prefix_indices) :],
|
212
|
+
)
|
213
|
+
|
214
|
+
self.dec_lock_ref(req.last_node)
|
215
|
+
self.inc_lock_ref(new_last_node)
|
216
|
+
|
217
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
218
|
+
req.prefix_indices = new_indices
|
219
|
+
req.last_node = new_last_node
|
220
|
+
|
221
|
+
def pretty_print(self):
|
222
|
+
self._print_helper(self.root_node, 0)
|
223
|
+
print(f"#tokens: {self.total_size()}")
|
224
|
+
|
225
|
+
def total_size(self):
|
226
|
+
return self._total_size_helper()
|
227
|
+
|
228
|
+
def evict(self, num_tokens: int):
|
229
|
+
if self.disable:
|
230
|
+
return
|
231
|
+
|
232
|
+
leaves = self._collect_leaves()
|
233
|
+
heapq.heapify(leaves)
|
234
|
+
|
235
|
+
num_evicted = 0
|
236
|
+
while num_evicted < num_tokens and len(leaves):
|
237
|
+
x = heapq.heappop(leaves)
|
238
|
+
|
239
|
+
if x == self.root_node:
|
240
|
+
break
|
241
|
+
if x.lock_ref > 0:
|
242
|
+
continue
|
243
|
+
|
244
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
245
|
+
num_evicted += len(x.value)
|
246
|
+
self._delete_leaf(x)
|
247
|
+
|
248
|
+
if len(x.parent.children) == 0:
|
249
|
+
heapq.heappush(leaves, x.parent)
|
250
|
+
|
251
|
+
def inc_lock_ref(self, node: LoRATreeNode):
|
252
|
+
if self.disable:
|
253
|
+
return 0
|
254
|
+
|
255
|
+
delta = 0
|
256
|
+
while node != self.root_node:
|
257
|
+
if node.lock_ref == 0:
|
258
|
+
self.evictable_size_ -= len(node.value)
|
259
|
+
self.protected_size_ += len(node.value)
|
260
|
+
delta -= len(node.value)
|
261
|
+
node.lock_ref += 1
|
262
|
+
node = node.parent
|
263
|
+
return delta
|
264
|
+
|
265
|
+
def dec_lock_ref(self, node: LoRATreeNode):
|
266
|
+
if self.disable:
|
267
|
+
return 0
|
268
|
+
|
269
|
+
delta = 0
|
270
|
+
while node != self.root_node:
|
271
|
+
if node.lock_ref == 1:
|
272
|
+
self.evictable_size_ += len(node.value)
|
273
|
+
self.protected_size_ -= len(node.value)
|
274
|
+
delta += len(node.value)
|
275
|
+
node.lock_ref -= 1
|
276
|
+
node = node.parent
|
277
|
+
return delta
|
278
|
+
|
279
|
+
def evictable_size(self):
|
280
|
+
return self.evictable_size_
|
281
|
+
|
282
|
+
def protected_size(self):
|
283
|
+
# protected size refers to the size of the cache that is locked
|
284
|
+
return self.protected_size_
|
285
|
+
|
286
|
+
def all_values_flatten(self):
|
287
|
+
values = []
|
288
|
+
|
289
|
+
def _dfs_helper(node: LoRATreeNode):
|
290
|
+
for _, child in node.children.items():
|
291
|
+
values.append(child.value)
|
292
|
+
_dfs_helper(child)
|
293
|
+
|
294
|
+
_dfs_helper(self.root_node)
|
295
|
+
return torch.cat(values)
|
296
|
+
|
297
|
+
##### Internal Helper Functions #####
|
298
|
+
|
299
|
+
def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
|
300
|
+
node.last_access_time = time.monotonic()
|
301
|
+
|
302
|
+
child_key = self.get_child_key_fn(key)
|
303
|
+
|
304
|
+
value = []
|
305
|
+
while len(key) > 0 and child_key in node.children.keys():
|
306
|
+
child = node.children[child_key]
|
307
|
+
child.last_access_time = time.monotonic()
|
308
|
+
prefix_len = self.key_match_fn(child.key, key)
|
309
|
+
if prefix_len < len(child.key):
|
310
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
311
|
+
value.append(new_node.value)
|
312
|
+
node = new_node
|
313
|
+
break
|
314
|
+
else:
|
315
|
+
value.append(child.value)
|
316
|
+
node = child
|
317
|
+
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
318
|
+
|
319
|
+
if len(key):
|
320
|
+
child_key = self.get_child_key_fn(key)
|
321
|
+
|
322
|
+
return value, node
|
323
|
+
|
324
|
+
def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
|
325
|
+
# new_node -> child
|
326
|
+
new_node = LoRATreeNode()
|
327
|
+
key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
|
328
|
+
key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
|
329
|
+
new_node.children = {self.get_child_key_fn(key_split_2): child}
|
330
|
+
new_node.parent = child.parent
|
331
|
+
new_node.lock_ref = child.lock_ref
|
332
|
+
new_node.key = key_split_1
|
333
|
+
new_node.value = child.value[:split_len]
|
334
|
+
child.parent = new_node
|
335
|
+
child.key = key_split_2
|
336
|
+
child.value = child.value[split_len:]
|
337
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
338
|
+
|
339
|
+
return new_node
|
340
|
+
|
341
|
+
def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
|
342
|
+
node.last_access_time = time.monotonic()
|
343
|
+
if len(key) == 0:
|
344
|
+
return 0
|
345
|
+
|
346
|
+
child_key = self.get_child_key_fn(key)
|
347
|
+
|
348
|
+
total_prefix_length = 0
|
349
|
+
while len(key) > 0 and child_key in node.children.keys():
|
350
|
+
node = node.children[child_key]
|
351
|
+
node.last_access_time = time.monotonic()
|
352
|
+
prefix_len = self.key_match_fn(node.key, key)
|
353
|
+
total_prefix_length += prefix_len
|
354
|
+
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
|
355
|
+
value = value[prefix_len:]
|
356
|
+
|
357
|
+
if prefix_len < len(node.key):
|
358
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
359
|
+
node = new_node
|
360
|
+
|
361
|
+
if len(key):
|
362
|
+
child_key = self.get_child_key_fn(key)
|
363
|
+
|
364
|
+
if len(key):
|
365
|
+
new_node = LoRATreeNode()
|
366
|
+
new_node.parent = node
|
367
|
+
new_node.key = key
|
368
|
+
new_node.value = value
|
369
|
+
node.children[child_key] = new_node
|
370
|
+
self.evictable_size_ += len(value)
|
371
|
+
return total_prefix_length
|
372
|
+
|
373
|
+
def _print_helper(self, node: LoRATreeNode, indent: int):
|
374
|
+
"""Prints the radix tree in a human-readable format."""
|
375
|
+
stack = [(node, indent)]
|
376
|
+
while stack:
|
377
|
+
current_node, current_indent = stack.pop()
|
378
|
+
print(
|
379
|
+
" " * current_indent,
|
380
|
+
len(current_node.key),
|
381
|
+
current_node.key.token_ids[:10],
|
382
|
+
f"r={current_node.lock_ref}",
|
383
|
+
)
|
384
|
+
for key, child in current_node.children.items():
|
385
|
+
stack.append((child, current_indent + 2))
|
386
|
+
|
387
|
+
assert key == self.get_child_key_fn(
|
388
|
+
child.key
|
389
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
390
|
+
|
391
|
+
def _delete_leaf(self, node):
|
392
|
+
for k, v in node.parent.children.items():
|
393
|
+
if v == node:
|
394
|
+
break
|
395
|
+
del node.parent.children[k]
|
396
|
+
self.evictable_size_ -= len(node.key)
|
397
|
+
|
398
|
+
def _total_size_helper(self):
|
399
|
+
total_size = 0
|
400
|
+
stack = [self.root_node]
|
401
|
+
while stack:
|
402
|
+
current_node = stack.pop()
|
403
|
+
total_size += len(current_node.value)
|
404
|
+
for child in current_node.children.values():
|
405
|
+
if child.evicted:
|
406
|
+
continue
|
407
|
+
stack.append(child)
|
408
|
+
return total_size
|
409
|
+
|
410
|
+
def _collect_leaves(self):
|
411
|
+
ret_list = []
|
412
|
+
stack = [self.root_node]
|
413
|
+
|
414
|
+
while stack:
|
415
|
+
cur_node = stack.pop()
|
416
|
+
if len(cur_node.children) == 0:
|
417
|
+
ret_list.append(cur_node)
|
418
|
+
else:
|
419
|
+
stack.extend(cur_node.children.values())
|
420
|
+
|
421
|
+
return ret_list
|
@@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
358
358
|
dst_v=device_pool.v_buffer[layer_id],
|
359
359
|
src_indices=host_indices,
|
360
360
|
dst_indices=device_indices,
|
361
|
+
layer_id=layer_id,
|
361
362
|
item_size=self.token_stride_size,
|
362
363
|
src_layout_dim=self.layout_dim,
|
363
364
|
)
|
@@ -471,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
471
472
|
* self.dtype.itemsize
|
472
473
|
)
|
473
474
|
for index in range(0, len(indices), self.page_size):
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
ptr_list.append(k_ptr)
|
489
|
-
ptr_list.append(v_ptr)
|
490
|
-
key_ = keys[index // self.page_size]
|
491
|
-
key_list.append(f"{key_}_{layer_id}_k")
|
492
|
-
key_list.append(f"{key_}_{layer_id}_v")
|
475
|
+
k_ptr = (
|
476
|
+
kv_buffer_data_ptr
|
477
|
+
+ indices[index]
|
478
|
+
* self.layer_num
|
479
|
+
* self.head_num
|
480
|
+
* self.head_dim
|
481
|
+
* self.dtype.itemsize
|
482
|
+
)
|
483
|
+
v_ptr = k_ptr + v_offset
|
484
|
+
ptr_list.append(k_ptr)
|
485
|
+
ptr_list.append(v_ptr)
|
486
|
+
key_ = keys[index // self.page_size]
|
487
|
+
key_list.append(f"{key_}_k")
|
488
|
+
key_list.append(f"{key_}_v")
|
493
489
|
element_size = (
|
494
|
-
self.
|
490
|
+
self.layer_num
|
491
|
+
* self.dtype.itemsize
|
492
|
+
* self.page_size
|
493
|
+
* self.head_num
|
494
|
+
* self.head_dim
|
495
495
|
)
|
496
496
|
element_size_list = [element_size] * len(key_list)
|
497
497
|
return key_list, ptr_list, element_size_list
|
@@ -585,6 +585,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
585
585
|
dst=device_pool.kv_buffer[layer_id],
|
586
586
|
src_indices=host_indices,
|
587
587
|
dst_indices=device_indices,
|
588
|
+
layer_id=layer_id,
|
588
589
|
item_size=self.token_stride_size,
|
589
590
|
src_layout_dim=self.layout_dim,
|
590
591
|
)
|
@@ -685,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
685
686
|
key_list = []
|
686
687
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
687
688
|
for index in range(0, len(indices), self.page_size):
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
)
|
699
|
-
ptr_list.append(k_ptr)
|
700
|
-
key_ = keys[index // self.page_size]
|
701
|
-
key_list.append(f"{key_}_{layer_id}_k")
|
689
|
+
k_ptr = (
|
690
|
+
kv_buffer_data_ptr
|
691
|
+
+ indices[index]
|
692
|
+
* self.layer_num
|
693
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
694
|
+
* self.dtype.itemsize
|
695
|
+
)
|
696
|
+
ptr_list.append(k_ptr)
|
697
|
+
key_ = keys[index // self.page_size]
|
698
|
+
key_list.append(f"{key_}_k")
|
702
699
|
element_size = (
|
703
|
-
self.
|
700
|
+
self.layer_num
|
701
|
+
* self.dtype.itemsize
|
704
702
|
* self.page_size
|
705
703
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
706
704
|
)
|
@@ -62,6 +62,7 @@ class TreeNode:
|
|
62
62
|
self.host_value: Optional[torch.Tensor] = None
|
63
63
|
# store hash values of each pages
|
64
64
|
self.hash_value: Optional[List[str]] = None
|
65
|
+
self.backuped_storage = False
|
65
66
|
|
66
67
|
self.id = TreeNode.counter if id is None else id
|
67
68
|
TreeNode.counter += 1
|
@@ -74,10 +75,6 @@ class TreeNode:
|
|
74
75
|
def backuped(self):
|
75
76
|
return self.host_value is not None
|
76
77
|
|
77
|
-
@property
|
78
|
-
def backuped_storage(self):
|
79
|
-
return self.hash_value is not None and len(self.hash_value) > 0
|
80
|
-
|
81
78
|
def protect_host(self):
|
82
79
|
"""Protect the host value from eviction."""
|
83
80
|
self.host_ref_counter += 1
|
@@ -498,7 +495,7 @@ class RadixCache(BasePrefixCache):
|
|
498
495
|
# One BlockStored per ``page_size`` chunk.
|
499
496
|
if self.enable_kv_cache_events:
|
500
497
|
# First chunk links to the last page of the parent node (if any).
|
501
|
-
if node.parent is None:
|
498
|
+
if node.parent is None or node != self.root_node:
|
502
499
|
parent_block_hash = None
|
503
500
|
else:
|
504
501
|
last_page_start = (
|