sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +33 -13
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/ir.py +1 -1
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +15 -3
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +121 -155
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +292 -0
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -439
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +196 -8
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
- sglang-0.1.15.dist-info/RECORD +69 -0
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
- sglang-0.1.13.dist-info/RECORD +0 -63
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,47 @@
|
|
1
1
|
import importlib
|
2
|
+
import importlib.resources
|
3
|
+
import inspect
|
2
4
|
import logging
|
5
|
+
import pkgutil
|
3
6
|
from dataclasses import dataclass
|
4
7
|
from functools import lru_cache
|
5
|
-
from
|
8
|
+
from typing import List
|
6
9
|
|
7
10
|
import numpy as np
|
8
11
|
import torch
|
12
|
+
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
13
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
14
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
15
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
16
|
+
from vllm.distributed import initialize_model_parallel
|
17
|
+
|
9
18
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
10
19
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
11
20
|
from sglang.srt.utils import is_multimodal_model
|
12
21
|
from sglang.utils import get_available_gpu_memory
|
13
|
-
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
14
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
15
|
-
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
16
|
-
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
17
|
-
|
18
|
-
import sglang
|
19
22
|
|
20
|
-
|
23
|
+
QUANTIZATION_CONFIG_MAPPING = {
|
24
|
+
"awq": AWQConfig,
|
25
|
+
"gptq": GPTQConfig,
|
26
|
+
"marlin": MarlinConfig,
|
27
|
+
}
|
21
28
|
|
22
29
|
logger = logging.getLogger("model_runner")
|
23
30
|
|
24
|
-
|
25
31
|
# for server args in model endpoints
|
26
|
-
global_server_args_dict
|
32
|
+
global_server_args_dict = {}
|
27
33
|
|
28
34
|
|
29
35
|
@lru_cache()
|
30
36
|
def import_model_classes():
|
31
37
|
model_arch_name_to_cls = {}
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
38
|
+
package_name = "sglang.srt.models"
|
39
|
+
package = importlib.import_module(package_name)
|
40
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
|
+
if not ispkg:
|
42
|
+
module = importlib.import_module(name)
|
43
|
+
if hasattr(module, "EntryClass"):
|
44
|
+
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
36
45
|
return model_arch_name_to_cls
|
37
46
|
|
38
47
|
|
@@ -78,6 +87,7 @@ class InputMetadata:
|
|
78
87
|
|
79
88
|
other_kv_index: torch.Tensor = None
|
80
89
|
return_logprob: bool = False
|
90
|
+
top_logprobs_nums: List[int] = None
|
81
91
|
|
82
92
|
# for flashinfer
|
83
93
|
qo_indptr: torch.Tensor = None
|
@@ -97,18 +107,20 @@ class InputMetadata:
|
|
97
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
98
108
|
)
|
99
109
|
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
110
|
+
self.kv_last_page_len = torch.ones(
|
111
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
114
|
+
seq_lens_cpu = self.seq_lens.tolist()
|
100
115
|
self.kv_indices = torch.cat(
|
101
116
|
[
|
102
117
|
self.req_to_token_pool.req_to_token[
|
103
|
-
|
118
|
+
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
104
119
|
]
|
105
120
|
for i in range(self.batch_size)
|
106
121
|
],
|
107
122
|
dim=0,
|
108
123
|
).contiguous()
|
109
|
-
self.kv_last_page_len = torch.ones(
|
110
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
111
|
-
)
|
112
124
|
|
113
125
|
workspace_buffer = torch.empty(
|
114
126
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
@@ -124,14 +136,17 @@ class InputMetadata:
|
|
124
136
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
125
137
|
workspace_buffer, "NHD"
|
126
138
|
)
|
127
|
-
|
139
|
+
args = [
|
128
140
|
self.qo_indptr,
|
129
141
|
self.kv_indptr,
|
130
142
|
self.kv_indices,
|
131
143
|
self.kv_last_page_len,
|
132
144
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
133
145
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
134
|
-
|
146
|
+
self.model_runner.model_config.head_dim
|
147
|
+
]
|
148
|
+
|
149
|
+
self.prefill_wrapper.begin_forward(*args)
|
135
150
|
else:
|
136
151
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
137
152
|
workspace_buffer, "NHD"
|
@@ -167,6 +182,7 @@ class InputMetadata:
|
|
167
182
|
out_cache_loc,
|
168
183
|
out_cache_cont_start=None,
|
169
184
|
out_cache_cont_end=None,
|
185
|
+
top_logprobs_nums=None,
|
170
186
|
return_logprob=False,
|
171
187
|
):
|
172
188
|
batch_size = len(req_pool_indices)
|
@@ -181,15 +197,15 @@ class InputMetadata:
|
|
181
197
|
req_pool_indices[0], seq_lens[0] - 1
|
182
198
|
].item()
|
183
199
|
else:
|
184
|
-
|
185
|
-
|
186
|
-
|
200
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
201
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
202
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
187
203
|
positions = torch.tensor(
|
188
204
|
np.concatenate(
|
189
205
|
[
|
190
206
|
np.arange(
|
191
|
-
|
192
|
-
|
207
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
208
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
193
209
|
)
|
194
210
|
for i in range(batch_size)
|
195
211
|
],
|
@@ -215,8 +231,9 @@ class InputMetadata:
|
|
215
231
|
out_cache_loc=out_cache_loc,
|
216
232
|
out_cache_cont_start=out_cache_cont_start,
|
217
233
|
out_cache_cont_end=out_cache_cont_end,
|
218
|
-
return_logprob=return_logprob,
|
219
234
|
other_kv_index=other_kv_index,
|
235
|
+
return_logprob=return_logprob,
|
236
|
+
top_logprobs_nums=top_logprobs_nums,
|
220
237
|
)
|
221
238
|
|
222
239
|
if forward_mode == ForwardMode.EXTEND:
|
@@ -260,9 +277,6 @@ class ModelRunner:
|
|
260
277
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
261
278
|
)
|
262
279
|
|
263
|
-
# A small all_reduce for warmup.
|
264
|
-
if self.tp_size > 1:
|
265
|
-
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
266
280
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
267
281
|
|
268
282
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -281,25 +295,33 @@ class ModelRunner:
|
|
281
295
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
282
296
|
|
283
297
|
# Load weights
|
284
|
-
|
285
|
-
|
298
|
+
quant_config = None
|
299
|
+
|
300
|
+
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
301
|
+
if quant_cfg is not None:
|
302
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
303
|
+
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
304
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
305
|
+
is_format_marlin = quant_cfg.get(
|
306
|
+
"checkpoint_format"
|
307
|
+
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
308
|
+
|
309
|
+
# Use marlin if the GPTQ model is serialized in marlin format.
|
310
|
+
if quant_method == "gptq" and is_format_marlin:
|
311
|
+
quant_method = "marlin"
|
312
|
+
|
313
|
+
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
314
|
+
|
315
|
+
if quant_config_class is None:
|
316
|
+
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
317
|
+
|
318
|
+
quant_config = quant_config_class.from_config(quant_cfg)
|
319
|
+
logger.info(f"quant_config: {quant_config}")
|
320
|
+
|
321
|
+
with set_default_torch_dtype(torch.float16):
|
286
322
|
with torch.device("cuda"):
|
287
|
-
hf_quant_config = getattr(
|
288
|
-
self.model_config.hf_config, "quantization_config", None
|
289
|
-
)
|
290
|
-
if hf_quant_config is not None:
|
291
|
-
quant_config_class = QUANTIONCONFIG_MAPPING.get(
|
292
|
-
hf_quant_config["quant_method"]
|
293
|
-
)
|
294
|
-
if quant_config_class is None:
|
295
|
-
raise ValueError(
|
296
|
-
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
297
|
-
)
|
298
|
-
quant_config = quant_config_class.from_config(hf_quant_config)
|
299
|
-
logger.info(f"quant_config: {quant_config}")
|
300
|
-
linear_method = quant_config.get_linear_method()
|
301
323
|
model = model_class(
|
302
|
-
config=self.model_config.hf_config,
|
324
|
+
config=self.model_config.hf_config, quant_config=quant_config
|
303
325
|
)
|
304
326
|
model.load_weights(
|
305
327
|
self.model_config.path,
|
@@ -345,148 +367,92 @@ class ModelRunner:
|
|
345
367
|
)
|
346
368
|
|
347
369
|
@torch.inference_mode()
|
348
|
-
def forward_prefill(
|
349
|
-
self,
|
350
|
-
input_ids,
|
351
|
-
req_pool_indices,
|
352
|
-
seq_lens,
|
353
|
-
prefix_lens,
|
354
|
-
position_ids_offsets,
|
355
|
-
out_cache_loc,
|
356
|
-
return_logprob,
|
357
|
-
):
|
370
|
+
def forward_prefill(self, batch: Batch):
|
358
371
|
input_metadata = InputMetadata.create(
|
359
372
|
self,
|
360
373
|
forward_mode=ForwardMode.PREFILL,
|
361
374
|
tp_size=self.tp_size,
|
362
|
-
req_pool_indices=req_pool_indices,
|
363
|
-
seq_lens=seq_lens,
|
364
|
-
prefix_lens=prefix_lens,
|
365
|
-
position_ids_offsets=position_ids_offsets,
|
366
|
-
out_cache_loc=out_cache_loc,
|
367
|
-
|
375
|
+
req_pool_indices=batch.req_pool_indices,
|
376
|
+
seq_lens=batch.seq_lens,
|
377
|
+
prefix_lens=batch.prefix_lens,
|
378
|
+
position_ids_offsets=batch.position_ids_offsets,
|
379
|
+
out_cache_loc=batch.out_cache_loc,
|
380
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
381
|
+
return_logprob=batch.return_logprob,
|
382
|
+
)
|
383
|
+
return self.model.forward(
|
384
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
368
385
|
)
|
369
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
370
386
|
|
371
387
|
@torch.inference_mode()
|
372
|
-
def forward_extend(
|
373
|
-
self,
|
374
|
-
input_ids,
|
375
|
-
req_pool_indices,
|
376
|
-
seq_lens,
|
377
|
-
prefix_lens,
|
378
|
-
position_ids_offsets,
|
379
|
-
out_cache_loc,
|
380
|
-
return_logprob,
|
381
|
-
):
|
388
|
+
def forward_extend(self, batch: Batch):
|
382
389
|
input_metadata = InputMetadata.create(
|
383
390
|
self,
|
384
391
|
forward_mode=ForwardMode.EXTEND,
|
385
392
|
tp_size=self.tp_size,
|
386
|
-
req_pool_indices=req_pool_indices,
|
387
|
-
seq_lens=seq_lens,
|
388
|
-
prefix_lens=prefix_lens,
|
389
|
-
position_ids_offsets=position_ids_offsets,
|
390
|
-
out_cache_loc=out_cache_loc,
|
391
|
-
|
393
|
+
req_pool_indices=batch.req_pool_indices,
|
394
|
+
seq_lens=batch.seq_lens,
|
395
|
+
prefix_lens=batch.prefix_lens,
|
396
|
+
position_ids_offsets=batch.position_ids_offsets,
|
397
|
+
out_cache_loc=batch.out_cache_loc,
|
398
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
399
|
+
return_logprob=batch.return_logprob,
|
400
|
+
)
|
401
|
+
return self.model.forward(
|
402
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
392
403
|
)
|
393
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
394
404
|
|
395
405
|
@torch.inference_mode()
|
396
|
-
def forward_decode(
|
397
|
-
self,
|
398
|
-
input_ids,
|
399
|
-
req_pool_indices,
|
400
|
-
seq_lens,
|
401
|
-
prefix_lens,
|
402
|
-
position_ids_offsets,
|
403
|
-
out_cache_loc,
|
404
|
-
out_cache_cont_start,
|
405
|
-
out_cache_cont_end,
|
406
|
-
return_logprob,
|
407
|
-
):
|
406
|
+
def forward_decode(self, batch: Batch):
|
408
407
|
input_metadata = InputMetadata.create(
|
409
408
|
self,
|
410
409
|
forward_mode=ForwardMode.DECODE,
|
411
410
|
tp_size=self.tp_size,
|
412
|
-
req_pool_indices=req_pool_indices,
|
413
|
-
seq_lens=seq_lens,
|
414
|
-
prefix_lens=prefix_lens,
|
415
|
-
position_ids_offsets=position_ids_offsets,
|
416
|
-
out_cache_loc=out_cache_loc,
|
417
|
-
out_cache_cont_start=out_cache_cont_start,
|
418
|
-
out_cache_cont_end=out_cache_cont_end,
|
419
|
-
|
411
|
+
req_pool_indices=batch.req_pool_indices,
|
412
|
+
seq_lens=batch.seq_lens,
|
413
|
+
prefix_lens=batch.prefix_lens,
|
414
|
+
position_ids_offsets=batch.position_ids_offsets,
|
415
|
+
out_cache_loc=batch.out_cache_loc,
|
416
|
+
out_cache_cont_start=batch.out_cache_cont_start,
|
417
|
+
out_cache_cont_end=batch.out_cache_cont_end,
|
418
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
419
|
+
return_logprob=batch.return_logprob,
|
420
|
+
)
|
421
|
+
return self.model.forward(
|
422
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
420
423
|
)
|
421
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
422
424
|
|
423
425
|
@torch.inference_mode()
|
424
|
-
def forward_extend_multi_modal(
|
425
|
-
self,
|
426
|
-
input_ids,
|
427
|
-
pixel_values,
|
428
|
-
image_sizes,
|
429
|
-
image_offsets,
|
430
|
-
req_pool_indices,
|
431
|
-
seq_lens,
|
432
|
-
prefix_lens,
|
433
|
-
position_ids_offsets,
|
434
|
-
out_cache_loc,
|
435
|
-
return_logprob,
|
436
|
-
):
|
426
|
+
def forward_extend_multi_modal(self, batch: Batch):
|
437
427
|
input_metadata = InputMetadata.create(
|
438
428
|
self,
|
439
429
|
forward_mode=ForwardMode.EXTEND,
|
440
430
|
tp_size=self.tp_size,
|
441
|
-
req_pool_indices=req_pool_indices,
|
442
|
-
seq_lens=seq_lens,
|
443
|
-
prefix_lens=prefix_lens,
|
444
|
-
position_ids_offsets=position_ids_offsets,
|
445
|
-
out_cache_loc=out_cache_loc,
|
446
|
-
|
431
|
+
req_pool_indices=batch.req_pool_indices,
|
432
|
+
seq_lens=batch.seq_lens,
|
433
|
+
prefix_lens=batch.prefix_lens,
|
434
|
+
position_ids_offsets=batch.position_ids_offsets,
|
435
|
+
out_cache_loc=batch.out_cache_loc,
|
436
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
437
|
+
return_logprob=batch.return_logprob,
|
447
438
|
)
|
448
439
|
return self.model.forward(
|
449
|
-
input_ids,
|
440
|
+
batch.input_ids,
|
450
441
|
input_metadata.positions,
|
451
442
|
input_metadata,
|
452
|
-
pixel_values,
|
453
|
-
image_sizes,
|
454
|
-
image_offsets,
|
443
|
+
batch.pixel_values,
|
444
|
+
batch.image_sizes,
|
445
|
+
batch.image_offsets,
|
455
446
|
)
|
456
447
|
|
457
|
-
def forward(self, batch: Batch, forward_mode: ForwardMode
|
448
|
+
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
458
449
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
"image_sizes": batch.image_sizes,
|
463
|
-
"image_offsets": batch.image_offsets,
|
464
|
-
"req_pool_indices": batch.req_pool_indices,
|
465
|
-
"seq_lens": batch.seq_lens,
|
466
|
-
"prefix_lens": batch.prefix_lens,
|
467
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
468
|
-
"out_cache_loc": batch.out_cache_loc,
|
469
|
-
"return_logprob": return_logprob,
|
470
|
-
}
|
471
|
-
return self.forward_extend_multi_modal(**kwargs)
|
472
|
-
else:
|
473
|
-
kwargs = {
|
474
|
-
"input_ids": batch.input_ids,
|
475
|
-
"req_pool_indices": batch.req_pool_indices,
|
476
|
-
"seq_lens": batch.seq_lens,
|
477
|
-
"prefix_lens": batch.prefix_lens,
|
478
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
479
|
-
"out_cache_loc": batch.out_cache_loc,
|
480
|
-
"return_logprob": return_logprob,
|
481
|
-
}
|
482
|
-
|
483
|
-
if forward_mode == ForwardMode.DECODE:
|
484
|
-
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
485
|
-
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
486
|
-
return self.forward_decode(**kwargs)
|
450
|
+
return self.forward_extend_multi_modal(batch)
|
451
|
+
elif forward_mode == ForwardMode.DECODE:
|
452
|
+
return self.forward_decode(batch)
|
487
453
|
elif forward_mode == ForwardMode.EXTEND:
|
488
|
-
return self.forward_extend(
|
454
|
+
return self.forward_extend(batch)
|
489
455
|
elif forward_mode == ForwardMode.PREFILL:
|
490
|
-
return self.forward_prefill(
|
456
|
+
return self.forward_prefill(batch)
|
491
457
|
else:
|
492
458
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import heapq
|
2
2
|
import time
|
3
3
|
from collections import defaultdict
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Tuple
|
6
4
|
|
7
5
|
import torch
|
8
6
|
|
@@ -11,32 +9,34 @@ class TreeNode:
|
|
11
9
|
def __init__(self):
|
12
10
|
self.children = defaultdict(TreeNode)
|
13
11
|
self.parent = None
|
12
|
+
self.key = None
|
14
13
|
self.value = None
|
15
14
|
self.ref_counter = 0
|
16
15
|
self.last_access_time = time.time()
|
17
16
|
|
18
|
-
def __lt__(self, other):
|
17
|
+
def __lt__(self, other: "TreeNode"):
|
19
18
|
return self.last_access_time < other.last_access_time
|
20
19
|
|
21
20
|
|
22
|
-
def
|
21
|
+
def _key_match(key0, key1):
|
23
22
|
i = 0
|
24
|
-
for
|
25
|
-
if
|
23
|
+
for k0, k1 in zip(key0, key1):
|
24
|
+
if k0 != k1:
|
26
25
|
break
|
27
26
|
i += 1
|
28
27
|
return i
|
29
28
|
|
30
29
|
|
31
30
|
class RadixCache:
|
32
|
-
def __init__(self, disable=False):
|
33
|
-
self.reset()
|
31
|
+
def __init__(self, disable: bool = False):
|
34
32
|
self.disable = disable
|
33
|
+
self.reset()
|
35
34
|
|
36
35
|
##### Public API #####
|
37
36
|
|
38
37
|
def reset(self):
|
39
38
|
self.root_node = TreeNode()
|
39
|
+
self.root_node.key = []
|
40
40
|
self.root_node.value = []
|
41
41
|
self.root_node.ref_counter = 1
|
42
42
|
self.evictable_size_ = 0
|
@@ -69,7 +69,7 @@ class RadixCache:
|
|
69
69
|
|
70
70
|
def evict(self, num_tokens, evict_callback):
|
71
71
|
if self.disable:
|
72
|
-
|
72
|
+
return
|
73
73
|
|
74
74
|
leaves = self._collect_leaves()
|
75
75
|
heapq.heapify(leaves)
|
@@ -113,42 +113,48 @@ class RadixCache:
|
|
113
113
|
return self.evictable_size_
|
114
114
|
|
115
115
|
##### Internal Helper Functions #####
|
116
|
+
|
116
117
|
def _match_prefix_helper(self, node, key, value, last_node):
|
117
118
|
node.last_access_time = time.time()
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
119
|
+
if len(key) == 0:
|
120
|
+
return
|
121
|
+
|
122
|
+
if key[0] in node.children.keys():
|
123
|
+
child = node.children[key[0]]
|
124
|
+
prefix_len = _key_match(child.key, key)
|
125
|
+
if prefix_len < len(child.key):
|
126
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
127
|
+
value.append(new_node.value)
|
128
|
+
last_node[0] = new_node
|
129
|
+
else:
|
130
|
+
value.append(child.value)
|
131
|
+
last_node[0] = child
|
132
|
+
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
131
133
|
|
132
134
|
def _split_node(self, key, child, split_len):
|
133
135
|
# new_node -> child
|
134
136
|
new_node = TreeNode()
|
135
|
-
new_node.children = {key[split_len:]: child}
|
137
|
+
new_node.children = {key[split_len:][0]: child}
|
136
138
|
new_node.parent = child.parent
|
137
139
|
new_node.ref_counter = child.ref_counter
|
140
|
+
new_node.key = child.key[:split_len]
|
138
141
|
new_node.value = child.value[:split_len]
|
139
142
|
child.parent = new_node
|
143
|
+
child.key = child.key[split_len:]
|
140
144
|
child.value = child.value[split_len:]
|
141
|
-
new_node.parent.children[key[:split_len]] = new_node
|
142
|
-
del new_node.parent.children[key]
|
145
|
+
new_node.parent.children[key[:split_len][0]] = new_node
|
143
146
|
return new_node
|
144
147
|
|
145
148
|
def _insert_helper(self, node, key, value):
|
146
149
|
node.last_access_time = time.time()
|
150
|
+
if len(key) == 0:
|
151
|
+
return 0
|
147
152
|
|
148
|
-
|
149
|
-
|
153
|
+
if key[0] in node.children.keys():
|
154
|
+
child = node.children[key[0]]
|
155
|
+
prefix_len = _key_match(child.key, key)
|
150
156
|
|
151
|
-
if prefix_len == len(
|
157
|
+
if prefix_len == len(child.key):
|
152
158
|
if prefix_len == len(key):
|
153
159
|
return prefix_len
|
154
160
|
else:
|
@@ -156,23 +162,25 @@ class RadixCache:
|
|
156
162
|
value = value[prefix_len:]
|
157
163
|
return prefix_len + self._insert_helper(child, key, value)
|
158
164
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
)
|
165
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
166
|
+
return prefix_len + self._insert_helper(
|
167
|
+
new_node, key[prefix_len:], value[prefix_len:]
|
168
|
+
)
|
164
169
|
|
165
170
|
if len(key):
|
166
171
|
new_node = TreeNode()
|
167
172
|
new_node.parent = node
|
173
|
+
new_node.key = key
|
168
174
|
new_node.value = value
|
169
|
-
node.children[key] = new_node
|
175
|
+
node.children[key[0]] = new_node
|
170
176
|
self.evictable_size_ += len(value)
|
171
177
|
return 0
|
172
178
|
|
173
179
|
def _print_helper(self, node, indent):
|
174
|
-
for
|
175
|
-
print(
|
180
|
+
for _, child in node.children.items():
|
181
|
+
print(
|
182
|
+
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
183
|
+
)
|
176
184
|
self._print_helper(child, indent=indent + 2)
|
177
185
|
|
178
186
|
def _delete_leaf(self, node):
|
@@ -180,7 +188,7 @@ class RadixCache:
|
|
180
188
|
if v == node:
|
181
189
|
break
|
182
190
|
del node.parent.children[k]
|
183
|
-
self.evictable_size_ -= len(
|
191
|
+
self.evictable_size_ -= len(node.key)
|
184
192
|
|
185
193
|
def _total_size_helper(self, node):
|
186
194
|
x = len(node.value)
|
@@ -203,7 +211,7 @@ class RadixCache:
|
|
203
211
|
|
204
212
|
|
205
213
|
if __name__ == "__main__":
|
206
|
-
tree = RadixCache(
|
214
|
+
tree = RadixCache()
|
207
215
|
|
208
216
|
tree.insert("Hello")
|
209
217
|
tree.insert("Hello")
|