sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
"""ModelRunner runs the forward passes of the models."""
|
2
|
+
|
3
|
+
import importlib
|
4
|
+
import importlib.resources
|
5
|
+
import logging
|
6
|
+
import pkgutil
|
7
|
+
from functools import lru_cache
|
8
|
+
from typing import Optional, Type
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import torch.nn as nn
|
12
|
+
from vllm.config import DeviceConfig, LoadConfig
|
13
|
+
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
|
15
|
+
from vllm.model_executor.model_loader import get_model
|
16
|
+
from vllm.model_executor.models import ModelRegistry
|
17
|
+
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.managers.controller.infer_batch import (
|
20
|
+
Batch,
|
21
|
+
ForwardMode,
|
22
|
+
InputMetadata,
|
23
|
+
global_server_args_dict,
|
24
|
+
)
|
25
|
+
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
26
|
+
from sglang.srt.server_args import ServerArgs
|
27
|
+
from sglang.srt.utils import (
|
28
|
+
get_available_gpu_memory,
|
29
|
+
is_multimodal_model,
|
30
|
+
monkey_patch_vllm_dummy_weight_loader,
|
31
|
+
monkey_patch_vllm_p2p_access_check,
|
32
|
+
)
|
33
|
+
|
34
|
+
logger = logging.getLogger("srt.model_runner")
|
35
|
+
|
36
|
+
|
37
|
+
class ModelRunner:
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
model_config,
|
41
|
+
mem_fraction_static: float,
|
42
|
+
gpu_id: int,
|
43
|
+
tp_rank: int,
|
44
|
+
tp_size: int,
|
45
|
+
nccl_port: int,
|
46
|
+
server_args: ServerArgs,
|
47
|
+
):
|
48
|
+
# Parse args
|
49
|
+
self.model_config = model_config
|
50
|
+
self.mem_fraction_static = mem_fraction_static
|
51
|
+
self.gpu_id = gpu_id
|
52
|
+
self.tp_rank = tp_rank
|
53
|
+
self.tp_size = tp_size
|
54
|
+
self.nccl_port = nccl_port
|
55
|
+
self.server_args = server_args
|
56
|
+
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
57
|
+
monkey_patch_vllm_dummy_weight_loader()
|
58
|
+
|
59
|
+
# Init torch distributed
|
60
|
+
torch.cuda.set_device(self.gpu_id)
|
61
|
+
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
62
|
+
|
63
|
+
if not server_args.enable_p2p_check:
|
64
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
65
|
+
|
66
|
+
if server_args.nccl_init_addr:
|
67
|
+
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
68
|
+
else:
|
69
|
+
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
70
|
+
init_distributed_environment(
|
71
|
+
backend="nccl",
|
72
|
+
world_size=self.tp_size,
|
73
|
+
rank=self.tp_rank,
|
74
|
+
local_rank=self.gpu_id,
|
75
|
+
distributed_init_method=nccl_init_method,
|
76
|
+
)
|
77
|
+
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
78
|
+
self.tp_group = get_tp_group()
|
79
|
+
total_gpu_memory = get_available_gpu_memory(
|
80
|
+
self.gpu_id, distributed=self.tp_size > 1
|
81
|
+
)
|
82
|
+
|
83
|
+
if self.tp_size > 1:
|
84
|
+
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
85
|
+
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
86
|
+
raise ValueError(
|
87
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
88
|
+
)
|
89
|
+
|
90
|
+
# Set some global args
|
91
|
+
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
92
|
+
global_server_args_dict[
|
93
|
+
"attention_reduce_in_fp32"
|
94
|
+
] = server_args.attention_reduce_in_fp32
|
95
|
+
|
96
|
+
# Load the model and create memory pool
|
97
|
+
self.load_model()
|
98
|
+
self.init_memory_pool(total_gpu_memory)
|
99
|
+
self.init_cublas()
|
100
|
+
self.init_flash_infer()
|
101
|
+
|
102
|
+
# Capture cuda graphs
|
103
|
+
self.init_cuda_graphs()
|
104
|
+
|
105
|
+
def load_model(self):
|
106
|
+
logger.info(
|
107
|
+
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
108
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
109
|
+
)
|
110
|
+
|
111
|
+
device_config = DeviceConfig()
|
112
|
+
load_config = LoadConfig(load_format=self.server_args.load_format)
|
113
|
+
vllm_model_config = VllmModelConfig(
|
114
|
+
model=self.server_args.model_path,
|
115
|
+
quantization=self.server_args.quantization,
|
116
|
+
tokenizer=None,
|
117
|
+
tokenizer_mode=None,
|
118
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
119
|
+
dtype=self.server_args.dtype,
|
120
|
+
seed=42,
|
121
|
+
skip_tokenizer_init=True,
|
122
|
+
)
|
123
|
+
self.dtype = vllm_model_config.dtype
|
124
|
+
if self.model_config.model_overide_args is not None:
|
125
|
+
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
126
|
+
|
127
|
+
self.model = get_model(
|
128
|
+
model_config=vllm_model_config,
|
129
|
+
device_config=device_config,
|
130
|
+
load_config=load_config,
|
131
|
+
lora_config=None,
|
132
|
+
multimodal_config=None,
|
133
|
+
parallel_config=None,
|
134
|
+
scheduler_config=None,
|
135
|
+
cache_config=None,
|
136
|
+
)
|
137
|
+
logger.info(
|
138
|
+
f"[gpu_id={self.gpu_id}] Load weight end. "
|
139
|
+
f"type={type(self.model).__name__}, "
|
140
|
+
f"dtype={self.dtype}, "
|
141
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
142
|
+
)
|
143
|
+
|
144
|
+
def profile_max_num_token(self, total_gpu_memory):
|
145
|
+
available_gpu_memory = get_available_gpu_memory(
|
146
|
+
self.gpu_id, distributed=self.tp_size > 1
|
147
|
+
)
|
148
|
+
head_dim = self.model_config.head_dim
|
149
|
+
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
150
|
+
cell_size = (
|
151
|
+
head_num
|
152
|
+
* head_dim
|
153
|
+
* self.model_config.num_hidden_layers
|
154
|
+
* 2
|
155
|
+
* torch._utils._element_size(self.dtype)
|
156
|
+
)
|
157
|
+
rest_memory = available_gpu_memory - total_gpu_memory * (
|
158
|
+
1 - self.mem_fraction_static
|
159
|
+
)
|
160
|
+
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
161
|
+
return max_num_token
|
162
|
+
|
163
|
+
def init_memory_pool(self, total_gpu_memory):
|
164
|
+
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
165
|
+
|
166
|
+
if self.max_total_num_tokens <= 0:
|
167
|
+
raise RuntimeError(
|
168
|
+
"Not enough memory. Please try to increase --mem-fraction-static."
|
169
|
+
)
|
170
|
+
|
171
|
+
self.req_to_token_pool = ReqToTokenPool(
|
172
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
173
|
+
self.model_config.context_len + 8,
|
174
|
+
)
|
175
|
+
self.token_to_kv_pool = TokenToKVPool(
|
176
|
+
self.max_total_num_tokens,
|
177
|
+
dtype=self.dtype,
|
178
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
179
|
+
head_dim=self.model_config.head_dim,
|
180
|
+
layer_num=self.model_config.num_hidden_layers,
|
181
|
+
)
|
182
|
+
logger.info(
|
183
|
+
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
184
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
185
|
+
)
|
186
|
+
|
187
|
+
def init_cublas(self):
|
188
|
+
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
189
|
+
dtype = torch.float16
|
190
|
+
device = "cuda"
|
191
|
+
a = torch.ones((16, 16), dtype=dtype, device=device)
|
192
|
+
b = torch.ones((16, 16), dtype=dtype, device=device)
|
193
|
+
c = a @ b
|
194
|
+
return c
|
195
|
+
|
196
|
+
def init_flash_infer(self):
|
197
|
+
if self.server_args.disable_flashinfer:
|
198
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
199
|
+
self.flashinfer_prefill_wrapper_paged = None
|
200
|
+
self.flashinfer_decode_wrapper = None
|
201
|
+
return
|
202
|
+
|
203
|
+
from flashinfer import (
|
204
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
205
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
206
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
207
|
+
)
|
208
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
209
|
+
|
210
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
211
|
+
self.model_config.num_attention_heads // self.tp_size,
|
212
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
213
|
+
):
|
214
|
+
use_tensor_cores = True
|
215
|
+
else:
|
216
|
+
use_tensor_cores = False
|
217
|
+
|
218
|
+
self.flashinfer_workspace_buffers = torch.empty(
|
219
|
+
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
220
|
+
)
|
221
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
222
|
+
self.flashinfer_workspace_buffers[0], "NHD"
|
223
|
+
)
|
224
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
225
|
+
self.flashinfer_workspace_buffers[1], "NHD"
|
226
|
+
)
|
227
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
228
|
+
self.flashinfer_workspace_buffers[0],
|
229
|
+
"NHD",
|
230
|
+
use_tensor_cores=use_tensor_cores,
|
231
|
+
)
|
232
|
+
|
233
|
+
def init_cuda_graphs(self):
|
234
|
+
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
|
235
|
+
|
236
|
+
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
237
|
+
self.cuda_graph_runner = None
|
238
|
+
return
|
239
|
+
|
240
|
+
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
241
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
242
|
+
self.cuda_graph_runner = CudaGraphRunner(
|
243
|
+
self, max_batch_size_to_capture=max(batch_size_list)
|
244
|
+
)
|
245
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
246
|
+
|
247
|
+
@torch.inference_mode()
|
248
|
+
def forward_decode(self, batch: Batch):
|
249
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
250
|
+
return self.cuda_graph_runner.replay(batch)
|
251
|
+
|
252
|
+
input_metadata = InputMetadata.create(
|
253
|
+
self,
|
254
|
+
forward_mode=ForwardMode.DECODE,
|
255
|
+
req_pool_indices=batch.req_pool_indices,
|
256
|
+
seq_lens=batch.seq_lens,
|
257
|
+
prefix_lens=batch.prefix_lens,
|
258
|
+
position_ids_offsets=batch.position_ids_offsets,
|
259
|
+
out_cache_loc=batch.out_cache_loc,
|
260
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
261
|
+
return_logprob=batch.return_logprob,
|
262
|
+
)
|
263
|
+
return self.model.forward(
|
264
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
265
|
+
)
|
266
|
+
|
267
|
+
@torch.inference_mode()
|
268
|
+
def forward_extend(self, batch: Batch):
|
269
|
+
input_metadata = InputMetadata.create(
|
270
|
+
self,
|
271
|
+
forward_mode=ForwardMode.EXTEND,
|
272
|
+
req_pool_indices=batch.req_pool_indices,
|
273
|
+
seq_lens=batch.seq_lens,
|
274
|
+
prefix_lens=batch.prefix_lens,
|
275
|
+
position_ids_offsets=batch.position_ids_offsets,
|
276
|
+
out_cache_loc=batch.out_cache_loc,
|
277
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
278
|
+
return_logprob=batch.return_logprob,
|
279
|
+
)
|
280
|
+
return self.model.forward(
|
281
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
282
|
+
)
|
283
|
+
|
284
|
+
@torch.inference_mode()
|
285
|
+
def forward_extend_multi_modal(self, batch: Batch):
|
286
|
+
input_metadata = InputMetadata.create(
|
287
|
+
self,
|
288
|
+
forward_mode=ForwardMode.EXTEND,
|
289
|
+
req_pool_indices=batch.req_pool_indices,
|
290
|
+
seq_lens=batch.seq_lens,
|
291
|
+
prefix_lens=batch.prefix_lens,
|
292
|
+
position_ids_offsets=batch.position_ids_offsets,
|
293
|
+
out_cache_loc=batch.out_cache_loc,
|
294
|
+
return_logprob=batch.return_logprob,
|
295
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
296
|
+
)
|
297
|
+
return self.model.forward(
|
298
|
+
batch.input_ids,
|
299
|
+
input_metadata.positions,
|
300
|
+
input_metadata,
|
301
|
+
batch.pixel_values,
|
302
|
+
batch.image_sizes,
|
303
|
+
batch.image_offsets,
|
304
|
+
)
|
305
|
+
|
306
|
+
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
307
|
+
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
308
|
+
return self.forward_extend_multi_modal(batch)
|
309
|
+
elif forward_mode == ForwardMode.DECODE:
|
310
|
+
return self.forward_decode(batch)
|
311
|
+
elif forward_mode == ForwardMode.EXTEND:
|
312
|
+
return self.forward_extend(batch)
|
313
|
+
else:
|
314
|
+
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
315
|
+
|
316
|
+
|
317
|
+
@lru_cache()
|
318
|
+
def import_model_classes():
|
319
|
+
model_arch_name_to_cls = {}
|
320
|
+
package_name = "sglang.srt.models"
|
321
|
+
package = importlib.import_module(package_name)
|
322
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
323
|
+
if not ispkg:
|
324
|
+
module = importlib.import_module(name)
|
325
|
+
if hasattr(module, "EntryClass"):
|
326
|
+
entry = module.EntryClass
|
327
|
+
if isinstance(
|
328
|
+
entry, list
|
329
|
+
): # To support multiple model classes in one module
|
330
|
+
for tmp in entry:
|
331
|
+
model_arch_name_to_cls[tmp.__name__] = tmp
|
332
|
+
else:
|
333
|
+
model_arch_name_to_cls[entry.__name__] = entry
|
334
|
+
|
335
|
+
# compat: some models such as chatglm has incorrect class set in config.json
|
336
|
+
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
337
|
+
if hasattr(module, "EntryClassRemapping") and isinstance(
|
338
|
+
module.EntryClassRemapping, list
|
339
|
+
):
|
340
|
+
for remap in module.EntryClassRemapping:
|
341
|
+
if isinstance(remap, tuple) and len(remap) == 2:
|
342
|
+
model_arch_name_to_cls[remap[0]] = remap[1]
|
343
|
+
|
344
|
+
return model_arch_name_to_cls
|
345
|
+
|
346
|
+
|
347
|
+
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
348
|
+
model_arch_name_to_cls = import_model_classes()
|
349
|
+
|
350
|
+
if model_arch not in model_arch_name_to_cls:
|
351
|
+
raise ValueError(
|
352
|
+
f"Unsupported architectures: {model_arch}. "
|
353
|
+
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
354
|
+
)
|
355
|
+
return model_arch_name_to_cls[model_arch]
|
356
|
+
|
357
|
+
|
358
|
+
# Monkey patch model loader
|
359
|
+
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
@@ -1,8 +1,10 @@
|
|
1
|
+
"""
|
2
|
+
The radix tree data structure for managing the KV cache.
|
3
|
+
"""
|
4
|
+
|
1
5
|
import heapq
|
2
6
|
import time
|
3
7
|
from collections import defaultdict
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Tuple
|
6
8
|
|
7
9
|
import torch
|
8
10
|
|
@@ -11,34 +13,38 @@ class TreeNode:
|
|
11
13
|
def __init__(self):
|
12
14
|
self.children = defaultdict(TreeNode)
|
13
15
|
self.parent = None
|
16
|
+
self.key = None
|
14
17
|
self.value = None
|
15
|
-
self.
|
18
|
+
self.lock_ref = 0
|
16
19
|
self.last_access_time = time.time()
|
17
20
|
|
18
|
-
def __lt__(self, other):
|
21
|
+
def __lt__(self, other: "TreeNode"):
|
19
22
|
return self.last_access_time < other.last_access_time
|
20
23
|
|
21
24
|
|
22
|
-
def
|
25
|
+
def _key_match(key0, key1):
|
23
26
|
i = 0
|
24
|
-
for
|
25
|
-
if
|
27
|
+
for k0, k1 in zip(key0, key1):
|
28
|
+
if k0 != k1:
|
26
29
|
break
|
27
30
|
i += 1
|
28
31
|
return i
|
29
32
|
|
30
33
|
|
31
34
|
class RadixCache:
|
32
|
-
def __init__(self, disable=False):
|
33
|
-
self.
|
35
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
36
|
+
self.req_to_token_pool = req_to_token_pool
|
37
|
+
self.token_to_kv_pool = token_to_kv_pool
|
34
38
|
self.disable = disable
|
39
|
+
self.reset()
|
35
40
|
|
36
41
|
##### Public API #####
|
37
42
|
|
38
43
|
def reset(self):
|
39
44
|
self.root_node = TreeNode()
|
45
|
+
self.root_node.key = []
|
40
46
|
self.root_node.value = []
|
41
|
-
self.root_node.
|
47
|
+
self.root_node.lock_ref = 1
|
42
48
|
self.evictable_size_ = 0
|
43
49
|
|
44
50
|
def match_prefix(self, key):
|
@@ -50,16 +56,52 @@ class RadixCache:
|
|
50
56
|
self._match_prefix_helper(self.root_node, key, value, last_node)
|
51
57
|
if value:
|
52
58
|
value = torch.concat(value)
|
59
|
+
else:
|
60
|
+
value = torch.tensor([], dtype=torch.int64)
|
53
61
|
return value, last_node[0]
|
54
62
|
|
55
63
|
def insert(self, key, value=None):
|
56
64
|
if self.disable:
|
57
|
-
return
|
65
|
+
return 0
|
58
66
|
|
59
67
|
if value is None:
|
60
68
|
value = [x for x in key]
|
61
69
|
return self._insert_helper(self.root_node, key, value)
|
62
70
|
|
71
|
+
def cache_req(
|
72
|
+
self,
|
73
|
+
token_ids,
|
74
|
+
last_uncached_pos,
|
75
|
+
req_pool_idx,
|
76
|
+
del_in_memory_pool=True,
|
77
|
+
old_last_node=None,
|
78
|
+
):
|
79
|
+
# Insert the request into radix cache
|
80
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
81
|
+
new_prefix_len = self.insert(token_ids, indices.clone())
|
82
|
+
|
83
|
+
if self.disable:
|
84
|
+
if del_in_memory_pool:
|
85
|
+
self.token_to_kv_pool.free(indices)
|
86
|
+
else:
|
87
|
+
return torch.tensor([], dtype=torch.int64), self.root_node
|
88
|
+
|
89
|
+
# Radix Cache takes one ref in memory pool
|
90
|
+
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
91
|
+
|
92
|
+
if del_in_memory_pool:
|
93
|
+
self.req_to_token_pool.free(req_pool_idx)
|
94
|
+
else:
|
95
|
+
cached_indices, new_last_node = self.match_prefix(token_ids)
|
96
|
+
assert len(cached_indices) == len(token_ids)
|
97
|
+
|
98
|
+
self.req_to_token_pool.req_to_token[
|
99
|
+
req_pool_idx, last_uncached_pos : len(cached_indices)
|
100
|
+
] = cached_indices[last_uncached_pos:]
|
101
|
+
self.dec_lock_ref(old_last_node)
|
102
|
+
self.inc_lock_ref(new_last_node)
|
103
|
+
return cached_indices, new_last_node
|
104
|
+
|
63
105
|
def pretty_print(self):
|
64
106
|
self._print_helper(self.root_node, 0)
|
65
107
|
print(f"#tokens: {self.total_size()}")
|
@@ -69,7 +111,7 @@ class RadixCache:
|
|
69
111
|
|
70
112
|
def evict(self, num_tokens, evict_callback):
|
71
113
|
if self.disable:
|
72
|
-
|
114
|
+
return
|
73
115
|
|
74
116
|
leaves = self._collect_leaves()
|
75
117
|
heapq.heapify(leaves)
|
@@ -80,32 +122,33 @@ class RadixCache:
|
|
80
122
|
|
81
123
|
if x == self.root_node:
|
82
124
|
break
|
83
|
-
if x.
|
125
|
+
if x.lock_ref > 0:
|
84
126
|
continue
|
85
127
|
|
86
|
-
|
128
|
+
evict_callback(x.value)
|
129
|
+
num_evicted += len(x.value)
|
87
130
|
self._delete_leaf(x)
|
88
131
|
|
89
132
|
if len(x.parent.children) == 0:
|
90
133
|
heapq.heappush(leaves, x.parent)
|
91
134
|
|
92
|
-
def
|
135
|
+
def inc_lock_ref(self, node: TreeNode):
|
93
136
|
delta = 0
|
94
137
|
while node != self.root_node:
|
95
|
-
if node.
|
138
|
+
if node.lock_ref == 0:
|
96
139
|
self.evictable_size_ -= len(node.value)
|
97
140
|
delta -= len(node.value)
|
98
|
-
node.
|
141
|
+
node.lock_ref += 1
|
99
142
|
node = node.parent
|
100
143
|
return delta
|
101
144
|
|
102
|
-
def
|
145
|
+
def dec_lock_ref(self, node: TreeNode):
|
103
146
|
delta = 0
|
104
147
|
while node != self.root_node:
|
105
|
-
if node.
|
148
|
+
if node.lock_ref == 1:
|
106
149
|
self.evictable_size_ += len(node.value)
|
107
150
|
delta += len(node.value)
|
108
|
-
node.
|
151
|
+
node.lock_ref -= 1
|
109
152
|
node = node.parent
|
110
153
|
return delta
|
111
154
|
|
@@ -113,42 +156,48 @@ class RadixCache:
|
|
113
156
|
return self.evictable_size_
|
114
157
|
|
115
158
|
##### Internal Helper Functions #####
|
159
|
+
|
116
160
|
def _match_prefix_helper(self, node, key, value, last_node):
|
117
161
|
node.last_access_time = time.time()
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
162
|
+
if len(key) == 0:
|
163
|
+
return
|
164
|
+
|
165
|
+
if key[0] in node.children.keys():
|
166
|
+
child = node.children[key[0]]
|
167
|
+
prefix_len = _key_match(child.key, key)
|
168
|
+
if prefix_len < len(child.key):
|
169
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
170
|
+
value.append(new_node.value)
|
171
|
+
last_node[0] = new_node
|
172
|
+
else:
|
173
|
+
value.append(child.value)
|
174
|
+
last_node[0] = child
|
175
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
176
|
+
|
177
|
+
def _split_node(self, key, child: TreeNode, split_len):
|
133
178
|
# new_node -> child
|
134
179
|
new_node = TreeNode()
|
135
|
-
new_node.children = {key[split_len:]: child}
|
180
|
+
new_node.children = {key[split_len:][0]: child}
|
136
181
|
new_node.parent = child.parent
|
137
|
-
new_node.
|
182
|
+
new_node.lock_ref = child.lock_ref
|
183
|
+
new_node.key = child.key[:split_len]
|
138
184
|
new_node.value = child.value[:split_len]
|
139
185
|
child.parent = new_node
|
186
|
+
child.key = child.key[split_len:]
|
140
187
|
child.value = child.value[split_len:]
|
141
|
-
new_node.parent.children[key[:split_len]] = new_node
|
142
|
-
del new_node.parent.children[key]
|
188
|
+
new_node.parent.children[key[:split_len][0]] = new_node
|
143
189
|
return new_node
|
144
190
|
|
145
191
|
def _insert_helper(self, node, key, value):
|
146
192
|
node.last_access_time = time.time()
|
193
|
+
if len(key) == 0:
|
194
|
+
return 0
|
147
195
|
|
148
|
-
|
149
|
-
|
196
|
+
if key[0] in node.children.keys():
|
197
|
+
child = node.children[key[0]]
|
198
|
+
prefix_len = _key_match(child.key, key)
|
150
199
|
|
151
|
-
if prefix_len == len(
|
200
|
+
if prefix_len == len(child.key):
|
152
201
|
if prefix_len == len(key):
|
153
202
|
return prefix_len
|
154
203
|
else:
|
@@ -156,23 +205,23 @@ class RadixCache:
|
|
156
205
|
value = value[prefix_len:]
|
157
206
|
return prefix_len + self._insert_helper(child, key, value)
|
158
207
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
)
|
208
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
209
|
+
return prefix_len + self._insert_helper(
|
210
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
211
|
+
)
|
164
212
|
|
165
213
|
if len(key):
|
166
214
|
new_node = TreeNode()
|
167
215
|
new_node.parent = node
|
216
|
+
new_node.key = key
|
168
217
|
new_node.value = value
|
169
|
-
node.children[key] = new_node
|
218
|
+
node.children[key[0]] = new_node
|
170
219
|
self.evictable_size_ += len(value)
|
171
220
|
return 0
|
172
221
|
|
173
|
-
def _print_helper(self, node, indent):
|
174
|
-
for
|
175
|
-
print(" " * indent, len(key), key[:10], f"r={child.
|
222
|
+
def _print_helper(self, node: TreeNode, indent):
|
223
|
+
for _, child in node.children.items():
|
224
|
+
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
176
225
|
self._print_helper(child, indent=indent + 2)
|
177
226
|
|
178
227
|
def _delete_leaf(self, node):
|
@@ -180,7 +229,7 @@ class RadixCache:
|
|
180
229
|
if v == node:
|
181
230
|
break
|
182
231
|
del node.parent.children[k]
|
183
|
-
self.evictable_size_ -= len(
|
232
|
+
self.evictable_size_ -= len(node.key)
|
184
233
|
|
185
234
|
def _total_size_helper(self, node):
|
186
235
|
x = len(node.value)
|
@@ -203,7 +252,7 @@ class RadixCache:
|
|
203
252
|
|
204
253
|
|
205
254
|
if __name__ == "__main__":
|
206
|
-
tree = RadixCache(
|
255
|
+
tree = RadixCache(None, None, False)
|
207
256
|
|
208
257
|
tree.insert("Hello")
|
209
258
|
tree.insert("Hello")
|