sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +111 -167
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +20 -21
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -447
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
- sglang-0.1.15.dist-info/RECORD +69 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,35 @@
|
|
1
1
|
import importlib
|
2
|
-
import
|
2
|
+
import importlib.resources
|
3
3
|
import inspect
|
4
|
+
import logging
|
5
|
+
import pkgutil
|
4
6
|
from dataclasses import dataclass
|
5
7
|
from functools import lru_cache
|
6
|
-
from
|
7
|
-
import importlib.resources
|
8
|
+
from typing import List
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
11
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
12
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
13
|
-
from sglang.srt.utils import is_multimodal_model
|
14
|
-
from sglang.utils import get_available_gpu_memory
|
15
12
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
16
13
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
17
14
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
18
|
-
from vllm.model_executor.model_loader import
|
19
|
-
from vllm.
|
20
|
-
|
21
|
-
import importlib
|
22
|
-
import pkgutil
|
15
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
16
|
+
from vllm.distributed import initialize_model_parallel
|
23
17
|
|
24
|
-
import
|
18
|
+
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
|
+
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
+
from sglang.srt.utils import is_multimodal_model
|
21
|
+
from sglang.utils import get_available_gpu_memory
|
25
22
|
|
26
|
-
|
23
|
+
QUANTIZATION_CONFIG_MAPPING = {
|
24
|
+
"awq": AWQConfig,
|
25
|
+
"gptq": GPTQConfig,
|
26
|
+
"marlin": MarlinConfig,
|
27
|
+
}
|
27
28
|
|
28
29
|
logger = logging.getLogger("model_runner")
|
29
30
|
|
30
|
-
|
31
31
|
# for server args in model endpoints
|
32
|
-
global_server_args_dict
|
32
|
+
global_server_args_dict = {}
|
33
33
|
|
34
34
|
|
35
35
|
@lru_cache()
|
@@ -37,7 +37,7 @@ def import_model_classes():
|
|
37
37
|
model_arch_name_to_cls = {}
|
38
38
|
package_name = "sglang.srt.models"
|
39
39
|
package = importlib.import_module(package_name)
|
40
|
-
for
|
40
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
41
|
if not ispkg:
|
42
42
|
module = importlib.import_module(name)
|
43
43
|
if hasattr(module, "EntryClass"):
|
@@ -87,6 +87,7 @@ class InputMetadata:
|
|
87
87
|
|
88
88
|
other_kv_index: torch.Tensor = None
|
89
89
|
return_logprob: bool = False
|
90
|
+
top_logprobs_nums: List[int] = None
|
90
91
|
|
91
92
|
# for flashinfer
|
92
93
|
qo_indptr: torch.Tensor = None
|
@@ -106,18 +107,20 @@ class InputMetadata:
|
|
106
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
107
108
|
)
|
108
109
|
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
110
|
+
self.kv_last_page_len = torch.ones(
|
111
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
114
|
+
seq_lens_cpu = self.seq_lens.tolist()
|
109
115
|
self.kv_indices = torch.cat(
|
110
116
|
[
|
111
117
|
self.req_to_token_pool.req_to_token[
|
112
|
-
|
118
|
+
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
113
119
|
]
|
114
120
|
for i in range(self.batch_size)
|
115
121
|
],
|
116
122
|
dim=0,
|
117
123
|
).contiguous()
|
118
|
-
self.kv_last_page_len = torch.ones(
|
119
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
120
|
-
)
|
121
124
|
|
122
125
|
workspace_buffer = torch.empty(
|
123
126
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
@@ -140,13 +143,9 @@ class InputMetadata:
|
|
140
143
|
self.kv_last_page_len,
|
141
144
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
142
145
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
146
|
+
self.model_runner.model_config.head_dim
|
143
147
|
]
|
144
148
|
|
145
|
-
# flashinfer >= 0.0.3
|
146
|
-
# FIXME: Drop this when flashinfer updates to 0.0.4
|
147
|
-
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
148
|
-
args.append(self.model_runner.model_config.head_dim)
|
149
|
-
|
150
149
|
self.prefill_wrapper.begin_forward(*args)
|
151
150
|
else:
|
152
151
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
@@ -183,6 +182,7 @@ class InputMetadata:
|
|
183
182
|
out_cache_loc,
|
184
183
|
out_cache_cont_start=None,
|
185
184
|
out_cache_cont_end=None,
|
185
|
+
top_logprobs_nums=None,
|
186
186
|
return_logprob=False,
|
187
187
|
):
|
188
188
|
batch_size = len(req_pool_indices)
|
@@ -197,15 +197,15 @@ class InputMetadata:
|
|
197
197
|
req_pool_indices[0], seq_lens[0] - 1
|
198
198
|
].item()
|
199
199
|
else:
|
200
|
-
|
201
|
-
|
202
|
-
|
200
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
201
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
202
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
203
203
|
positions = torch.tensor(
|
204
204
|
np.concatenate(
|
205
205
|
[
|
206
206
|
np.arange(
|
207
|
-
|
208
|
-
|
207
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
208
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
209
209
|
)
|
210
210
|
for i in range(batch_size)
|
211
211
|
],
|
@@ -231,8 +231,9 @@ class InputMetadata:
|
|
231
231
|
out_cache_loc=out_cache_loc,
|
232
232
|
out_cache_cont_start=out_cache_cont_start,
|
233
233
|
out_cache_cont_end=out_cache_cont_end,
|
234
|
-
return_logprob=return_logprob,
|
235
234
|
other_kv_index=other_kv_index,
|
235
|
+
return_logprob=return_logprob,
|
236
|
+
top_logprobs_nums=top_logprobs_nums,
|
236
237
|
)
|
237
238
|
|
238
239
|
if forward_mode == ForwardMode.EXTEND:
|
@@ -276,9 +277,6 @@ class ModelRunner:
|
|
276
277
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
277
278
|
)
|
278
279
|
|
279
|
-
# A small all_reduce for warmup.
|
280
|
-
if self.tp_size > 1:
|
281
|
-
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
282
280
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
283
281
|
|
284
282
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -297,31 +295,33 @@ class ModelRunner:
|
|
297
295
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
298
296
|
|
299
297
|
# Load weights
|
300
|
-
|
301
|
-
|
298
|
+
quant_config = None
|
299
|
+
|
300
|
+
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
301
|
+
if quant_cfg is not None:
|
302
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
303
|
+
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
304
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
305
|
+
is_format_marlin = quant_cfg.get(
|
306
|
+
"checkpoint_format"
|
307
|
+
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
308
|
+
|
309
|
+
# Use marlin if the GPTQ model is serialized in marlin format.
|
310
|
+
if quant_method == "gptq" and is_format_marlin:
|
311
|
+
quant_method = "marlin"
|
312
|
+
|
313
|
+
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
314
|
+
|
315
|
+
if quant_config_class is None:
|
316
|
+
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
317
|
+
|
318
|
+
quant_config = quant_config_class.from_config(quant_cfg)
|
319
|
+
logger.info(f"quant_config: {quant_config}")
|
320
|
+
|
321
|
+
with set_default_torch_dtype(torch.float16):
|
302
322
|
with torch.device("cuda"):
|
303
|
-
hf_quant_config = getattr(
|
304
|
-
self.model_config.hf_config, "quantization_config", None
|
305
|
-
)
|
306
|
-
if hf_quant_config is not None:
|
307
|
-
hf_quant_method = hf_quant_config["quant_method"]
|
308
|
-
|
309
|
-
# compat: autogptq uses is_marlin_format within quant config
|
310
|
-
if (hf_quant_method == "gptq"
|
311
|
-
and "is_marlin_format" in hf_quant_config
|
312
|
-
and hf_quant_config["is_marlin_format"]):
|
313
|
-
hf_quant_method = "marlin"
|
314
|
-
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
315
|
-
|
316
|
-
if quant_config_class is None:
|
317
|
-
raise ValueError(
|
318
|
-
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
319
|
-
)
|
320
|
-
quant_config = quant_config_class.from_config(hf_quant_config)
|
321
|
-
logger.info(f"quant_config: {quant_config}")
|
322
|
-
linear_method = quant_config.get_linear_method()
|
323
323
|
model = model_class(
|
324
|
-
config=self.model_config.hf_config,
|
324
|
+
config=self.model_config.hf_config, quant_config=quant_config
|
325
325
|
)
|
326
326
|
model.load_weights(
|
327
327
|
self.model_config.path,
|
@@ -367,148 +367,92 @@ class ModelRunner:
|
|
367
367
|
)
|
368
368
|
|
369
369
|
@torch.inference_mode()
|
370
|
-
def forward_prefill(
|
371
|
-
self,
|
372
|
-
input_ids,
|
373
|
-
req_pool_indices,
|
374
|
-
seq_lens,
|
375
|
-
prefix_lens,
|
376
|
-
position_ids_offsets,
|
377
|
-
out_cache_loc,
|
378
|
-
return_logprob,
|
379
|
-
):
|
370
|
+
def forward_prefill(self, batch: Batch):
|
380
371
|
input_metadata = InputMetadata.create(
|
381
372
|
self,
|
382
373
|
forward_mode=ForwardMode.PREFILL,
|
383
374
|
tp_size=self.tp_size,
|
384
|
-
req_pool_indices=req_pool_indices,
|
385
|
-
seq_lens=seq_lens,
|
386
|
-
prefix_lens=prefix_lens,
|
387
|
-
position_ids_offsets=position_ids_offsets,
|
388
|
-
out_cache_loc=out_cache_loc,
|
389
|
-
|
375
|
+
req_pool_indices=batch.req_pool_indices,
|
376
|
+
seq_lens=batch.seq_lens,
|
377
|
+
prefix_lens=batch.prefix_lens,
|
378
|
+
position_ids_offsets=batch.position_ids_offsets,
|
379
|
+
out_cache_loc=batch.out_cache_loc,
|
380
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
381
|
+
return_logprob=batch.return_logprob,
|
382
|
+
)
|
383
|
+
return self.model.forward(
|
384
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
390
385
|
)
|
391
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
392
386
|
|
393
387
|
@torch.inference_mode()
|
394
|
-
def forward_extend(
|
395
|
-
self,
|
396
|
-
input_ids,
|
397
|
-
req_pool_indices,
|
398
|
-
seq_lens,
|
399
|
-
prefix_lens,
|
400
|
-
position_ids_offsets,
|
401
|
-
out_cache_loc,
|
402
|
-
return_logprob,
|
403
|
-
):
|
388
|
+
def forward_extend(self, batch: Batch):
|
404
389
|
input_metadata = InputMetadata.create(
|
405
390
|
self,
|
406
391
|
forward_mode=ForwardMode.EXTEND,
|
407
392
|
tp_size=self.tp_size,
|
408
|
-
req_pool_indices=req_pool_indices,
|
409
|
-
seq_lens=seq_lens,
|
410
|
-
prefix_lens=prefix_lens,
|
411
|
-
position_ids_offsets=position_ids_offsets,
|
412
|
-
out_cache_loc=out_cache_loc,
|
413
|
-
|
393
|
+
req_pool_indices=batch.req_pool_indices,
|
394
|
+
seq_lens=batch.seq_lens,
|
395
|
+
prefix_lens=batch.prefix_lens,
|
396
|
+
position_ids_offsets=batch.position_ids_offsets,
|
397
|
+
out_cache_loc=batch.out_cache_loc,
|
398
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
399
|
+
return_logprob=batch.return_logprob,
|
400
|
+
)
|
401
|
+
return self.model.forward(
|
402
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
414
403
|
)
|
415
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
416
404
|
|
417
405
|
@torch.inference_mode()
|
418
|
-
def forward_decode(
|
419
|
-
self,
|
420
|
-
input_ids,
|
421
|
-
req_pool_indices,
|
422
|
-
seq_lens,
|
423
|
-
prefix_lens,
|
424
|
-
position_ids_offsets,
|
425
|
-
out_cache_loc,
|
426
|
-
out_cache_cont_start,
|
427
|
-
out_cache_cont_end,
|
428
|
-
return_logprob,
|
429
|
-
):
|
406
|
+
def forward_decode(self, batch: Batch):
|
430
407
|
input_metadata = InputMetadata.create(
|
431
408
|
self,
|
432
409
|
forward_mode=ForwardMode.DECODE,
|
433
410
|
tp_size=self.tp_size,
|
434
|
-
req_pool_indices=req_pool_indices,
|
435
|
-
seq_lens=seq_lens,
|
436
|
-
prefix_lens=prefix_lens,
|
437
|
-
position_ids_offsets=position_ids_offsets,
|
438
|
-
out_cache_loc=out_cache_loc,
|
439
|
-
out_cache_cont_start=out_cache_cont_start,
|
440
|
-
out_cache_cont_end=out_cache_cont_end,
|
441
|
-
|
411
|
+
req_pool_indices=batch.req_pool_indices,
|
412
|
+
seq_lens=batch.seq_lens,
|
413
|
+
prefix_lens=batch.prefix_lens,
|
414
|
+
position_ids_offsets=batch.position_ids_offsets,
|
415
|
+
out_cache_loc=batch.out_cache_loc,
|
416
|
+
out_cache_cont_start=batch.out_cache_cont_start,
|
417
|
+
out_cache_cont_end=batch.out_cache_cont_end,
|
418
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
419
|
+
return_logprob=batch.return_logprob,
|
420
|
+
)
|
421
|
+
return self.model.forward(
|
422
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
442
423
|
)
|
443
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
444
424
|
|
445
425
|
@torch.inference_mode()
|
446
|
-
def forward_extend_multi_modal(
|
447
|
-
self,
|
448
|
-
input_ids,
|
449
|
-
pixel_values,
|
450
|
-
image_sizes,
|
451
|
-
image_offsets,
|
452
|
-
req_pool_indices,
|
453
|
-
seq_lens,
|
454
|
-
prefix_lens,
|
455
|
-
position_ids_offsets,
|
456
|
-
out_cache_loc,
|
457
|
-
return_logprob,
|
458
|
-
):
|
426
|
+
def forward_extend_multi_modal(self, batch: Batch):
|
459
427
|
input_metadata = InputMetadata.create(
|
460
428
|
self,
|
461
429
|
forward_mode=ForwardMode.EXTEND,
|
462
430
|
tp_size=self.tp_size,
|
463
|
-
req_pool_indices=req_pool_indices,
|
464
|
-
seq_lens=seq_lens,
|
465
|
-
prefix_lens=prefix_lens,
|
466
|
-
position_ids_offsets=position_ids_offsets,
|
467
|
-
out_cache_loc=out_cache_loc,
|
468
|
-
|
431
|
+
req_pool_indices=batch.req_pool_indices,
|
432
|
+
seq_lens=batch.seq_lens,
|
433
|
+
prefix_lens=batch.prefix_lens,
|
434
|
+
position_ids_offsets=batch.position_ids_offsets,
|
435
|
+
out_cache_loc=batch.out_cache_loc,
|
436
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
437
|
+
return_logprob=batch.return_logprob,
|
469
438
|
)
|
470
439
|
return self.model.forward(
|
471
|
-
input_ids,
|
440
|
+
batch.input_ids,
|
472
441
|
input_metadata.positions,
|
473
442
|
input_metadata,
|
474
|
-
pixel_values,
|
475
|
-
image_sizes,
|
476
|
-
image_offsets,
|
443
|
+
batch.pixel_values,
|
444
|
+
batch.image_sizes,
|
445
|
+
batch.image_offsets,
|
477
446
|
)
|
478
447
|
|
479
|
-
def forward(self, batch: Batch, forward_mode: ForwardMode
|
448
|
+
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
480
449
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
"image_sizes": batch.image_sizes,
|
485
|
-
"image_offsets": batch.image_offsets,
|
486
|
-
"req_pool_indices": batch.req_pool_indices,
|
487
|
-
"seq_lens": batch.seq_lens,
|
488
|
-
"prefix_lens": batch.prefix_lens,
|
489
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
490
|
-
"out_cache_loc": batch.out_cache_loc,
|
491
|
-
"return_logprob": return_logprob,
|
492
|
-
}
|
493
|
-
return self.forward_extend_multi_modal(**kwargs)
|
494
|
-
else:
|
495
|
-
kwargs = {
|
496
|
-
"input_ids": batch.input_ids,
|
497
|
-
"req_pool_indices": batch.req_pool_indices,
|
498
|
-
"seq_lens": batch.seq_lens,
|
499
|
-
"prefix_lens": batch.prefix_lens,
|
500
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
501
|
-
"out_cache_loc": batch.out_cache_loc,
|
502
|
-
"return_logprob": return_logprob,
|
503
|
-
}
|
504
|
-
|
505
|
-
if forward_mode == ForwardMode.DECODE:
|
506
|
-
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
507
|
-
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
508
|
-
return self.forward_decode(**kwargs)
|
450
|
+
return self.forward_extend_multi_modal(batch)
|
451
|
+
elif forward_mode == ForwardMode.DECODE:
|
452
|
+
return self.forward_decode(batch)
|
509
453
|
elif forward_mode == ForwardMode.EXTEND:
|
510
|
-
return self.forward_extend(
|
454
|
+
return self.forward_extend(batch)
|
511
455
|
elif forward_mode == ForwardMode.PREFILL:
|
512
|
-
return self.forward_prefill(
|
456
|
+
return self.forward_prefill(batch)
|
513
457
|
else:
|
514
458
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import heapq
|
2
2
|
import time
|
3
3
|
from collections import defaultdict
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Tuple
|
6
4
|
|
7
5
|
import torch
|
8
6
|
|
@@ -11,32 +9,34 @@ class TreeNode:
|
|
11
9
|
def __init__(self):
|
12
10
|
self.children = defaultdict(TreeNode)
|
13
11
|
self.parent = None
|
12
|
+
self.key = None
|
14
13
|
self.value = None
|
15
14
|
self.ref_counter = 0
|
16
15
|
self.last_access_time = time.time()
|
17
16
|
|
18
|
-
def __lt__(self, other):
|
17
|
+
def __lt__(self, other: "TreeNode"):
|
19
18
|
return self.last_access_time < other.last_access_time
|
20
19
|
|
21
20
|
|
22
|
-
def
|
21
|
+
def _key_match(key0, key1):
|
23
22
|
i = 0
|
24
|
-
for
|
25
|
-
if
|
23
|
+
for k0, k1 in zip(key0, key1):
|
24
|
+
if k0 != k1:
|
26
25
|
break
|
27
26
|
i += 1
|
28
27
|
return i
|
29
28
|
|
30
29
|
|
31
30
|
class RadixCache:
|
32
|
-
def __init__(self, disable=False):
|
33
|
-
self.reset()
|
31
|
+
def __init__(self, disable: bool = False):
|
34
32
|
self.disable = disable
|
33
|
+
self.reset()
|
35
34
|
|
36
35
|
##### Public API #####
|
37
36
|
|
38
37
|
def reset(self):
|
39
38
|
self.root_node = TreeNode()
|
39
|
+
self.root_node.key = []
|
40
40
|
self.root_node.value = []
|
41
41
|
self.root_node.ref_counter = 1
|
42
42
|
self.evictable_size_ = 0
|
@@ -69,7 +69,7 @@ class RadixCache:
|
|
69
69
|
|
70
70
|
def evict(self, num_tokens, evict_callback):
|
71
71
|
if self.disable:
|
72
|
-
|
72
|
+
return
|
73
73
|
|
74
74
|
leaves = self._collect_leaves()
|
75
75
|
heapq.heapify(leaves)
|
@@ -113,42 +113,48 @@ class RadixCache:
|
|
113
113
|
return self.evictable_size_
|
114
114
|
|
115
115
|
##### Internal Helper Functions #####
|
116
|
+
|
116
117
|
def _match_prefix_helper(self, node, key, value, last_node):
|
117
118
|
node.last_access_time = time.time()
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
119
|
+
if len(key) == 0:
|
120
|
+
return
|
121
|
+
|
122
|
+
if key[0] in node.children.keys():
|
123
|
+
child = node.children[key[0]]
|
124
|
+
prefix_len = _key_match(child.key, key)
|
125
|
+
if prefix_len < len(child.key):
|
126
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
127
|
+
value.append(new_node.value)
|
128
|
+
last_node[0] = new_node
|
129
|
+
else:
|
130
|
+
value.append(child.value)
|
131
|
+
last_node[0] = child
|
132
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
131
133
|
|
132
134
|
def _split_node(self, key, child, split_len):
|
133
135
|
# new_node -> child
|
134
136
|
new_node = TreeNode()
|
135
|
-
new_node.children = {key[split_len:]: child}
|
137
|
+
new_node.children = {key[split_len:][0]: child}
|
136
138
|
new_node.parent = child.parent
|
137
139
|
new_node.ref_counter = child.ref_counter
|
140
|
+
new_node.key = child.key[:split_len]
|
138
141
|
new_node.value = child.value[:split_len]
|
139
142
|
child.parent = new_node
|
143
|
+
child.key = child.key[split_len:]
|
140
144
|
child.value = child.value[split_len:]
|
141
|
-
new_node.parent.children[key[:split_len]] = new_node
|
142
|
-
del new_node.parent.children[key]
|
145
|
+
new_node.parent.children[key[:split_len][0]] = new_node
|
143
146
|
return new_node
|
144
147
|
|
145
148
|
def _insert_helper(self, node, key, value):
|
146
149
|
node.last_access_time = time.time()
|
150
|
+
if len(key) == 0:
|
151
|
+
return 0
|
147
152
|
|
148
|
-
|
149
|
-
|
153
|
+
if key[0] in node.children.keys():
|
154
|
+
child = node.children[key[0]]
|
155
|
+
prefix_len = _key_match(child.key, key)
|
150
156
|
|
151
|
-
if prefix_len == len(
|
157
|
+
if prefix_len == len(child.key):
|
152
158
|
if prefix_len == len(key):
|
153
159
|
return prefix_len
|
154
160
|
else:
|
@@ -156,23 +162,25 @@ class RadixCache:
|
|
156
162
|
value = value[prefix_len:]
|
157
163
|
return prefix_len + self._insert_helper(child, key, value)
|
158
164
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
)
|
165
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
166
|
+
return prefix_len + self._insert_helper(
|
167
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
168
|
+
)
|
164
169
|
|
165
170
|
if len(key):
|
166
171
|
new_node = TreeNode()
|
167
172
|
new_node.parent = node
|
173
|
+
new_node.key = key
|
168
174
|
new_node.value = value
|
169
|
-
node.children[key] = new_node
|
175
|
+
node.children[key[0]] = new_node
|
170
176
|
self.evictable_size_ += len(value)
|
171
177
|
return 0
|
172
178
|
|
173
179
|
def _print_helper(self, node, indent):
|
174
|
-
for
|
175
|
-
print(
|
180
|
+
for _, child in node.children.items():
|
181
|
+
print(
|
182
|
+
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
183
|
+
)
|
176
184
|
self._print_helper(child, indent=indent + 2)
|
177
185
|
|
178
186
|
def _delete_leaf(self, node):
|
@@ -180,7 +188,7 @@ class RadixCache:
|
|
180
188
|
if v == node:
|
181
189
|
break
|
182
190
|
del node.parent.children[k]
|
183
|
-
self.evictable_size_ -= len(
|
191
|
+
self.evictable_size_ -= len(node.key)
|
184
192
|
|
185
193
|
def _total_size_helper(self, node):
|
186
194
|
x = len(node.value)
|
@@ -203,7 +211,7 @@ class RadixCache:
|
|
203
211
|
|
204
212
|
|
205
213
|
if __name__ == "__main__":
|
206
|
-
tree = RadixCache(
|
214
|
+
tree = RadixCache()
|
207
215
|
|
208
216
|
tree.insert("Hello")
|
209
217
|
tree.insert("Hello")
|