sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,35 @@
|
|
1
1
|
import importlib
|
2
|
-
import
|
2
|
+
import importlib.resources
|
3
3
|
import inspect
|
4
|
+
import logging
|
5
|
+
import pkgutil
|
4
6
|
from dataclasses import dataclass
|
5
7
|
from functools import lru_cache
|
6
|
-
from
|
7
|
-
import importlib.resources
|
8
|
+
from typing import List
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
11
|
-
from
|
12
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
13
|
-
from sglang.srt.utils import is_multimodal_model
|
14
|
-
from sglang.utils import get_available_gpu_memory
|
12
|
+
from vllm.distributed import initialize_model_parallel
|
15
13
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
16
14
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
17
15
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
18
|
-
from vllm.model_executor.model_loader import
|
19
|
-
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
16
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
20
17
|
|
21
|
-
import
|
22
|
-
import
|
18
|
+
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
|
+
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
+
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
|
23
21
|
|
24
|
-
import sglang
|
25
22
|
|
26
|
-
|
23
|
+
QUANTIZATION_CONFIG_MAPPING = {
|
24
|
+
"awq": AWQConfig,
|
25
|
+
"gptq": GPTQConfig,
|
26
|
+
"marlin": MarlinConfig,
|
27
|
+
}
|
27
28
|
|
28
29
|
logger = logging.getLogger("model_runner")
|
29
30
|
|
30
|
-
|
31
31
|
# for server args in model endpoints
|
32
|
-
global_server_args_dict
|
32
|
+
global_server_args_dict = {}
|
33
33
|
|
34
34
|
|
35
35
|
@lru_cache()
|
@@ -37,7 +37,7 @@ def import_model_classes():
|
|
37
37
|
model_arch_name_to_cls = {}
|
38
38
|
package_name = "sglang.srt.models"
|
39
39
|
package = importlib.import_module(package_name)
|
40
|
-
for
|
40
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
41
|
if not ispkg:
|
42
42
|
module = importlib.import_module(name)
|
43
43
|
if hasattr(module, "EntryClass"):
|
@@ -87,6 +87,7 @@ class InputMetadata:
|
|
87
87
|
|
88
88
|
other_kv_index: torch.Tensor = None
|
89
89
|
return_logprob: bool = False
|
90
|
+
top_logprobs_nums: List[int] = None
|
90
91
|
|
91
92
|
# for flashinfer
|
92
93
|
qo_indptr: torch.Tensor = None
|
@@ -106,18 +107,20 @@ class InputMetadata:
|
|
106
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
107
108
|
)
|
108
109
|
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
110
|
+
self.kv_last_page_len = torch.ones(
|
111
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
114
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
109
115
|
self.kv_indices = torch.cat(
|
110
116
|
[
|
111
117
|
self.req_to_token_pool.req_to_token[
|
112
|
-
|
118
|
+
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
113
119
|
]
|
114
120
|
for i in range(self.batch_size)
|
115
121
|
],
|
116
122
|
dim=0,
|
117
123
|
).contiguous()
|
118
|
-
self.kv_last_page_len = torch.ones(
|
119
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
120
|
-
)
|
121
124
|
|
122
125
|
workspace_buffer = torch.empty(
|
123
126
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
@@ -140,13 +143,9 @@ class InputMetadata:
|
|
140
143
|
self.kv_last_page_len,
|
141
144
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
142
145
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
146
|
+
self.model_runner.model_config.head_dim,
|
143
147
|
]
|
144
148
|
|
145
|
-
# flashinfer >= 0.0.3
|
146
|
-
# FIXME: Drop this when flashinfer updates to 0.0.4
|
147
|
-
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
148
|
-
args.append(self.model_runner.model_config.head_dim)
|
149
|
-
|
150
149
|
self.prefill_wrapper.begin_forward(*args)
|
151
150
|
else:
|
152
151
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
@@ -183,6 +182,7 @@ class InputMetadata:
|
|
183
182
|
out_cache_loc,
|
184
183
|
out_cache_cont_start=None,
|
185
184
|
out_cache_cont_end=None,
|
185
|
+
top_logprobs_nums=None,
|
186
186
|
return_logprob=False,
|
187
187
|
):
|
188
188
|
batch_size = len(req_pool_indices)
|
@@ -197,15 +197,15 @@ class InputMetadata:
|
|
197
197
|
req_pool_indices[0], seq_lens[0] - 1
|
198
198
|
].item()
|
199
199
|
else:
|
200
|
-
|
201
|
-
|
202
|
-
|
200
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
201
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
202
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
203
203
|
positions = torch.tensor(
|
204
204
|
np.concatenate(
|
205
205
|
[
|
206
206
|
np.arange(
|
207
|
-
|
208
|
-
|
207
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
208
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
209
209
|
)
|
210
210
|
for i in range(batch_size)
|
211
211
|
],
|
@@ -231,8 +231,9 @@ class InputMetadata:
|
|
231
231
|
out_cache_loc=out_cache_loc,
|
232
232
|
out_cache_cont_start=out_cache_cont_start,
|
233
233
|
out_cache_cont_end=out_cache_cont_end,
|
234
|
-
return_logprob=return_logprob,
|
235
234
|
other_kv_index=other_kv_index,
|
235
|
+
return_logprob=return_logprob,
|
236
|
+
top_logprobs_nums=top_logprobs_nums,
|
236
237
|
)
|
237
238
|
|
238
239
|
if forward_mode == ForwardMode.EXTEND:
|
@@ -276,9 +277,6 @@ class ModelRunner:
|
|
276
277
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
277
278
|
)
|
278
279
|
|
279
|
-
# A small all_reduce for warmup.
|
280
|
-
if self.tp_size > 1:
|
281
|
-
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
282
280
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
283
281
|
|
284
282
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -297,31 +295,33 @@ class ModelRunner:
|
|
297
295
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
298
296
|
|
299
297
|
# Load weights
|
300
|
-
|
301
|
-
|
298
|
+
quant_config = None
|
299
|
+
|
300
|
+
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
301
|
+
if quant_cfg is not None:
|
302
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
303
|
+
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
304
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
305
|
+
is_format_marlin = quant_cfg.get(
|
306
|
+
"checkpoint_format"
|
307
|
+
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
308
|
+
|
309
|
+
# Use marlin if the GPTQ model is serialized in marlin format.
|
310
|
+
if quant_method == "gptq" and is_format_marlin:
|
311
|
+
quant_method = "marlin"
|
312
|
+
|
313
|
+
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
314
|
+
|
315
|
+
if quant_config_class is None:
|
316
|
+
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
317
|
+
|
318
|
+
quant_config = quant_config_class.from_config(quant_cfg)
|
319
|
+
logger.info(f"quant_config: {quant_config}")
|
320
|
+
|
321
|
+
with set_default_torch_dtype(torch.float16):
|
302
322
|
with torch.device("cuda"):
|
303
|
-
hf_quant_config = getattr(
|
304
|
-
self.model_config.hf_config, "quantization_config", None
|
305
|
-
)
|
306
|
-
if hf_quant_config is not None:
|
307
|
-
hf_quant_method = hf_quant_config["quant_method"]
|
308
|
-
|
309
|
-
# compat: autogptq uses is_marlin_format within quant config
|
310
|
-
if (hf_quant_method == "gptq"
|
311
|
-
and "is_marlin_format" in hf_quant_config
|
312
|
-
and hf_quant_config["is_marlin_format"]):
|
313
|
-
hf_quant_method = "marlin"
|
314
|
-
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
315
|
-
|
316
|
-
if quant_config_class is None:
|
317
|
-
raise ValueError(
|
318
|
-
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
319
|
-
)
|
320
|
-
quant_config = quant_config_class.from_config(hf_quant_config)
|
321
|
-
logger.info(f"quant_config: {quant_config}")
|
322
|
-
linear_method = quant_config.get_linear_method()
|
323
323
|
model = model_class(
|
324
|
-
config=self.model_config.hf_config,
|
324
|
+
config=self.model_config.hf_config, quant_config=quant_config
|
325
325
|
)
|
326
326
|
model.load_weights(
|
327
327
|
self.model_config.path,
|
@@ -367,148 +367,92 @@ class ModelRunner:
|
|
367
367
|
)
|
368
368
|
|
369
369
|
@torch.inference_mode()
|
370
|
-
def forward_prefill(
|
371
|
-
self,
|
372
|
-
input_ids,
|
373
|
-
req_pool_indices,
|
374
|
-
seq_lens,
|
375
|
-
prefix_lens,
|
376
|
-
position_ids_offsets,
|
377
|
-
out_cache_loc,
|
378
|
-
return_logprob,
|
379
|
-
):
|
370
|
+
def forward_prefill(self, batch: Batch):
|
380
371
|
input_metadata = InputMetadata.create(
|
381
372
|
self,
|
382
373
|
forward_mode=ForwardMode.PREFILL,
|
383
374
|
tp_size=self.tp_size,
|
384
|
-
req_pool_indices=req_pool_indices,
|
385
|
-
seq_lens=seq_lens,
|
386
|
-
prefix_lens=prefix_lens,
|
387
|
-
position_ids_offsets=position_ids_offsets,
|
388
|
-
out_cache_loc=out_cache_loc,
|
389
|
-
|
375
|
+
req_pool_indices=batch.req_pool_indices,
|
376
|
+
seq_lens=batch.seq_lens,
|
377
|
+
prefix_lens=batch.prefix_lens,
|
378
|
+
position_ids_offsets=batch.position_ids_offsets,
|
379
|
+
out_cache_loc=batch.out_cache_loc,
|
380
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
381
|
+
return_logprob=batch.return_logprob,
|
382
|
+
)
|
383
|
+
return self.model.forward(
|
384
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
390
385
|
)
|
391
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
392
386
|
|
393
387
|
@torch.inference_mode()
|
394
|
-
def forward_extend(
|
395
|
-
self,
|
396
|
-
input_ids,
|
397
|
-
req_pool_indices,
|
398
|
-
seq_lens,
|
399
|
-
prefix_lens,
|
400
|
-
position_ids_offsets,
|
401
|
-
out_cache_loc,
|
402
|
-
return_logprob,
|
403
|
-
):
|
388
|
+
def forward_extend(self, batch: Batch):
|
404
389
|
input_metadata = InputMetadata.create(
|
405
390
|
self,
|
406
391
|
forward_mode=ForwardMode.EXTEND,
|
407
392
|
tp_size=self.tp_size,
|
408
|
-
req_pool_indices=req_pool_indices,
|
409
|
-
seq_lens=seq_lens,
|
410
|
-
prefix_lens=prefix_lens,
|
411
|
-
position_ids_offsets=position_ids_offsets,
|
412
|
-
out_cache_loc=out_cache_loc,
|
413
|
-
|
393
|
+
req_pool_indices=batch.req_pool_indices,
|
394
|
+
seq_lens=batch.seq_lens,
|
395
|
+
prefix_lens=batch.prefix_lens,
|
396
|
+
position_ids_offsets=batch.position_ids_offsets,
|
397
|
+
out_cache_loc=batch.out_cache_loc,
|
398
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
399
|
+
return_logprob=batch.return_logprob,
|
400
|
+
)
|
401
|
+
return self.model.forward(
|
402
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
414
403
|
)
|
415
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
416
404
|
|
417
405
|
@torch.inference_mode()
|
418
|
-
def forward_decode(
|
419
|
-
self,
|
420
|
-
input_ids,
|
421
|
-
req_pool_indices,
|
422
|
-
seq_lens,
|
423
|
-
prefix_lens,
|
424
|
-
position_ids_offsets,
|
425
|
-
out_cache_loc,
|
426
|
-
out_cache_cont_start,
|
427
|
-
out_cache_cont_end,
|
428
|
-
return_logprob,
|
429
|
-
):
|
406
|
+
def forward_decode(self, batch: Batch):
|
430
407
|
input_metadata = InputMetadata.create(
|
431
408
|
self,
|
432
409
|
forward_mode=ForwardMode.DECODE,
|
433
410
|
tp_size=self.tp_size,
|
434
|
-
req_pool_indices=req_pool_indices,
|
435
|
-
seq_lens=seq_lens,
|
436
|
-
prefix_lens=prefix_lens,
|
437
|
-
position_ids_offsets=position_ids_offsets,
|
438
|
-
out_cache_loc=out_cache_loc,
|
439
|
-
out_cache_cont_start=out_cache_cont_start,
|
440
|
-
out_cache_cont_end=out_cache_cont_end,
|
441
|
-
|
411
|
+
req_pool_indices=batch.req_pool_indices,
|
412
|
+
seq_lens=batch.seq_lens,
|
413
|
+
prefix_lens=batch.prefix_lens,
|
414
|
+
position_ids_offsets=batch.position_ids_offsets,
|
415
|
+
out_cache_loc=batch.out_cache_loc,
|
416
|
+
out_cache_cont_start=batch.out_cache_cont_start,
|
417
|
+
out_cache_cont_end=batch.out_cache_cont_end,
|
418
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
419
|
+
return_logprob=batch.return_logprob,
|
420
|
+
)
|
421
|
+
return self.model.forward(
|
422
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
442
423
|
)
|
443
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
444
424
|
|
445
425
|
@torch.inference_mode()
|
446
|
-
def forward_extend_multi_modal(
|
447
|
-
self,
|
448
|
-
input_ids,
|
449
|
-
pixel_values,
|
450
|
-
image_sizes,
|
451
|
-
image_offsets,
|
452
|
-
req_pool_indices,
|
453
|
-
seq_lens,
|
454
|
-
prefix_lens,
|
455
|
-
position_ids_offsets,
|
456
|
-
out_cache_loc,
|
457
|
-
return_logprob,
|
458
|
-
):
|
426
|
+
def forward_extend_multi_modal(self, batch: Batch):
|
459
427
|
input_metadata = InputMetadata.create(
|
460
428
|
self,
|
461
429
|
forward_mode=ForwardMode.EXTEND,
|
462
430
|
tp_size=self.tp_size,
|
463
|
-
req_pool_indices=req_pool_indices,
|
464
|
-
seq_lens=seq_lens,
|
465
|
-
prefix_lens=prefix_lens,
|
466
|
-
position_ids_offsets=position_ids_offsets,
|
467
|
-
out_cache_loc=out_cache_loc,
|
468
|
-
|
431
|
+
req_pool_indices=batch.req_pool_indices,
|
432
|
+
seq_lens=batch.seq_lens,
|
433
|
+
prefix_lens=batch.prefix_lens,
|
434
|
+
position_ids_offsets=batch.position_ids_offsets,
|
435
|
+
out_cache_loc=batch.out_cache_loc,
|
436
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
437
|
+
return_logprob=batch.return_logprob,
|
469
438
|
)
|
470
439
|
return self.model.forward(
|
471
|
-
input_ids,
|
440
|
+
batch.input_ids,
|
472
441
|
input_metadata.positions,
|
473
442
|
input_metadata,
|
474
|
-
pixel_values,
|
475
|
-
image_sizes,
|
476
|
-
image_offsets,
|
443
|
+
batch.pixel_values,
|
444
|
+
batch.image_sizes,
|
445
|
+
batch.image_offsets,
|
477
446
|
)
|
478
447
|
|
479
|
-
def forward(self, batch: Batch, forward_mode: ForwardMode
|
448
|
+
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
480
449
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
"image_sizes": batch.image_sizes,
|
485
|
-
"image_offsets": batch.image_offsets,
|
486
|
-
"req_pool_indices": batch.req_pool_indices,
|
487
|
-
"seq_lens": batch.seq_lens,
|
488
|
-
"prefix_lens": batch.prefix_lens,
|
489
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
490
|
-
"out_cache_loc": batch.out_cache_loc,
|
491
|
-
"return_logprob": return_logprob,
|
492
|
-
}
|
493
|
-
return self.forward_extend_multi_modal(**kwargs)
|
494
|
-
else:
|
495
|
-
kwargs = {
|
496
|
-
"input_ids": batch.input_ids,
|
497
|
-
"req_pool_indices": batch.req_pool_indices,
|
498
|
-
"seq_lens": batch.seq_lens,
|
499
|
-
"prefix_lens": batch.prefix_lens,
|
500
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
501
|
-
"out_cache_loc": batch.out_cache_loc,
|
502
|
-
"return_logprob": return_logprob,
|
503
|
-
}
|
504
|
-
|
505
|
-
if forward_mode == ForwardMode.DECODE:
|
506
|
-
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
507
|
-
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
508
|
-
return self.forward_decode(**kwargs)
|
450
|
+
return self.forward_extend_multi_modal(batch)
|
451
|
+
elif forward_mode == ForwardMode.DECODE:
|
452
|
+
return self.forward_decode(batch)
|
509
453
|
elif forward_mode == ForwardMode.EXTEND:
|
510
|
-
return self.forward_extend(
|
454
|
+
return self.forward_extend(batch)
|
511
455
|
elif forward_mode == ForwardMode.PREFILL:
|
512
|
-
return self.forward_prefill(
|
456
|
+
return self.forward_prefill(batch)
|
513
457
|
else:
|
514
458
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import heapq
|
2
2
|
import time
|
3
3
|
from collections import defaultdict
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Tuple
|
6
4
|
|
7
5
|
import torch
|
8
6
|
|
@@ -11,34 +9,38 @@ class TreeNode:
|
|
11
9
|
def __init__(self):
|
12
10
|
self.children = defaultdict(TreeNode)
|
13
11
|
self.parent = None
|
12
|
+
self.key = None
|
14
13
|
self.value = None
|
15
|
-
self.
|
14
|
+
self.lock_ref = 0
|
16
15
|
self.last_access_time = time.time()
|
17
16
|
|
18
|
-
def __lt__(self, other):
|
17
|
+
def __lt__(self, other: "TreeNode"):
|
19
18
|
return self.last_access_time < other.last_access_time
|
20
19
|
|
21
20
|
|
22
|
-
def
|
21
|
+
def _key_match(key0, key1):
|
23
22
|
i = 0
|
24
|
-
for
|
25
|
-
if
|
23
|
+
for k0, k1 in zip(key0, key1):
|
24
|
+
if k0 != k1:
|
26
25
|
break
|
27
26
|
i += 1
|
28
27
|
return i
|
29
28
|
|
30
29
|
|
31
30
|
class RadixCache:
|
32
|
-
def __init__(self, disable=False):
|
33
|
-
self.
|
31
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
32
|
+
self.req_to_token_pool = req_to_token_pool
|
33
|
+
self.token_to_kv_pool = token_to_kv_pool
|
34
34
|
self.disable = disable
|
35
|
+
self.reset()
|
35
36
|
|
36
37
|
##### Public API #####
|
37
38
|
|
38
39
|
def reset(self):
|
39
40
|
self.root_node = TreeNode()
|
41
|
+
self.root_node.key = []
|
40
42
|
self.root_node.value = []
|
41
|
-
self.root_node.
|
43
|
+
self.root_node.lock_ref = 1
|
42
44
|
self.evictable_size_ = 0
|
43
45
|
|
44
46
|
def match_prefix(self, key):
|
@@ -50,6 +52,8 @@ class RadixCache:
|
|
50
52
|
self._match_prefix_helper(self.root_node, key, value, last_node)
|
51
53
|
if value:
|
52
54
|
value = torch.concat(value)
|
55
|
+
else:
|
56
|
+
value = torch.tensor([], dtype=torch.int64)
|
53
57
|
return value, last_node[0]
|
54
58
|
|
55
59
|
def insert(self, key, value=None):
|
@@ -60,6 +64,34 @@ class RadixCache:
|
|
60
64
|
value = [x for x in key]
|
61
65
|
return self._insert_helper(self.root_node, key, value)
|
62
66
|
|
67
|
+
def cache_req(
|
68
|
+
self,
|
69
|
+
token_ids,
|
70
|
+
last_uncached_pos,
|
71
|
+
req_pool_idx,
|
72
|
+
del_in_memory_pool=True,
|
73
|
+
old_last_node=None,
|
74
|
+
):
|
75
|
+
# Insert the request into radix cache
|
76
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
|
+
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
|
+
|
79
|
+
# Radix Cache takes one ref in memory pool
|
80
|
+
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
81
|
+
|
82
|
+
if del_in_memory_pool:
|
83
|
+
self.req_to_token_pool.free(req_pool_idx)
|
84
|
+
else:
|
85
|
+
cached_indices, new_last_node = self.match_prefix(token_ids)
|
86
|
+
assert len(cached_indices) == len(token_ids)
|
87
|
+
|
88
|
+
self.req_to_token_pool.req_to_token[
|
89
|
+
req_pool_idx, last_uncached_pos : len(cached_indices)
|
90
|
+
] = cached_indices[last_uncached_pos:]
|
91
|
+
self.dec_lock_ref(old_last_node)
|
92
|
+
self.inc_lock_ref(new_last_node)
|
93
|
+
return cached_indices, new_last_node
|
94
|
+
|
63
95
|
def pretty_print(self):
|
64
96
|
self._print_helper(self.root_node, 0)
|
65
97
|
print(f"#tokens: {self.total_size()}")
|
@@ -69,7 +101,7 @@ class RadixCache:
|
|
69
101
|
|
70
102
|
def evict(self, num_tokens, evict_callback):
|
71
103
|
if self.disable:
|
72
|
-
|
104
|
+
return
|
73
105
|
|
74
106
|
leaves = self._collect_leaves()
|
75
107
|
heapq.heapify(leaves)
|
@@ -80,7 +112,7 @@ class RadixCache:
|
|
80
112
|
|
81
113
|
if x == self.root_node:
|
82
114
|
break
|
83
|
-
if x.
|
115
|
+
if x.lock_ref > 0:
|
84
116
|
continue
|
85
117
|
|
86
118
|
num_evicted += evict_callback(x.value)
|
@@ -89,23 +121,23 @@ class RadixCache:
|
|
89
121
|
if len(x.parent.children) == 0:
|
90
122
|
heapq.heappush(leaves, x.parent)
|
91
123
|
|
92
|
-
def
|
124
|
+
def inc_lock_ref(self, node: TreeNode):
|
93
125
|
delta = 0
|
94
126
|
while node != self.root_node:
|
95
|
-
if node.
|
127
|
+
if node.lock_ref == 0:
|
96
128
|
self.evictable_size_ -= len(node.value)
|
97
129
|
delta -= len(node.value)
|
98
|
-
node.
|
130
|
+
node.lock_ref += 1
|
99
131
|
node = node.parent
|
100
132
|
return delta
|
101
133
|
|
102
|
-
def
|
134
|
+
def dec_lock_ref(self, node: TreeNode):
|
103
135
|
delta = 0
|
104
136
|
while node != self.root_node:
|
105
|
-
if node.
|
137
|
+
if node.lock_ref == 1:
|
106
138
|
self.evictable_size_ += len(node.value)
|
107
139
|
delta += len(node.value)
|
108
|
-
node.
|
140
|
+
node.lock_ref -= 1
|
109
141
|
node = node.parent
|
110
142
|
return delta
|
111
143
|
|
@@ -113,42 +145,48 @@ class RadixCache:
|
|
113
145
|
return self.evictable_size_
|
114
146
|
|
115
147
|
##### Internal Helper Functions #####
|
148
|
+
|
116
149
|
def _match_prefix_helper(self, node, key, value, last_node):
|
117
150
|
node.last_access_time = time.time()
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
151
|
+
if len(key) == 0:
|
152
|
+
return
|
153
|
+
|
154
|
+
if key[0] in node.children.keys():
|
155
|
+
child = node.children[key[0]]
|
156
|
+
prefix_len = _key_match(child.key, key)
|
157
|
+
if prefix_len < len(child.key):
|
158
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
159
|
+
value.append(new_node.value)
|
160
|
+
last_node[0] = new_node
|
161
|
+
else:
|
162
|
+
value.append(child.value)
|
163
|
+
last_node[0] = child
|
164
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
165
|
+
|
166
|
+
def _split_node(self, key, child: TreeNode, split_len):
|
133
167
|
# new_node -> child
|
134
168
|
new_node = TreeNode()
|
135
|
-
new_node.children = {key[split_len:]: child}
|
169
|
+
new_node.children = {key[split_len:][0]: child}
|
136
170
|
new_node.parent = child.parent
|
137
|
-
new_node.
|
171
|
+
new_node.lock_ref = child.lock_ref
|
172
|
+
new_node.key = child.key[:split_len]
|
138
173
|
new_node.value = child.value[:split_len]
|
139
174
|
child.parent = new_node
|
175
|
+
child.key = child.key[split_len:]
|
140
176
|
child.value = child.value[split_len:]
|
141
|
-
new_node.parent.children[key[:split_len]] = new_node
|
142
|
-
del new_node.parent.children[key]
|
177
|
+
new_node.parent.children[key[:split_len][0]] = new_node
|
143
178
|
return new_node
|
144
179
|
|
145
180
|
def _insert_helper(self, node, key, value):
|
146
181
|
node.last_access_time = time.time()
|
182
|
+
if len(key) == 0:
|
183
|
+
return 0
|
147
184
|
|
148
|
-
|
149
|
-
|
185
|
+
if key[0] in node.children.keys():
|
186
|
+
child = node.children[key[0]]
|
187
|
+
prefix_len = _key_match(child.key, key)
|
150
188
|
|
151
|
-
if prefix_len == len(
|
189
|
+
if prefix_len == len(child.key):
|
152
190
|
if prefix_len == len(key):
|
153
191
|
return prefix_len
|
154
192
|
else:
|
@@ -156,23 +194,23 @@ class RadixCache:
|
|
156
194
|
value = value[prefix_len:]
|
157
195
|
return prefix_len + self._insert_helper(child, key, value)
|
158
196
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
)
|
197
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
198
|
+
return prefix_len + self._insert_helper(
|
199
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
200
|
+
)
|
164
201
|
|
165
202
|
if len(key):
|
166
203
|
new_node = TreeNode()
|
167
204
|
new_node.parent = node
|
205
|
+
new_node.key = key
|
168
206
|
new_node.value = value
|
169
|
-
node.children[key] = new_node
|
207
|
+
node.children[key[0]] = new_node
|
170
208
|
self.evictable_size_ += len(value)
|
171
209
|
return 0
|
172
210
|
|
173
|
-
def _print_helper(self, node, indent):
|
174
|
-
for
|
175
|
-
print(" " * indent, len(key), key[:10], f"r={child.
|
211
|
+
def _print_helper(self, node: TreeNode, indent):
|
212
|
+
for _, child in node.children.items():
|
213
|
+
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
176
214
|
self._print_helper(child, indent=indent + 2)
|
177
215
|
|
178
216
|
def _delete_leaf(self, node):
|
@@ -180,7 +218,7 @@ class RadixCache:
|
|
180
218
|
if v == node:
|
181
219
|
break
|
182
220
|
del node.parent.children[k]
|
183
|
-
self.evictable_size_ -= len(
|
221
|
+
self.evictable_size_ -= len(node.key)
|
184
222
|
|
185
223
|
def _total_size_helper(self, node):
|
186
224
|
x = len(node.value)
|
@@ -203,7 +241,7 @@ class RadixCache:
|
|
203
241
|
|
204
242
|
|
205
243
|
if __name__ == "__main__":
|
206
|
-
tree = RadixCache(
|
244
|
+
tree = RadixCache(None, None, False)
|
207
245
|
|
208
246
|
tree.insert("Hello")
|
209
247
|
tree.insert("Hello")
|