sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Page-aligned memory pool.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import triton
|
22
|
+
import triton.language as tl
|
23
|
+
|
24
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
25
|
+
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
26
|
+
|
27
|
+
|
28
|
+
@triton.jit
|
29
|
+
def alloc_extend_kernel(
|
30
|
+
pre_lens_ptr,
|
31
|
+
seq_lens_ptr,
|
32
|
+
last_loc_ptr,
|
33
|
+
free_page_ptr,
|
34
|
+
out_indices,
|
35
|
+
ret_values,
|
36
|
+
bs_upper: tl.constexpr,
|
37
|
+
page_size: tl.constexpr,
|
38
|
+
max_num_extend_tokens: tl.constexpr,
|
39
|
+
):
|
40
|
+
pid = tl.program_id(0)
|
41
|
+
|
42
|
+
load_offset = tl.arange(0, bs_upper)
|
43
|
+
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
44
|
+
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
|
45
|
+
extend_lens = seq_lens - pre_lens
|
46
|
+
|
47
|
+
seq_len = tl.load(seq_lens_ptr + pid)
|
48
|
+
pre_len = tl.load(pre_lens_ptr + pid)
|
49
|
+
extend_len = seq_len - pre_len
|
50
|
+
|
51
|
+
sum_extend_lens = tl.sum(extend_lens)
|
52
|
+
output_start_loc = sum_extend_lens - extend_len
|
53
|
+
|
54
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
55
|
+
num_pages_before = (pre_lens + page_size - 1) // page_size
|
56
|
+
num_new_pages = num_pages_after - num_pages_before
|
57
|
+
|
58
|
+
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
59
|
+
pre_len + page_size - 1
|
60
|
+
) // page_size
|
61
|
+
sum_num_new_pages = tl.sum(num_new_pages)
|
62
|
+
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
63
|
+
|
64
|
+
# Return value
|
65
|
+
if pid == tl.num_programs(0) - 1:
|
66
|
+
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
67
|
+
tl.int64
|
68
|
+
)
|
69
|
+
tl.store(ret_values, merged_value)
|
70
|
+
|
71
|
+
# Part 1: fill the old partial page
|
72
|
+
last_loc = tl.load(last_loc_ptr + pid)
|
73
|
+
num_part1 = (
|
74
|
+
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
|
75
|
+
)
|
76
|
+
offset_one_page = tl.arange(0, page_size)
|
77
|
+
tl.store(
|
78
|
+
out_indices + output_start_loc + offset_one_page,
|
79
|
+
last_loc + 1 + offset_one_page,
|
80
|
+
mask=offset_one_page < num_part1,
|
81
|
+
)
|
82
|
+
if pre_len + num_part1 == seq_len:
|
83
|
+
return
|
84
|
+
|
85
|
+
# Part 2: fill the new full pages
|
86
|
+
num_part2 = (
|
87
|
+
seq_len // page_size * page_size
|
88
|
+
- (pre_len + page_size - 1) // page_size * page_size
|
89
|
+
)
|
90
|
+
|
91
|
+
offset_many_page = tl.arange(0, max_num_extend_tokens)
|
92
|
+
page_start = tl.load(
|
93
|
+
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
|
94
|
+
mask=offset_many_page < num_part2,
|
95
|
+
)
|
96
|
+
tl.store(
|
97
|
+
out_indices + output_start_loc + num_part1 + offset_many_page,
|
98
|
+
page_start * page_size + offset_many_page % page_size,
|
99
|
+
mask=offset_many_page < num_part2,
|
100
|
+
)
|
101
|
+
if pre_len + num_part1 + num_part2 == seq_len:
|
102
|
+
return
|
103
|
+
|
104
|
+
# Part 3: fill the new partial page
|
105
|
+
num_part3 = seq_len - seq_len // page_size * page_size
|
106
|
+
start_loc = tl.load(
|
107
|
+
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
|
108
|
+
)
|
109
|
+
tl.store(
|
110
|
+
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
|
111
|
+
start_loc * page_size + offset_one_page,
|
112
|
+
mask=offset_one_page < num_part3,
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
@triton.jit
|
117
|
+
def alloc_decode_kernel(
|
118
|
+
seq_lens_ptr,
|
119
|
+
last_loc_ptr,
|
120
|
+
free_page_ptr,
|
121
|
+
out_indices,
|
122
|
+
ret_values,
|
123
|
+
bs_upper: tl.constexpr,
|
124
|
+
page_size: tl.constexpr,
|
125
|
+
):
|
126
|
+
pid = tl.program_id(0)
|
127
|
+
|
128
|
+
load_offset = tl.arange(0, bs_upper)
|
129
|
+
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
130
|
+
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
|
131
|
+
|
132
|
+
seq_len = tl.load(seq_lens_ptr + pid)
|
133
|
+
pre_len = seq_len - 1
|
134
|
+
|
135
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
136
|
+
num_pages_before = (pre_lens + page_size - 1) // page_size
|
137
|
+
num_new_pages = num_pages_after - num_pages_before
|
138
|
+
|
139
|
+
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
140
|
+
pre_len + page_size - 1
|
141
|
+
) // page_size
|
142
|
+
sum_num_new_pages = tl.sum(num_new_pages)
|
143
|
+
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
144
|
+
|
145
|
+
# Return value
|
146
|
+
if pid == tl.num_programs(0) - 1:
|
147
|
+
tl.store(ret_values, sum_num_new_pages)
|
148
|
+
|
149
|
+
if num_page_start_loc_self == 0:
|
150
|
+
last_loc = tl.load(last_loc_ptr + pid)
|
151
|
+
tl.store(out_indices + pid, last_loc + 1)
|
152
|
+
else:
|
153
|
+
page = tl.load(free_page_ptr + new_page_start_loc)
|
154
|
+
tl.store(out_indices + pid, page * page_size)
|
155
|
+
|
156
|
+
|
157
|
+
class PagedTokenToKVPoolAllocator:
|
158
|
+
"""
|
159
|
+
An allocator managing the indices to kv cache data.
|
160
|
+
|
161
|
+
This class has the same interface as `TokenToKVPoolAllocator` but the output
|
162
|
+
of one request is always page-aligned.
|
163
|
+
|
164
|
+
TODO: fuse last_loc into the kernel.
|
165
|
+
"""
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
size: int,
|
170
|
+
page_size: int,
|
171
|
+
dtype: torch.dtype,
|
172
|
+
device: str,
|
173
|
+
kvcache: KVCache,
|
174
|
+
):
|
175
|
+
self.size = size
|
176
|
+
self.dtype = dtype
|
177
|
+
self.device = device
|
178
|
+
self.page_size = page_size
|
179
|
+
self.num_pages = size // page_size
|
180
|
+
|
181
|
+
self.free_pages = None
|
182
|
+
self.is_not_in_free_group = True
|
183
|
+
self.free_group = []
|
184
|
+
self.clear()
|
185
|
+
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
186
|
+
|
187
|
+
self._kvcache = kvcache
|
188
|
+
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
189
|
+
|
190
|
+
def available_size(self):
|
191
|
+
return len(self.free_pages) * self.page_size
|
192
|
+
|
193
|
+
def alloc_extend(
|
194
|
+
self,
|
195
|
+
prefix_lens: torch.Tensor,
|
196
|
+
seq_lens: torch.Tensor,
|
197
|
+
last_loc: torch.Tensor,
|
198
|
+
extend_num_tokens: int,
|
199
|
+
):
|
200
|
+
if self.debug_mode:
|
201
|
+
assert torch.all(
|
202
|
+
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
203
|
+
)
|
204
|
+
|
205
|
+
bs = len(prefix_lens)
|
206
|
+
out_indices = torch.empty(
|
207
|
+
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
208
|
+
)
|
209
|
+
alloc_extend_kernel[(bs,)](
|
210
|
+
prefix_lens,
|
211
|
+
seq_lens,
|
212
|
+
last_loc,
|
213
|
+
self.free_pages,
|
214
|
+
out_indices,
|
215
|
+
self.ret_values,
|
216
|
+
next_power_of_2(bs),
|
217
|
+
self.page_size,
|
218
|
+
next_power_of_2(extend_num_tokens),
|
219
|
+
)
|
220
|
+
|
221
|
+
merged_value = self.ret_values.item()
|
222
|
+
num_new_pages = merged_value >> 32
|
223
|
+
if num_new_pages > len(self.free_pages):
|
224
|
+
return None
|
225
|
+
|
226
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
227
|
+
return out_indices
|
228
|
+
|
229
|
+
def alloc_decode(
|
230
|
+
self,
|
231
|
+
seq_lens: torch.Tensor,
|
232
|
+
last_loc: torch.Tensor,
|
233
|
+
):
|
234
|
+
if self.debug_mode:
|
235
|
+
assert torch.all(
|
236
|
+
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
237
|
+
)
|
238
|
+
|
239
|
+
bs = len(seq_lens)
|
240
|
+
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
241
|
+
alloc_decode_kernel[(bs,)](
|
242
|
+
seq_lens,
|
243
|
+
last_loc,
|
244
|
+
self.free_pages,
|
245
|
+
out_indices,
|
246
|
+
self.ret_values,
|
247
|
+
next_power_of_2(bs),
|
248
|
+
self.page_size,
|
249
|
+
)
|
250
|
+
|
251
|
+
num_new_pages = self.ret_values.item()
|
252
|
+
if num_new_pages > len(self.free_pages):
|
253
|
+
return None
|
254
|
+
|
255
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
256
|
+
return out_indices
|
257
|
+
|
258
|
+
def free(self, free_index: torch.Tensor):
|
259
|
+
if free_index.numel() == 0:
|
260
|
+
return
|
261
|
+
|
262
|
+
if self.is_not_in_free_group:
|
263
|
+
free_page_indices = torch.unique(free_index // self.page_size)
|
264
|
+
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
265
|
+
else:
|
266
|
+
self.free_group.append(free_index)
|
267
|
+
|
268
|
+
def free_group_begin(self):
|
269
|
+
self.is_not_in_free_group = False
|
270
|
+
self.free_group = []
|
271
|
+
|
272
|
+
def free_group_end(self):
|
273
|
+
self.is_not_in_free_group = True
|
274
|
+
if self.free_group:
|
275
|
+
self.free(torch.concat(self.free_group))
|
276
|
+
|
277
|
+
def clear(self):
|
278
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
|
+
self.free_pages = torch.arange(
|
280
|
+
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
281
|
+
)
|
282
|
+
self.is_in_free_group = False
|
283
|
+
self.free_group = []
|
@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
|
|
22
22
|
import heapq
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
|
-
from
|
25
|
+
from functools import partial
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
27
|
|
27
28
|
import torch
|
28
29
|
|
@@ -67,7 +68,7 @@ class TreeNode:
|
|
67
68
|
return self.last_access_time < other.last_access_time
|
68
69
|
|
69
70
|
|
70
|
-
def
|
71
|
+
def _key_match_page_size1(key0: List, key1: List):
|
71
72
|
i = 0
|
72
73
|
for k0, k1 in zip(key0, key1):
|
73
74
|
if k0 != k1:
|
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
|
|
76
77
|
return i
|
77
78
|
|
78
79
|
|
80
|
+
def _key_match_paged(key0: List, key1: List, page_size: int):
|
81
|
+
min_len = min(len(key0), len(key1))
|
82
|
+
|
83
|
+
i = 0
|
84
|
+
while i < min_len:
|
85
|
+
if key0[i : i + page_size] != key1[i : i + page_size]:
|
86
|
+
break
|
87
|
+
i += page_size
|
88
|
+
|
89
|
+
return i
|
90
|
+
|
91
|
+
|
79
92
|
class RadixCache(BasePrefixCache):
|
80
93
|
def __init__(
|
81
94
|
self,
|
82
95
|
req_to_token_pool: ReqToTokenPool,
|
83
96
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
97
|
+
page_size: int,
|
84
98
|
disable: bool = False,
|
85
99
|
):
|
86
100
|
self.req_to_token_pool = req_to_token_pool
|
87
101
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
102
|
+
self.page_size = page_size
|
88
103
|
self.disable = disable
|
104
|
+
|
105
|
+
if self.token_to_kv_pool_allocator:
|
106
|
+
self.device = self.token_to_kv_pool_allocator.device
|
107
|
+
else:
|
108
|
+
self.device = torch.device("cpu")
|
109
|
+
|
110
|
+
if self.page_size == 1:
|
111
|
+
self.key_match_fn = _key_match_page_size1
|
112
|
+
self.get_child_key_fn = lambda key: key[0]
|
113
|
+
else:
|
114
|
+
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
115
|
+
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
89
116
|
self.reset()
|
90
117
|
|
91
118
|
##### Public API #####
|
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
|
|
109
136
|
The last node create a new child if the prefix is shorter
|
110
137
|
than the last node's value.
|
111
138
|
"""
|
112
|
-
if self.disable:
|
113
|
-
return
|
139
|
+
if self.disable or len(key) == 0:
|
140
|
+
return (
|
141
|
+
torch.empty(
|
142
|
+
(0,),
|
143
|
+
dtype=torch.int32,
|
144
|
+
device=self.device,
|
145
|
+
),
|
146
|
+
self.root_node,
|
147
|
+
)
|
148
|
+
|
149
|
+
if self.page_size != 1:
|
150
|
+
page_aligned_len = len(key) // self.page_size * self.page_size
|
151
|
+
key = key[:page_aligned_len]
|
114
152
|
|
115
153
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
116
154
|
if value:
|
117
155
|
value = torch.concat(value)
|
118
156
|
else:
|
119
|
-
value = torch.
|
157
|
+
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
120
158
|
return value, last_node
|
121
159
|
|
122
160
|
def insert(self, key: List, value=None):
|
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
|
|
127
165
|
value = [x for x in key]
|
128
166
|
return self._insert_helper(self.root_node, key, value)
|
129
167
|
|
130
|
-
def cache_finished_req(self, req: Req
|
168
|
+
def cache_finished_req(self, req: Req):
|
131
169
|
"""Cache request when it finishes."""
|
132
170
|
if self.disable:
|
133
|
-
if token_ids is None:
|
134
|
-
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
135
|
-
else:
|
136
|
-
token_ids_len = len(token_ids)
|
137
|
-
|
138
171
|
kv_indices = self.req_to_token_pool.req_to_token[
|
139
|
-
req.req_pool_idx, :
|
172
|
+
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
140
173
|
]
|
141
174
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
142
175
|
self.req_to_token_pool.free(req.req_pool_idx)
|
143
176
|
return
|
144
177
|
|
145
|
-
|
146
|
-
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
178
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
147
179
|
kv_indices = self.req_to_token_pool.req_to_token[
|
148
180
|
req.req_pool_idx, : len(token_ids)
|
149
181
|
]
|
150
182
|
|
183
|
+
if self.page_size != 1:
|
184
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
185
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
186
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
187
|
+
else:
|
188
|
+
page_aligned_len = len(kv_indices)
|
189
|
+
page_aligned_kv_indices = kv_indices.clone()
|
190
|
+
|
151
191
|
# Radix Cache takes one ref in memory pool
|
152
|
-
new_prefix_len = self.insert(
|
192
|
+
new_prefix_len = self.insert(
|
193
|
+
token_ids[:page_aligned_len], page_aligned_kv_indices
|
194
|
+
)
|
153
195
|
self.token_to_kv_pool_allocator.free(
|
154
196
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
155
197
|
)
|
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
|
|
158
200
|
self.req_to_token_pool.free(req.req_pool_idx)
|
159
201
|
self.dec_lock_ref(req.last_node)
|
160
202
|
|
161
|
-
def cache_unfinished_req(self, req: Req
|
203
|
+
def cache_unfinished_req(self, req: Req):
|
162
204
|
"""Cache request when it is unfinished."""
|
163
205
|
if self.disable:
|
164
206
|
return
|
165
207
|
|
166
|
-
|
167
|
-
token_ids = req.fill_ids
|
168
|
-
|
208
|
+
token_ids = req.fill_ids
|
169
209
|
kv_indices = self.req_to_token_pool.req_to_token[
|
170
210
|
req.req_pool_idx, : len(token_ids)
|
171
211
|
]
|
172
212
|
|
213
|
+
if self.page_size != 1:
|
214
|
+
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
215
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
216
|
+
else:
|
217
|
+
page_aligned_len = len(kv_indices)
|
218
|
+
page_aligned_kv_indices = kv_indices.clone()
|
219
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
220
|
+
|
173
221
|
# Radix Cache takes one ref in memory pool
|
174
|
-
new_prefix_len = self.insert(
|
222
|
+
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
|
175
223
|
self.token_to_kv_pool_allocator.free(
|
176
224
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
177
225
|
)
|
178
226
|
|
179
227
|
# The prefix indices could be updated, reuse it
|
180
|
-
new_indices, new_last_node = self.match_prefix(
|
181
|
-
assert len(new_indices) == len(token_ids)
|
228
|
+
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
182
229
|
self.req_to_token_pool.write(
|
183
230
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
184
231
|
new_indices[len(req.prefix_indices) :],
|
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
|
|
186
233
|
|
187
234
|
self.dec_lock_ref(req.last_node)
|
188
235
|
self.inc_lock_ref(new_last_node)
|
189
|
-
|
236
|
+
|
237
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
238
|
+
if self.page_size != 1:
|
239
|
+
req.prefix_indices = torch.cat(
|
240
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
req.prefix_indices = new_indices
|
190
244
|
req.last_node = new_last_node
|
191
245
|
|
192
246
|
def pretty_print(self):
|
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
|
|
196
250
|
def total_size(self):
|
197
251
|
return self._total_size_helper()
|
198
252
|
|
199
|
-
def evict(self, num_tokens: int
|
253
|
+
def evict(self, num_tokens: int):
|
200
254
|
if self.disable:
|
201
255
|
return
|
202
256
|
|
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
|
|
212
266
|
if x.lock_ref > 0:
|
213
267
|
continue
|
214
268
|
|
215
|
-
|
269
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
216
270
|
num_evicted += len(x.value)
|
217
271
|
self._delete_leaf(x)
|
218
272
|
|
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
|
|
254
308
|
# protected size refers to the size of the cache that is locked
|
255
309
|
return self.protected_size_
|
256
310
|
|
311
|
+
def all_values_flatten(self):
|
312
|
+
values = []
|
313
|
+
|
314
|
+
def _dfs_helper(node: TreeNode):
|
315
|
+
for _, child in node.children.items():
|
316
|
+
values.append(child.value)
|
317
|
+
_dfs_helper(child)
|
318
|
+
|
319
|
+
_dfs_helper(self.root_node)
|
320
|
+
return torch.concat(values)
|
321
|
+
|
257
322
|
##### Internal Helper Functions #####
|
258
323
|
|
259
324
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
260
325
|
node.last_access_time = time.time()
|
326
|
+
|
327
|
+
child_key = self.get_child_key_fn(key)
|
328
|
+
|
261
329
|
value = []
|
262
|
-
while len(key) > 0 and
|
263
|
-
child = node.children[
|
330
|
+
while len(key) > 0 and child_key in node.children.keys():
|
331
|
+
child = node.children[child_key]
|
264
332
|
child.last_access_time = time.time()
|
265
|
-
prefix_len =
|
333
|
+
prefix_len = self.key_match_fn(child.key, key)
|
266
334
|
if prefix_len < len(child.key):
|
267
335
|
new_node = self._split_node(child.key, child, prefix_len)
|
268
336
|
value.append(new_node.value)
|
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
|
|
272
340
|
value.append(child.value)
|
273
341
|
node = child
|
274
342
|
key = key[prefix_len:]
|
343
|
+
|
344
|
+
if len(key):
|
345
|
+
child_key = self.get_child_key_fn(key)
|
346
|
+
|
275
347
|
return value, node
|
276
348
|
|
277
349
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
278
350
|
# new_node -> child
|
279
351
|
new_node = TreeNode()
|
280
|
-
new_node.children = {key[split_len]: child}
|
352
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
281
353
|
new_node.parent = child.parent
|
282
354
|
new_node.lock_ref = child.lock_ref
|
283
355
|
new_node.key = child.key[:split_len]
|
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
|
|
285
357
|
child.parent = new_node
|
286
358
|
child.key = child.key[split_len:]
|
287
359
|
child.value = child.value[split_len:]
|
288
|
-
new_node.parent.children[key
|
360
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
289
361
|
return new_node
|
290
362
|
|
291
363
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
|
|
293
365
|
if len(key) == 0:
|
294
366
|
return 0
|
295
367
|
|
368
|
+
child_key = self.get_child_key_fn(key)
|
369
|
+
|
296
370
|
total_prefix_length = 0
|
297
|
-
while len(key) > 0 and
|
298
|
-
node = node.children[
|
371
|
+
while len(key) > 0 and child_key in node.children.keys():
|
372
|
+
node = node.children[child_key]
|
299
373
|
node.last_access_time = time.time()
|
300
|
-
prefix_len =
|
374
|
+
prefix_len = self.key_match_fn(node.key, key)
|
301
375
|
total_prefix_length += prefix_len
|
302
376
|
key = key[prefix_len:]
|
303
377
|
value = value[prefix_len:]
|
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
|
|
306
380
|
new_node = self._split_node(node.key, node, prefix_len)
|
307
381
|
node = new_node
|
308
382
|
|
383
|
+
if len(key):
|
384
|
+
child_key = self.get_child_key_fn(key)
|
385
|
+
|
309
386
|
if len(key):
|
310
387
|
new_node = TreeNode()
|
311
388
|
new_node.parent = node
|
312
389
|
new_node.key = key
|
313
390
|
new_node.value = value
|
314
|
-
node.children[
|
391
|
+
node.children[child_key] = new_node
|
315
392
|
self.evictable_size_ += len(value)
|
316
393
|
return total_prefix_length
|
317
394
|
|
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
|
|
326
403
|
current_node.key[:10],
|
327
404
|
f"r={current_node.lock_ref}",
|
328
405
|
)
|
329
|
-
for
|
406
|
+
for key, child in current_node.children.items():
|
330
407
|
stack.append((child, current_indent + 2))
|
331
408
|
|
409
|
+
assert key == self.get_child_key_fn(
|
410
|
+
child.key
|
411
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
412
|
+
|
332
413
|
def _delete_leaf(self, node):
|
333
414
|
for k, v in node.parent.children.items():
|
334
415
|
if v == node:
|
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
|
|
363
444
|
|
364
445
|
|
365
446
|
if __name__ == "__main__":
|
366
|
-
tree = RadixCache(None, None, False)
|
447
|
+
tree = RadixCache(None, None, page_size=1, disable=False)
|
367
448
|
|
368
449
|
tree.insert("Hello")
|
369
450
|
tree.insert("Hello")
|
sglang/srt/metrics/collector.py
CHANGED
@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
|
|
121
121
|
labelnames=labels.keys(),
|
122
122
|
)
|
123
123
|
|
124
|
+
self.cached_tokens_total = Counter(
|
125
|
+
name="sglang:cached_tokens_total",
|
126
|
+
documentation="Number of cached prompt tokens.",
|
127
|
+
labelnames=labels.keys(),
|
128
|
+
)
|
129
|
+
|
124
130
|
self.num_requests_total = Counter(
|
125
131
|
name="sglang:num_requests_total",
|
126
132
|
documentation="Number of requests processed.",
|
@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
|
|
245
251
|
self,
|
246
252
|
prompt_tokens: int,
|
247
253
|
generation_tokens: int,
|
254
|
+
cached_tokens: int,
|
248
255
|
e2e_latency: float,
|
249
256
|
):
|
250
257
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
251
258
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
259
|
+
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
252
260
|
self.num_requests_total.labels(**self.labels).inc(1)
|
253
261
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
254
262
|
if generation_tokens >= 1:
|
@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
35
35
|
)
|
36
36
|
from sglang.srt.utils import is_hip
|
37
37
|
|
38
|
-
|
38
|
+
_is_hip = is_hip()
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
119
119
|
else:
|
120
120
|
capture_bs = list(range(1, 33))
|
121
121
|
|
122
|
-
if
|
122
|
+
if _is_hip:
|
123
123
|
capture_bs += [i * 8 for i in range(21, 33)]
|
124
124
|
|
125
125
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -264,11 +264,15 @@ class CudaGraphRunner:
|
|
264
264
|
def model_capture_mode(self):
|
265
265
|
if hasattr(self.model_runner.model, "capture_mode"):
|
266
266
|
self.model_runner.model.capture_mode = True
|
267
|
+
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
268
|
+
self.model_runner.token_to_kv_pool.capture_mode = True
|
267
269
|
|
268
270
|
yield
|
269
271
|
|
270
272
|
if hasattr(self.model_runner.model, "capture_mode"):
|
271
273
|
self.model_runner.model.capture_mode = False
|
274
|
+
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
275
|
+
self.model_runner.token_to_kv_pool.capture_mode = False
|
272
276
|
|
273
277
|
def can_run(self, forward_batch: ForwardBatch):
|
274
278
|
if self.enable_dp_attention:
|
@@ -300,10 +304,11 @@ class CudaGraphRunner:
|
|
300
304
|
def capture(self):
|
301
305
|
with graph_capture() as graph_capture_context:
|
302
306
|
self.stream = graph_capture_context.stream
|
307
|
+
# Reverse the order to enable better memory sharing across cuda graphs.
|
303
308
|
capture_range = (
|
304
|
-
tqdm.tqdm(self.capture_bs)
|
309
|
+
tqdm.tqdm(list(reversed(self.capture_bs)))
|
305
310
|
if get_tensor_model_parallel_rank() == 0
|
306
|
-
else self.capture_bs
|
311
|
+
else reversed(self.capture_bs)
|
307
312
|
)
|
308
313
|
for bs in capture_range:
|
309
314
|
with patch_model(
|
@@ -396,16 +401,10 @@ class CudaGraphRunner:
|
|
396
401
|
|
397
402
|
run_once()
|
398
403
|
|
399
|
-
torch.cuda.synchronize()
|
400
|
-
self.model_runner.tp_group.barrier()
|
401
|
-
|
402
404
|
global global_graph_memory_pool
|
403
405
|
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
404
406
|
out = run_once()
|
405
407
|
|
406
|
-
torch.cuda.synchronize()
|
407
|
-
self.model_runner.tp_group.barrier()
|
408
|
-
|
409
408
|
global_graph_memory_pool = graph.pool()
|
410
409
|
return graph, out
|
411
410
|
|
@@ -427,7 +426,7 @@ class CudaGraphRunner:
|
|
427
426
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
428
427
|
self.capture()
|
429
428
|
|
430
|
-
def replay(self, forward_batch: ForwardBatch):
|
429
|
+
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
431
430
|
self.recapture_if_needed(forward_batch)
|
432
431
|
|
433
432
|
raw_bs = forward_batch.batch_size
|