sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
1
+ """
2
+ Copyright 2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Page-aligned memory pool.
18
+ """
19
+
20
+ import torch
21
+ import triton
22
+ import triton.language as tl
23
+
24
+ from sglang.srt.mem_cache.memory_pool import KVCache
25
+ from sglang.srt.utils import get_bool_env_var, next_power_of_2
26
+
27
+
28
+ @triton.jit
29
+ def alloc_extend_kernel(
30
+ pre_lens_ptr,
31
+ seq_lens_ptr,
32
+ last_loc_ptr,
33
+ free_page_ptr,
34
+ out_indices,
35
+ ret_values,
36
+ bs_upper: tl.constexpr,
37
+ page_size: tl.constexpr,
38
+ max_num_extend_tokens: tl.constexpr,
39
+ ):
40
+ pid = tl.program_id(0)
41
+
42
+ load_offset = tl.arange(0, bs_upper)
43
+ seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
44
+ pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
45
+ extend_lens = seq_lens - pre_lens
46
+
47
+ seq_len = tl.load(seq_lens_ptr + pid)
48
+ pre_len = tl.load(pre_lens_ptr + pid)
49
+ extend_len = seq_len - pre_len
50
+
51
+ sum_extend_lens = tl.sum(extend_lens)
52
+ output_start_loc = sum_extend_lens - extend_len
53
+
54
+ num_pages_after = (seq_lens + page_size - 1) // page_size
55
+ num_pages_before = (pre_lens + page_size - 1) // page_size
56
+ num_new_pages = num_pages_after - num_pages_before
57
+
58
+ num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
59
+ pre_len + page_size - 1
60
+ ) // page_size
61
+ sum_num_new_pages = tl.sum(num_new_pages)
62
+ new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
63
+
64
+ # Return value
65
+ if pid == tl.num_programs(0) - 1:
66
+ merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
67
+ tl.int64
68
+ )
69
+ tl.store(ret_values, merged_value)
70
+
71
+ # Part 1: fill the old partial page
72
+ last_loc = tl.load(last_loc_ptr + pid)
73
+ num_part1 = (
74
+ min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
75
+ )
76
+ offset_one_page = tl.arange(0, page_size)
77
+ tl.store(
78
+ out_indices + output_start_loc + offset_one_page,
79
+ last_loc + 1 + offset_one_page,
80
+ mask=offset_one_page < num_part1,
81
+ )
82
+ if pre_len + num_part1 == seq_len:
83
+ return
84
+
85
+ # Part 2: fill the new full pages
86
+ num_part2 = (
87
+ seq_len // page_size * page_size
88
+ - (pre_len + page_size - 1) // page_size * page_size
89
+ )
90
+
91
+ offset_many_page = tl.arange(0, max_num_extend_tokens)
92
+ page_start = tl.load(
93
+ free_page_ptr + new_page_start_loc + offset_many_page // page_size,
94
+ mask=offset_many_page < num_part2,
95
+ )
96
+ tl.store(
97
+ out_indices + output_start_loc + num_part1 + offset_many_page,
98
+ page_start * page_size + offset_many_page % page_size,
99
+ mask=offset_many_page < num_part2,
100
+ )
101
+ if pre_len + num_part1 + num_part2 == seq_len:
102
+ return
103
+
104
+ # Part 3: fill the new partial page
105
+ num_part3 = seq_len - seq_len // page_size * page_size
106
+ start_loc = tl.load(
107
+ free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
108
+ )
109
+ tl.store(
110
+ out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
111
+ start_loc * page_size + offset_one_page,
112
+ mask=offset_one_page < num_part3,
113
+ )
114
+
115
+
116
+ @triton.jit
117
+ def alloc_decode_kernel(
118
+ seq_lens_ptr,
119
+ last_loc_ptr,
120
+ free_page_ptr,
121
+ out_indices,
122
+ ret_values,
123
+ bs_upper: tl.constexpr,
124
+ page_size: tl.constexpr,
125
+ ):
126
+ pid = tl.program_id(0)
127
+
128
+ load_offset = tl.arange(0, bs_upper)
129
+ seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
130
+ pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
131
+
132
+ seq_len = tl.load(seq_lens_ptr + pid)
133
+ pre_len = seq_len - 1
134
+
135
+ num_pages_after = (seq_lens + page_size - 1) // page_size
136
+ num_pages_before = (pre_lens + page_size - 1) // page_size
137
+ num_new_pages = num_pages_after - num_pages_before
138
+
139
+ num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
140
+ pre_len + page_size - 1
141
+ ) // page_size
142
+ sum_num_new_pages = tl.sum(num_new_pages)
143
+ new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
144
+
145
+ # Return value
146
+ if pid == tl.num_programs(0) - 1:
147
+ tl.store(ret_values, sum_num_new_pages)
148
+
149
+ if num_page_start_loc_self == 0:
150
+ last_loc = tl.load(last_loc_ptr + pid)
151
+ tl.store(out_indices + pid, last_loc + 1)
152
+ else:
153
+ page = tl.load(free_page_ptr + new_page_start_loc)
154
+ tl.store(out_indices + pid, page * page_size)
155
+
156
+
157
+ class PagedTokenToKVPoolAllocator:
158
+ """
159
+ An allocator managing the indices to kv cache data.
160
+
161
+ This class has the same interface as `TokenToKVPoolAllocator` but the output
162
+ of one request is always page-aligned.
163
+
164
+ TODO: fuse last_loc into the kernel.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ size: int,
170
+ page_size: int,
171
+ dtype: torch.dtype,
172
+ device: str,
173
+ kvcache: KVCache,
174
+ ):
175
+ self.size = size
176
+ self.dtype = dtype
177
+ self.device = device
178
+ self.page_size = page_size
179
+ self.num_pages = size // page_size
180
+
181
+ self.free_pages = None
182
+ self.is_not_in_free_group = True
183
+ self.free_group = []
184
+ self.clear()
185
+ self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
186
+
187
+ self._kvcache = kvcache
188
+ self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
189
+
190
+ def available_size(self):
191
+ return len(self.free_pages) * self.page_size
192
+
193
+ def alloc_extend(
194
+ self,
195
+ prefix_lens: torch.Tensor,
196
+ seq_lens: torch.Tensor,
197
+ last_loc: torch.Tensor,
198
+ extend_num_tokens: int,
199
+ ):
200
+ if self.debug_mode:
201
+ assert torch.all(
202
+ (last_loc + 1) % self.page_size == prefix_lens % self.page_size
203
+ )
204
+
205
+ bs = len(prefix_lens)
206
+ out_indices = torch.empty(
207
+ (extend_num_tokens,), dtype=torch.int64, device=self.device
208
+ )
209
+ alloc_extend_kernel[(bs,)](
210
+ prefix_lens,
211
+ seq_lens,
212
+ last_loc,
213
+ self.free_pages,
214
+ out_indices,
215
+ self.ret_values,
216
+ next_power_of_2(bs),
217
+ self.page_size,
218
+ next_power_of_2(extend_num_tokens),
219
+ )
220
+
221
+ merged_value = self.ret_values.item()
222
+ num_new_pages = merged_value >> 32
223
+ if num_new_pages > len(self.free_pages):
224
+ return None
225
+
226
+ self.free_pages = self.free_pages[num_new_pages:]
227
+ return out_indices
228
+
229
+ def alloc_decode(
230
+ self,
231
+ seq_lens: torch.Tensor,
232
+ last_loc: torch.Tensor,
233
+ ):
234
+ if self.debug_mode:
235
+ assert torch.all(
236
+ (last_loc + 2) % self.page_size == seq_lens % self.page_size
237
+ )
238
+
239
+ bs = len(seq_lens)
240
+ out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
241
+ alloc_decode_kernel[(bs,)](
242
+ seq_lens,
243
+ last_loc,
244
+ self.free_pages,
245
+ out_indices,
246
+ self.ret_values,
247
+ next_power_of_2(bs),
248
+ self.page_size,
249
+ )
250
+
251
+ num_new_pages = self.ret_values.item()
252
+ if num_new_pages > len(self.free_pages):
253
+ return None
254
+
255
+ self.free_pages = self.free_pages[num_new_pages:]
256
+ return out_indices
257
+
258
+ def free(self, free_index: torch.Tensor):
259
+ if free_index.numel() == 0:
260
+ return
261
+
262
+ if self.is_not_in_free_group:
263
+ free_page_indices = torch.unique(free_index // self.page_size)
264
+ self.free_pages = torch.cat((free_page_indices, self.free_pages))
265
+ else:
266
+ self.free_group.append(free_index)
267
+
268
+ def free_group_begin(self):
269
+ self.is_not_in_free_group = False
270
+ self.free_group = []
271
+
272
+ def free_group_end(self):
273
+ self.is_not_in_free_group = True
274
+ if self.free_group:
275
+ self.free(torch.concat(self.free_group))
276
+
277
+ def clear(self):
278
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
279
+ self.free_pages = torch.arange(
280
+ 1, self.num_pages + 1, dtype=torch.int64, device=self.device
281
+ )
282
+ self.is_in_free_group = False
283
+ self.free_group = []
@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
25
+ from functools import partial
26
+ from typing import TYPE_CHECKING, List, Optional, Tuple
26
27
 
27
28
  import torch
28
29
 
@@ -67,7 +68,7 @@ class TreeNode:
67
68
  return self.last_access_time < other.last_access_time
68
69
 
69
70
 
70
- def _key_match(key0: List, key1: List):
71
+ def _key_match_page_size1(key0: List, key1: List):
71
72
  i = 0
72
73
  for k0, k1 in zip(key0, key1):
73
74
  if k0 != k1:
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
76
77
  return i
77
78
 
78
79
 
80
+ def _key_match_paged(key0: List, key1: List, page_size: int):
81
+ min_len = min(len(key0), len(key1))
82
+
83
+ i = 0
84
+ while i < min_len:
85
+ if key0[i : i + page_size] != key1[i : i + page_size]:
86
+ break
87
+ i += page_size
88
+
89
+ return i
90
+
91
+
79
92
  class RadixCache(BasePrefixCache):
80
93
  def __init__(
81
94
  self,
82
95
  req_to_token_pool: ReqToTokenPool,
83
96
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
97
+ page_size: int,
84
98
  disable: bool = False,
85
99
  ):
86
100
  self.req_to_token_pool = req_to_token_pool
87
101
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
102
+ self.page_size = page_size
88
103
  self.disable = disable
104
+
105
+ if self.token_to_kv_pool_allocator:
106
+ self.device = self.token_to_kv_pool_allocator.device
107
+ else:
108
+ self.device = torch.device("cpu")
109
+
110
+ if self.page_size == 1:
111
+ self.key_match_fn = _key_match_page_size1
112
+ self.get_child_key_fn = lambda key: key[0]
113
+ else:
114
+ self.key_match_fn = partial(_key_match_paged, page_size=page_size)
115
+ self.get_child_key_fn = lambda key: tuple(key[:page_size])
89
116
  self.reset()
90
117
 
91
118
  ##### Public API #####
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
109
136
  The last node create a new child if the prefix is shorter
110
137
  than the last node's value.
111
138
  """
112
- if self.disable:
113
- return [], self.root_node
139
+ if self.disable or len(key) == 0:
140
+ return (
141
+ torch.empty(
142
+ (0,),
143
+ dtype=torch.int32,
144
+ device=self.device,
145
+ ),
146
+ self.root_node,
147
+ )
148
+
149
+ if self.page_size != 1:
150
+ page_aligned_len = len(key) // self.page_size * self.page_size
151
+ key = key[:page_aligned_len]
114
152
 
115
153
  value, last_node = self._match_prefix_helper(self.root_node, key)
116
154
  if value:
117
155
  value = torch.concat(value)
118
156
  else:
119
- value = torch.tensor([], dtype=torch.int32)
157
+ value = torch.empty((0,), dtype=torch.int32, device=self.device)
120
158
  return value, last_node
121
159
 
122
160
  def insert(self, key: List, value=None):
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
127
165
  value = [x for x in key]
128
166
  return self._insert_helper(self.root_node, key, value)
129
167
 
130
- def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
168
+ def cache_finished_req(self, req: Req):
131
169
  """Cache request when it finishes."""
132
170
  if self.disable:
133
- if token_ids is None:
134
- token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
135
- else:
136
- token_ids_len = len(token_ids)
137
-
138
171
  kv_indices = self.req_to_token_pool.req_to_token[
139
- req.req_pool_idx, :token_ids_len
172
+ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
140
173
  ]
141
174
  self.token_to_kv_pool_allocator.free(kv_indices)
142
175
  self.req_to_token_pool.free(req.req_pool_idx)
143
176
  return
144
177
 
145
- if token_ids is None:
146
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
178
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
147
179
  kv_indices = self.req_to_token_pool.req_to_token[
148
180
  req.req_pool_idx, : len(token_ids)
149
181
  ]
150
182
 
183
+ if self.page_size != 1:
184
+ page_aligned_len = len(kv_indices) // self.page_size * self.page_size
185
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
186
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
187
+ else:
188
+ page_aligned_len = len(kv_indices)
189
+ page_aligned_kv_indices = kv_indices.clone()
190
+
151
191
  # Radix Cache takes one ref in memory pool
152
- new_prefix_len = self.insert(token_ids, kv_indices.clone())
192
+ new_prefix_len = self.insert(
193
+ token_ids[:page_aligned_len], page_aligned_kv_indices
194
+ )
153
195
  self.token_to_kv_pool_allocator.free(
154
196
  kv_indices[len(req.prefix_indices) : new_prefix_len]
155
197
  )
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
158
200
  self.req_to_token_pool.free(req.req_pool_idx)
159
201
  self.dec_lock_ref(req.last_node)
160
202
 
161
- def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
203
+ def cache_unfinished_req(self, req: Req):
162
204
  """Cache request when it is unfinished."""
163
205
  if self.disable:
164
206
  return
165
207
 
166
- if token_ids is None:
167
- token_ids = req.fill_ids
168
-
208
+ token_ids = req.fill_ids
169
209
  kv_indices = self.req_to_token_pool.req_to_token[
170
210
  req.req_pool_idx, : len(token_ids)
171
211
  ]
172
212
 
213
+ if self.page_size != 1:
214
+ page_aligned_len = len(kv_indices) // self.page_size * self.page_size
215
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
216
+ else:
217
+ page_aligned_len = len(kv_indices)
218
+ page_aligned_kv_indices = kv_indices.clone()
219
+ page_aligned_token_ids = token_ids[:page_aligned_len]
220
+
173
221
  # Radix Cache takes one ref in memory pool
174
- new_prefix_len = self.insert(token_ids, kv_indices.clone())
222
+ new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
175
223
  self.token_to_kv_pool_allocator.free(
176
224
  kv_indices[len(req.prefix_indices) : new_prefix_len]
177
225
  )
178
226
 
179
227
  # The prefix indices could be updated, reuse it
180
- new_indices, new_last_node = self.match_prefix(token_ids)
181
- assert len(new_indices) == len(token_ids)
228
+ new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
182
229
  self.req_to_token_pool.write(
183
230
  (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
184
231
  new_indices[len(req.prefix_indices) :],
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
186
233
 
187
234
  self.dec_lock_ref(req.last_node)
188
235
  self.inc_lock_ref(new_last_node)
189
- req.prefix_indices = new_indices
236
+
237
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
238
+ if self.page_size != 1:
239
+ req.prefix_indices = torch.cat(
240
+ [new_indices, kv_indices[len(new_indices) :]]
241
+ )
242
+ else:
243
+ req.prefix_indices = new_indices
190
244
  req.last_node = new_last_node
191
245
 
192
246
  def pretty_print(self):
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
196
250
  def total_size(self):
197
251
  return self._total_size_helper()
198
252
 
199
- def evict(self, num_tokens: int, evict_callback: Callable):
253
+ def evict(self, num_tokens: int):
200
254
  if self.disable:
201
255
  return
202
256
 
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
212
266
  if x.lock_ref > 0:
213
267
  continue
214
268
 
215
- evict_callback(x.value)
269
+ self.token_to_kv_pool_allocator.free(x.value)
216
270
  num_evicted += len(x.value)
217
271
  self._delete_leaf(x)
218
272
 
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
254
308
  # protected size refers to the size of the cache that is locked
255
309
  return self.protected_size_
256
310
 
311
+ def all_values_flatten(self):
312
+ values = []
313
+
314
+ def _dfs_helper(node: TreeNode):
315
+ for _, child in node.children.items():
316
+ values.append(child.value)
317
+ _dfs_helper(child)
318
+
319
+ _dfs_helper(self.root_node)
320
+ return torch.concat(values)
321
+
257
322
  ##### Internal Helper Functions #####
258
323
 
259
324
  def _match_prefix_helper(self, node: TreeNode, key: List):
260
325
  node.last_access_time = time.time()
326
+
327
+ child_key = self.get_child_key_fn(key)
328
+
261
329
  value = []
262
- while len(key) > 0 and key[0] in node.children.keys():
263
- child = node.children[key[0]]
330
+ while len(key) > 0 and child_key in node.children.keys():
331
+ child = node.children[child_key]
264
332
  child.last_access_time = time.time()
265
- prefix_len = _key_match(child.key, key)
333
+ prefix_len = self.key_match_fn(child.key, key)
266
334
  if prefix_len < len(child.key):
267
335
  new_node = self._split_node(child.key, child, prefix_len)
268
336
  value.append(new_node.value)
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
272
340
  value.append(child.value)
273
341
  node = child
274
342
  key = key[prefix_len:]
343
+
344
+ if len(key):
345
+ child_key = self.get_child_key_fn(key)
346
+
275
347
  return value, node
276
348
 
277
349
  def _split_node(self, key, child: TreeNode, split_len: int):
278
350
  # new_node -> child
279
351
  new_node = TreeNode()
280
- new_node.children = {key[split_len]: child}
352
+ new_node.children = {self.get_child_key_fn(key[split_len:]): child}
281
353
  new_node.parent = child.parent
282
354
  new_node.lock_ref = child.lock_ref
283
355
  new_node.key = child.key[:split_len]
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
285
357
  child.parent = new_node
286
358
  child.key = child.key[split_len:]
287
359
  child.value = child.value[split_len:]
288
- new_node.parent.children[key[0]] = new_node
360
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
289
361
  return new_node
290
362
 
291
363
  def _insert_helper(self, node: TreeNode, key: List, value):
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
293
365
  if len(key) == 0:
294
366
  return 0
295
367
 
368
+ child_key = self.get_child_key_fn(key)
369
+
296
370
  total_prefix_length = 0
297
- while len(key) > 0 and key[0] in node.children.keys():
298
- node = node.children[key[0]]
371
+ while len(key) > 0 and child_key in node.children.keys():
372
+ node = node.children[child_key]
299
373
  node.last_access_time = time.time()
300
- prefix_len = _key_match(node.key, key)
374
+ prefix_len = self.key_match_fn(node.key, key)
301
375
  total_prefix_length += prefix_len
302
376
  key = key[prefix_len:]
303
377
  value = value[prefix_len:]
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
306
380
  new_node = self._split_node(node.key, node, prefix_len)
307
381
  node = new_node
308
382
 
383
+ if len(key):
384
+ child_key = self.get_child_key_fn(key)
385
+
309
386
  if len(key):
310
387
  new_node = TreeNode()
311
388
  new_node.parent = node
312
389
  new_node.key = key
313
390
  new_node.value = value
314
- node.children[key[0]] = new_node
391
+ node.children[child_key] = new_node
315
392
  self.evictable_size_ += len(value)
316
393
  return total_prefix_length
317
394
 
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
326
403
  current_node.key[:10],
327
404
  f"r={current_node.lock_ref}",
328
405
  )
329
- for _, child in current_node.children.items():
406
+ for key, child in current_node.children.items():
330
407
  stack.append((child, current_indent + 2))
331
408
 
409
+ assert key == self.get_child_key_fn(
410
+ child.key
411
+ ), f"{key=}, {self.get_child_key_fn(child.key)=}"
412
+
332
413
  def _delete_leaf(self, node):
333
414
  for k, v in node.parent.children.items():
334
415
  if v == node:
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
363
444
 
364
445
 
365
446
  if __name__ == "__main__":
366
- tree = RadixCache(None, None, False)
447
+ tree = RadixCache(None, None, page_size=1, disable=False)
367
448
 
368
449
  tree.insert("Hello")
369
450
  tree.insert("Hello")