sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Page-aligned memory pool.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import triton
|
22
|
+
import triton.language as tl
|
23
|
+
|
24
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
25
|
+
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
26
|
+
|
27
|
+
|
28
|
+
@triton.jit
|
29
|
+
def alloc_extend_kernel(
|
30
|
+
pre_lens_ptr,
|
31
|
+
seq_lens_ptr,
|
32
|
+
last_loc_ptr,
|
33
|
+
free_page_ptr,
|
34
|
+
out_indices,
|
35
|
+
ret_values,
|
36
|
+
bs_upper: tl.constexpr,
|
37
|
+
page_size: tl.constexpr,
|
38
|
+
max_num_extend_tokens: tl.constexpr,
|
39
|
+
):
|
40
|
+
pid = tl.program_id(0)
|
41
|
+
|
42
|
+
load_offset = tl.arange(0, bs_upper)
|
43
|
+
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
44
|
+
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
|
45
|
+
extend_lens = seq_lens - pre_lens
|
46
|
+
|
47
|
+
seq_len = tl.load(seq_lens_ptr + pid)
|
48
|
+
pre_len = tl.load(pre_lens_ptr + pid)
|
49
|
+
extend_len = seq_len - pre_len
|
50
|
+
|
51
|
+
sum_extend_lens = tl.sum(extend_lens)
|
52
|
+
output_start_loc = sum_extend_lens - extend_len
|
53
|
+
|
54
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
55
|
+
num_pages_before = (pre_lens + page_size - 1) // page_size
|
56
|
+
num_new_pages = num_pages_after - num_pages_before
|
57
|
+
|
58
|
+
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
59
|
+
pre_len + page_size - 1
|
60
|
+
) // page_size
|
61
|
+
sum_num_new_pages = tl.sum(num_new_pages)
|
62
|
+
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
63
|
+
|
64
|
+
# Return value
|
65
|
+
if pid == tl.num_programs(0) - 1:
|
66
|
+
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
67
|
+
tl.int64
|
68
|
+
)
|
69
|
+
tl.store(ret_values, merged_value)
|
70
|
+
|
71
|
+
# Part 1: fill the old partial page
|
72
|
+
last_loc = tl.load(last_loc_ptr + pid)
|
73
|
+
num_part1 = (
|
74
|
+
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
|
75
|
+
)
|
76
|
+
offset_one_page = tl.arange(0, page_size)
|
77
|
+
tl.store(
|
78
|
+
out_indices + output_start_loc + offset_one_page,
|
79
|
+
last_loc + 1 + offset_one_page,
|
80
|
+
mask=offset_one_page < num_part1,
|
81
|
+
)
|
82
|
+
if pre_len + num_part1 == seq_len:
|
83
|
+
return
|
84
|
+
|
85
|
+
# Part 2: fill the new full pages
|
86
|
+
num_part2 = (
|
87
|
+
seq_len // page_size * page_size
|
88
|
+
- (pre_len + page_size - 1) // page_size * page_size
|
89
|
+
)
|
90
|
+
|
91
|
+
offset_many_page = tl.arange(0, max_num_extend_tokens)
|
92
|
+
page_start = tl.load(
|
93
|
+
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
|
94
|
+
mask=offset_many_page < num_part2,
|
95
|
+
)
|
96
|
+
tl.store(
|
97
|
+
out_indices + output_start_loc + num_part1 + offset_many_page,
|
98
|
+
page_start * page_size + offset_many_page % page_size,
|
99
|
+
mask=offset_many_page < num_part2,
|
100
|
+
)
|
101
|
+
if pre_len + num_part1 + num_part2 == seq_len:
|
102
|
+
return
|
103
|
+
|
104
|
+
# Part 3: fill the new partial page
|
105
|
+
num_part3 = seq_len - seq_len // page_size * page_size
|
106
|
+
start_loc = tl.load(
|
107
|
+
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
|
108
|
+
)
|
109
|
+
tl.store(
|
110
|
+
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
|
111
|
+
start_loc * page_size + offset_one_page,
|
112
|
+
mask=offset_one_page < num_part3,
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
@triton.jit
|
117
|
+
def alloc_decode_kernel(
|
118
|
+
seq_lens_ptr,
|
119
|
+
last_loc_ptr,
|
120
|
+
free_page_ptr,
|
121
|
+
out_indices,
|
122
|
+
ret_values,
|
123
|
+
bs_upper: tl.constexpr,
|
124
|
+
page_size: tl.constexpr,
|
125
|
+
):
|
126
|
+
pid = tl.program_id(0)
|
127
|
+
|
128
|
+
load_offset = tl.arange(0, bs_upper)
|
129
|
+
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
130
|
+
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
|
131
|
+
|
132
|
+
seq_len = tl.load(seq_lens_ptr + pid)
|
133
|
+
pre_len = seq_len - 1
|
134
|
+
|
135
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
136
|
+
num_pages_before = (pre_lens + page_size - 1) // page_size
|
137
|
+
num_new_pages = num_pages_after - num_pages_before
|
138
|
+
|
139
|
+
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
140
|
+
pre_len + page_size - 1
|
141
|
+
) // page_size
|
142
|
+
sum_num_new_pages = tl.sum(num_new_pages)
|
143
|
+
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
144
|
+
|
145
|
+
# Return value
|
146
|
+
if pid == tl.num_programs(0) - 1:
|
147
|
+
tl.store(ret_values, sum_num_new_pages)
|
148
|
+
|
149
|
+
if num_page_start_loc_self == 0:
|
150
|
+
last_loc = tl.load(last_loc_ptr + pid)
|
151
|
+
tl.store(out_indices + pid, last_loc + 1)
|
152
|
+
else:
|
153
|
+
page = tl.load(free_page_ptr + new_page_start_loc)
|
154
|
+
tl.store(out_indices + pid, page * page_size)
|
155
|
+
|
156
|
+
|
157
|
+
class PagedTokenToKVPoolAllocator:
|
158
|
+
"""
|
159
|
+
An allocator managing the indices to kv cache data.
|
160
|
+
|
161
|
+
This class has the same interface as `TokenToKVPoolAllocator` but the output
|
162
|
+
of one request is always page-aligned.
|
163
|
+
|
164
|
+
TODO: fuse last_loc into the kernel.
|
165
|
+
"""
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
size: int,
|
170
|
+
page_size: int,
|
171
|
+
dtype: torch.dtype,
|
172
|
+
device: str,
|
173
|
+
kvcache: KVCache,
|
174
|
+
):
|
175
|
+
self.size = size
|
176
|
+
self.dtype = dtype
|
177
|
+
self.device = device
|
178
|
+
self.page_size = page_size
|
179
|
+
self.num_pages = size // page_size
|
180
|
+
|
181
|
+
self.free_pages = None
|
182
|
+
self.is_not_in_free_group = True
|
183
|
+
self.free_group = []
|
184
|
+
self.clear()
|
185
|
+
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
186
|
+
|
187
|
+
self._kvcache = kvcache
|
188
|
+
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
189
|
+
|
190
|
+
def available_size(self):
|
191
|
+
return len(self.free_pages) * self.page_size
|
192
|
+
|
193
|
+
def alloc_extend(
|
194
|
+
self,
|
195
|
+
prefix_lens: torch.Tensor,
|
196
|
+
seq_lens: torch.Tensor,
|
197
|
+
last_loc: torch.Tensor,
|
198
|
+
extend_num_tokens: int,
|
199
|
+
):
|
200
|
+
if self.debug_mode:
|
201
|
+
assert torch.all(
|
202
|
+
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
203
|
+
)
|
204
|
+
|
205
|
+
bs = len(prefix_lens)
|
206
|
+
out_indices = torch.empty(
|
207
|
+
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
208
|
+
)
|
209
|
+
alloc_extend_kernel[(bs,)](
|
210
|
+
prefix_lens,
|
211
|
+
seq_lens,
|
212
|
+
last_loc,
|
213
|
+
self.free_pages,
|
214
|
+
out_indices,
|
215
|
+
self.ret_values,
|
216
|
+
next_power_of_2(bs),
|
217
|
+
self.page_size,
|
218
|
+
next_power_of_2(extend_num_tokens),
|
219
|
+
)
|
220
|
+
|
221
|
+
merged_value = self.ret_values.item()
|
222
|
+
num_new_pages = merged_value >> 32
|
223
|
+
if num_new_pages > len(self.free_pages):
|
224
|
+
return None
|
225
|
+
|
226
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
227
|
+
return out_indices
|
228
|
+
|
229
|
+
def alloc_decode(
|
230
|
+
self,
|
231
|
+
seq_lens: torch.Tensor,
|
232
|
+
last_loc: torch.Tensor,
|
233
|
+
):
|
234
|
+
if self.debug_mode:
|
235
|
+
assert torch.all(
|
236
|
+
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
237
|
+
)
|
238
|
+
|
239
|
+
bs = len(seq_lens)
|
240
|
+
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
241
|
+
alloc_decode_kernel[(bs,)](
|
242
|
+
seq_lens,
|
243
|
+
last_loc,
|
244
|
+
self.free_pages,
|
245
|
+
out_indices,
|
246
|
+
self.ret_values,
|
247
|
+
next_power_of_2(bs),
|
248
|
+
self.page_size,
|
249
|
+
)
|
250
|
+
|
251
|
+
num_new_pages = self.ret_values.item()
|
252
|
+
if num_new_pages > len(self.free_pages):
|
253
|
+
return None
|
254
|
+
|
255
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
256
|
+
return out_indices
|
257
|
+
|
258
|
+
def free(self, free_index: torch.Tensor):
|
259
|
+
if free_index.numel() == 0:
|
260
|
+
return
|
261
|
+
|
262
|
+
if self.is_not_in_free_group:
|
263
|
+
free_page_indices = torch.unique(free_index // self.page_size)
|
264
|
+
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
265
|
+
else:
|
266
|
+
self.free_group.append(free_index)
|
267
|
+
|
268
|
+
def free_group_begin(self):
|
269
|
+
self.is_not_in_free_group = False
|
270
|
+
self.free_group = []
|
271
|
+
|
272
|
+
def free_group_end(self):
|
273
|
+
self.is_not_in_free_group = True
|
274
|
+
if self.free_group:
|
275
|
+
self.free(torch.concat(self.free_group))
|
276
|
+
|
277
|
+
def clear(self):
|
278
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
|
+
self.free_pages = torch.arange(
|
280
|
+
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
281
|
+
)
|
282
|
+
self.is_in_free_group = False
|
283
|
+
self.free_group = []
|
@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
|
|
22
22
|
import heapq
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
|
-
from
|
25
|
+
from functools import partial
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
27
|
|
27
28
|
import torch
|
28
29
|
|
@@ -67,7 +68,7 @@ class TreeNode:
|
|
67
68
|
return self.last_access_time < other.last_access_time
|
68
69
|
|
69
70
|
|
70
|
-
def
|
71
|
+
def _key_match_page_size1(key0: List, key1: List):
|
71
72
|
i = 0
|
72
73
|
for k0, k1 in zip(key0, key1):
|
73
74
|
if k0 != k1:
|
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
|
|
76
77
|
return i
|
77
78
|
|
78
79
|
|
80
|
+
def _key_match_paged(key0: List, key1: List, page_size: int):
|
81
|
+
min_len = min(len(key0), len(key1))
|
82
|
+
|
83
|
+
i = 0
|
84
|
+
while i < min_len:
|
85
|
+
if key0[i : i + page_size] != key1[i : i + page_size]:
|
86
|
+
break
|
87
|
+
i += page_size
|
88
|
+
|
89
|
+
return i
|
90
|
+
|
91
|
+
|
79
92
|
class RadixCache(BasePrefixCache):
|
80
93
|
def __init__(
|
81
94
|
self,
|
82
95
|
req_to_token_pool: ReqToTokenPool,
|
83
96
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
97
|
+
page_size: int,
|
84
98
|
disable: bool = False,
|
85
99
|
):
|
86
100
|
self.req_to_token_pool = req_to_token_pool
|
87
101
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
102
|
+
self.page_size = page_size
|
88
103
|
self.disable = disable
|
104
|
+
|
105
|
+
if self.token_to_kv_pool_allocator:
|
106
|
+
self.device = self.token_to_kv_pool_allocator.device
|
107
|
+
else:
|
108
|
+
self.device = torch.device("cpu")
|
109
|
+
|
110
|
+
if self.page_size == 1:
|
111
|
+
self.key_match_fn = _key_match_page_size1
|
112
|
+
self.get_child_key_fn = lambda key: key[0]
|
113
|
+
else:
|
114
|
+
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
115
|
+
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
89
116
|
self.reset()
|
90
117
|
|
91
118
|
##### Public API #####
|
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
|
|
109
136
|
The last node create a new child if the prefix is shorter
|
110
137
|
than the last node's value.
|
111
138
|
"""
|
112
|
-
if self.disable:
|
113
|
-
return
|
139
|
+
if self.disable or len(key) == 0:
|
140
|
+
return (
|
141
|
+
torch.empty(
|
142
|
+
(0,),
|
143
|
+
dtype=torch.int32,
|
144
|
+
device=self.device,
|
145
|
+
),
|
146
|
+
self.root_node,
|
147
|
+
)
|
148
|
+
|
149
|
+
if self.page_size != 1:
|
150
|
+
page_aligned_len = len(key) // self.page_size * self.page_size
|
151
|
+
key = key[:page_aligned_len]
|
114
152
|
|
115
153
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
116
154
|
if value:
|
117
155
|
value = torch.concat(value)
|
118
156
|
else:
|
119
|
-
value = torch.
|
157
|
+
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
120
158
|
return value, last_node
|
121
159
|
|
122
160
|
def insert(self, key: List, value=None):
|
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
|
|
127
165
|
value = [x for x in key]
|
128
166
|
return self._insert_helper(self.root_node, key, value)
|
129
167
|
|
130
|
-
def cache_finished_req(self, req: Req
|
168
|
+
def cache_finished_req(self, req: Req):
|
131
169
|
"""Cache request when it finishes."""
|
132
170
|
if self.disable:
|
133
|
-
if token_ids is None:
|
134
|
-
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
135
|
-
else:
|
136
|
-
token_ids_len = len(token_ids)
|
137
|
-
|
138
171
|
kv_indices = self.req_to_token_pool.req_to_token[
|
139
|
-
req.req_pool_idx, :
|
172
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
140
173
|
]
|
141
174
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
142
175
|
self.req_to_token_pool.free(req.req_pool_idx)
|
143
176
|
return
|
144
177
|
|
145
|
-
|
146
|
-
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
178
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
147
179
|
kv_indices = self.req_to_token_pool.req_to_token[
|
148
180
|
req.req_pool_idx, : len(token_ids)
|
149
181
|
]
|
150
182
|
|
183
|
+
if self.page_size != 1:
|
184
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
185
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
186
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
187
|
+
else:
|
188
|
+
page_aligned_len = len(kv_indices)
|
189
|
+
page_aligned_kv_indices = kv_indices.clone()
|
190
|
+
|
151
191
|
# Radix Cache takes one ref in memory pool
|
152
|
-
new_prefix_len = self.insert(
|
192
|
+
new_prefix_len = self.insert(
|
193
|
+
token_ids[:page_aligned_len], page_aligned_kv_indices
|
194
|
+
)
|
153
195
|
self.token_to_kv_pool_allocator.free(
|
154
196
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
155
197
|
)
|
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
|
|
158
200
|
self.req_to_token_pool.free(req.req_pool_idx)
|
159
201
|
self.dec_lock_ref(req.last_node)
|
160
202
|
|
161
|
-
def cache_unfinished_req(self, req: Req
|
203
|
+
def cache_unfinished_req(self, req: Req):
|
162
204
|
"""Cache request when it is unfinished."""
|
163
205
|
if self.disable:
|
164
206
|
return
|
165
207
|
|
166
|
-
|
167
|
-
token_ids = req.fill_ids
|
168
|
-
|
208
|
+
token_ids = req.fill_ids
|
169
209
|
kv_indices = self.req_to_token_pool.req_to_token[
|
170
210
|
req.req_pool_idx, : len(token_ids)
|
171
211
|
]
|
172
212
|
|
213
|
+
if self.page_size != 1:
|
214
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
215
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
216
|
+
else:
|
217
|
+
page_aligned_len = len(kv_indices)
|
218
|
+
page_aligned_kv_indices = kv_indices.clone()
|
219
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
220
|
+
|
173
221
|
# Radix Cache takes one ref in memory pool
|
174
|
-
new_prefix_len = self.insert(
|
222
|
+
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
|
175
223
|
self.token_to_kv_pool_allocator.free(
|
176
224
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
177
225
|
)
|
178
226
|
|
179
227
|
# The prefix indices could be updated, reuse it
|
180
|
-
new_indices, new_last_node = self.match_prefix(
|
181
|
-
assert len(new_indices) == len(token_ids)
|
228
|
+
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
182
229
|
self.req_to_token_pool.write(
|
183
230
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
184
231
|
new_indices[len(req.prefix_indices) :],
|
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
|
|
186
233
|
|
187
234
|
self.dec_lock_ref(req.last_node)
|
188
235
|
self.inc_lock_ref(new_last_node)
|
189
|
-
|
236
|
+
|
237
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
238
|
+
if self.page_size != 1:
|
239
|
+
req.prefix_indices = torch.cat(
|
240
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
req.prefix_indices = new_indices
|
190
244
|
req.last_node = new_last_node
|
191
245
|
|
192
246
|
def pretty_print(self):
|
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
|
|
196
250
|
def total_size(self):
|
197
251
|
return self._total_size_helper()
|
198
252
|
|
199
|
-
def evict(self, num_tokens: int
|
253
|
+
def evict(self, num_tokens: int):
|
200
254
|
if self.disable:
|
201
255
|
return
|
202
256
|
|
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
|
|
212
266
|
if x.lock_ref > 0:
|
213
267
|
continue
|
214
268
|
|
215
|
-
|
269
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
216
270
|
num_evicted += len(x.value)
|
217
271
|
self._delete_leaf(x)
|
218
272
|
|
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
|
|
254
308
|
# protected size refers to the size of the cache that is locked
|
255
309
|
return self.protected_size_
|
256
310
|
|
311
|
+
def all_values_flatten(self):
|
312
|
+
values = []
|
313
|
+
|
314
|
+
def _dfs_helper(node: TreeNode):
|
315
|
+
for _, child in node.children.items():
|
316
|
+
values.append(child.value)
|
317
|
+
_dfs_helper(child)
|
318
|
+
|
319
|
+
_dfs_helper(self.root_node)
|
320
|
+
return torch.concat(values)
|
321
|
+
|
257
322
|
##### Internal Helper Functions #####
|
258
323
|
|
259
324
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
260
325
|
node.last_access_time = time.time()
|
326
|
+
|
327
|
+
child_key = self.get_child_key_fn(key)
|
328
|
+
|
261
329
|
value = []
|
262
|
-
while len(key) > 0 and
|
263
|
-
child = node.children[
|
330
|
+
while len(key) > 0 and child_key in node.children.keys():
|
331
|
+
child = node.children[child_key]
|
264
332
|
child.last_access_time = time.time()
|
265
|
-
prefix_len =
|
333
|
+
prefix_len = self.key_match_fn(child.key, key)
|
266
334
|
if prefix_len < len(child.key):
|
267
335
|
new_node = self._split_node(child.key, child, prefix_len)
|
268
336
|
value.append(new_node.value)
|
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
|
|
272
340
|
value.append(child.value)
|
273
341
|
node = child
|
274
342
|
key = key[prefix_len:]
|
343
|
+
|
344
|
+
if len(key):
|
345
|
+
child_key = self.get_child_key_fn(key)
|
346
|
+
|
275
347
|
return value, node
|
276
348
|
|
277
349
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
278
350
|
# new_node -> child
|
279
351
|
new_node = TreeNode()
|
280
|
-
new_node.children = {key[split_len]: child}
|
352
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
281
353
|
new_node.parent = child.parent
|
282
354
|
new_node.lock_ref = child.lock_ref
|
283
355
|
new_node.key = child.key[:split_len]
|
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
|
|
285
357
|
child.parent = new_node
|
286
358
|
child.key = child.key[split_len:]
|
287
359
|
child.value = child.value[split_len:]
|
288
|
-
new_node.parent.children[key
|
360
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
289
361
|
return new_node
|
290
362
|
|
291
363
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
|
|
293
365
|
if len(key) == 0:
|
294
366
|
return 0
|
295
367
|
|
368
|
+
child_key = self.get_child_key_fn(key)
|
369
|
+
|
296
370
|
total_prefix_length = 0
|
297
|
-
while len(key) > 0 and
|
298
|
-
node = node.children[
|
371
|
+
while len(key) > 0 and child_key in node.children.keys():
|
372
|
+
node = node.children[child_key]
|
299
373
|
node.last_access_time = time.time()
|
300
|
-
prefix_len =
|
374
|
+
prefix_len = self.key_match_fn(node.key, key)
|
301
375
|
total_prefix_length += prefix_len
|
302
376
|
key = key[prefix_len:]
|
303
377
|
value = value[prefix_len:]
|
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
|
|
306
380
|
new_node = self._split_node(node.key, node, prefix_len)
|
307
381
|
node = new_node
|
308
382
|
|
383
|
+
if len(key):
|
384
|
+
child_key = self.get_child_key_fn(key)
|
385
|
+
|
309
386
|
if len(key):
|
310
387
|
new_node = TreeNode()
|
311
388
|
new_node.parent = node
|
312
389
|
new_node.key = key
|
313
390
|
new_node.value = value
|
314
|
-
node.children[
|
391
|
+
node.children[child_key] = new_node
|
315
392
|
self.evictable_size_ += len(value)
|
316
393
|
return total_prefix_length
|
317
394
|
|
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
|
|
326
403
|
current_node.key[:10],
|
327
404
|
f"r={current_node.lock_ref}",
|
328
405
|
)
|
329
|
-
for
|
406
|
+
for key, child in current_node.children.items():
|
330
407
|
stack.append((child, current_indent + 2))
|
331
408
|
|
409
|
+
assert key == self.get_child_key_fn(
|
410
|
+
child.key
|
411
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
412
|
+
|
332
413
|
def _delete_leaf(self, node):
|
333
414
|
for k, v in node.parent.children.items():
|
334
415
|
if v == node:
|
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
|
|
363
444
|
|
364
445
|
|
365
446
|
if __name__ == "__main__":
|
366
|
-
tree = RadixCache(None, None, False)
|
447
|
+
tree = RadixCache(None, None, page_size=1, disable=False)
|
367
448
|
|
368
449
|
tree.insert("Hello")
|
369
450
|
tree.insert("Hello")
|