sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/tracer.py +6 -4
  11. sglang/launch_server.py +2 -1
  12. sglang/srt/constrained/fsm_cache.py +1 -0
  13. sglang/srt/constrained/jump_forward.py +1 -0
  14. sglang/srt/conversation.py +2 -2
  15. sglang/srt/hf_transformers_utils.py +2 -1
  16. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  17. sglang/srt/layers/extend_attention.py +1 -0
  18. sglang/srt/layers/logits_processor.py +114 -54
  19. sglang/srt/layers/radix_attention.py +2 -1
  20. sglang/srt/layers/token_attention.py +1 -0
  21. sglang/srt/managers/detokenizer_manager.py +5 -1
  22. sglang/srt/managers/io_struct.py +12 -0
  23. sglang/srt/managers/router/infer_batch.py +70 -33
  24. sglang/srt/managers/router/manager.py +7 -2
  25. sglang/srt/managers/router/model_rpc.py +116 -73
  26. sglang/srt/managers/router/model_runner.py +111 -167
  27. sglang/srt/managers/router/radix_cache.py +46 -38
  28. sglang/srt/managers/tokenizer_manager.py +56 -11
  29. sglang/srt/memory_pool.py +5 -14
  30. sglang/srt/model_config.py +7 -0
  31. sglang/srt/models/commandr.py +376 -0
  32. sglang/srt/models/dbrx.py +413 -0
  33. sglang/srt/models/dbrx_config.py +281 -0
  34. sglang/srt/models/gemma.py +22 -20
  35. sglang/srt/models/llama2.py +23 -21
  36. sglang/srt/models/llava.py +12 -10
  37. sglang/srt/models/mixtral.py +27 -25
  38. sglang/srt/models/qwen.py +23 -21
  39. sglang/srt/models/qwen2.py +23 -21
  40. sglang/srt/models/stablelm.py +20 -21
  41. sglang/srt/models/yivl.py +6 -5
  42. sglang/srt/openai_api_adapter.py +356 -0
  43. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  44. sglang/srt/sampling_params.py +2 -0
  45. sglang/srt/server.py +68 -447
  46. sglang/srt/server_args.py +76 -49
  47. sglang/srt/utils.py +88 -32
  48. sglang/srt/weight_utils.py +402 -0
  49. sglang/test/test_programs.py +8 -7
  50. sglang/test/test_utils.py +195 -7
  51. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
  52. sglang-0.1.15.dist-info/RECORD +69 -0
  53. sglang-0.1.14.dist-info/RECORD +0 -64
  54. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  55. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
  56. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,35 @@
1
1
  import importlib
2
- import logging
2
+ import importlib.resources
3
3
  import inspect
4
+ import logging
5
+ import pkgutil
4
6
  from dataclasses import dataclass
5
7
  from functools import lru_cache
6
- from pathlib import Path
7
- import importlib.resources
8
+ from typing import List
8
9
 
9
10
  import numpy as np
10
11
  import torch
11
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
12
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
13
- from sglang.srt.utils import is_multimodal_model
14
- from sglang.utils import get_available_gpu_memory
15
12
  from vllm.model_executor.layers.quantization.awq import AWQConfig
16
13
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
17
14
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
18
- from vllm.model_executor.model_loader import _set_default_torch_dtype
19
- from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
20
-
21
- import importlib
22
- import pkgutil
15
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
16
+ from vllm.distributed import initialize_model_parallel
23
17
 
24
- import sglang
18
+ from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
+ from sglang.srt.utils import is_multimodal_model
21
+ from sglang.utils import get_available_gpu_memory
25
22
 
26
- QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
23
+ QUANTIZATION_CONFIG_MAPPING = {
24
+ "awq": AWQConfig,
25
+ "gptq": GPTQConfig,
26
+ "marlin": MarlinConfig,
27
+ }
27
28
 
28
29
  logger = logging.getLogger("model_runner")
29
30
 
30
-
31
31
  # for server args in model endpoints
32
- global_server_args_dict: dict = None
32
+ global_server_args_dict = {}
33
33
 
34
34
 
35
35
  @lru_cache()
@@ -37,7 +37,7 @@ def import_model_classes():
37
37
  model_arch_name_to_cls = {}
38
38
  package_name = "sglang.srt.models"
39
39
  package = importlib.import_module(package_name)
40
- for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
40
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
41
  if not ispkg:
42
42
  module = importlib.import_module(name)
43
43
  if hasattr(module, "EntryClass"):
@@ -87,6 +87,7 @@ class InputMetadata:
87
87
 
88
88
  other_kv_index: torch.Tensor = None
89
89
  return_logprob: bool = False
90
+ top_logprobs_nums: List[int] = None
90
91
 
91
92
  # for flashinfer
92
93
  qo_indptr: torch.Tensor = None
@@ -106,18 +107,20 @@ class InputMetadata:
106
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
107
108
  )
108
109
  self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
110
+ self.kv_last_page_len = torch.ones(
111
+ (self.batch_size,), dtype=torch.int32, device="cuda"
112
+ )
113
+ req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
114
+ seq_lens_cpu = self.seq_lens.tolist()
109
115
  self.kv_indices = torch.cat(
110
116
  [
111
117
  self.req_to_token_pool.req_to_token[
112
- self.req_pool_indices[i].item(), : self.seq_lens[i].item()
118
+ req_pool_indices_cpu[i], : seq_lens_cpu[i]
113
119
  ]
114
120
  for i in range(self.batch_size)
115
121
  ],
116
122
  dim=0,
117
123
  ).contiguous()
118
- self.kv_last_page_len = torch.ones(
119
- (self.batch_size,), dtype=torch.int32, device="cuda"
120
- )
121
124
 
122
125
  workspace_buffer = torch.empty(
123
126
  32 * 1024 * 1024, dtype=torch.int8, device="cuda"
@@ -140,13 +143,9 @@ class InputMetadata:
140
143
  self.kv_last_page_len,
141
144
  self.model_runner.model_config.num_attention_heads // tp_size,
142
145
  self.model_runner.model_config.num_key_value_heads // tp_size,
146
+ self.model_runner.model_config.head_dim
143
147
  ]
144
148
 
145
- # flashinfer >= 0.0.3
146
- # FIXME: Drop this when flashinfer updates to 0.0.4
147
- if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
148
- args.append(self.model_runner.model_config.head_dim)
149
-
150
149
  self.prefill_wrapper.begin_forward(*args)
151
150
  else:
152
151
  self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
@@ -183,6 +182,7 @@ class InputMetadata:
183
182
  out_cache_loc,
184
183
  out_cache_cont_start=None,
185
184
  out_cache_cont_end=None,
185
+ top_logprobs_nums=None,
186
186
  return_logprob=False,
187
187
  ):
188
188
  batch_size = len(req_pool_indices)
@@ -197,15 +197,15 @@ class InputMetadata:
197
197
  req_pool_indices[0], seq_lens[0] - 1
198
198
  ].item()
199
199
  else:
200
- seq_lens_np = seq_lens.cpu().numpy()
201
- prefix_lens_np = prefix_lens.cpu().numpy()
202
- position_ids_offsets_np = position_ids_offsets.cpu().numpy()
200
+ seq_lens_cpu = seq_lens.cpu().numpy()
201
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
202
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
203
203
  positions = torch.tensor(
204
204
  np.concatenate(
205
205
  [
206
206
  np.arange(
207
- prefix_lens_np[i] + position_ids_offsets_np[i],
208
- seq_lens_np[i] + position_ids_offsets_np[i],
207
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
208
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
209
209
  )
210
210
  for i in range(batch_size)
211
211
  ],
@@ -231,8 +231,9 @@ class InputMetadata:
231
231
  out_cache_loc=out_cache_loc,
232
232
  out_cache_cont_start=out_cache_cont_start,
233
233
  out_cache_cont_end=out_cache_cont_end,
234
- return_logprob=return_logprob,
235
234
  other_kv_index=other_kv_index,
235
+ return_logprob=return_logprob,
236
+ top_logprobs_nums=top_logprobs_nums,
236
237
  )
237
238
 
238
239
  if forward_mode == ForwardMode.EXTEND:
@@ -276,9 +277,6 @@ class ModelRunner:
276
277
  init_method=f"tcp://127.0.0.1:{self.nccl_port}",
277
278
  )
278
279
 
279
- # A small all_reduce for warmup.
280
- if self.tp_size > 1:
281
- torch.distributed.all_reduce(torch.zeros(1).cuda())
282
280
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
283
281
 
284
282
  total_gpu_memory = get_available_gpu_memory(
@@ -297,31 +295,33 @@ class ModelRunner:
297
295
  logger.info(f"Rank {self.tp_rank}: load weight begin.")
298
296
 
299
297
  # Load weights
300
- linear_method = None
301
- with _set_default_torch_dtype(torch.float16):
298
+ quant_config = None
299
+
300
+ quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
+ if quant_cfg is not None:
302
+ quant_method = quant_cfg.get("quant_method", "").lower()
303
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
304
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
305
+ is_format_marlin = quant_cfg.get(
306
+ "checkpoint_format"
307
+ ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
+
309
+ # Use marlin if the GPTQ model is serialized in marlin format.
310
+ if quant_method == "gptq" and is_format_marlin:
311
+ quant_method = "marlin"
312
+
313
+ quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
+
315
+ if quant_config_class is None:
316
+ raise ValueError(f"Unsupported quantization method: {quant_method}")
317
+
318
+ quant_config = quant_config_class.from_config(quant_cfg)
319
+ logger.info(f"quant_config: {quant_config}")
320
+
321
+ with set_default_torch_dtype(torch.float16):
302
322
  with torch.device("cuda"):
303
- hf_quant_config = getattr(
304
- self.model_config.hf_config, "quantization_config", None
305
- )
306
- if hf_quant_config is not None:
307
- hf_quant_method = hf_quant_config["quant_method"]
308
-
309
- # compat: autogptq uses is_marlin_format within quant config
310
- if (hf_quant_method == "gptq"
311
- and "is_marlin_format" in hf_quant_config
312
- and hf_quant_config["is_marlin_format"]):
313
- hf_quant_method = "marlin"
314
- quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
315
-
316
- if quant_config_class is None:
317
- raise ValueError(
318
- f"Unsupported quantization method: {hf_quant_config['quant_method']}"
319
- )
320
- quant_config = quant_config_class.from_config(hf_quant_config)
321
- logger.info(f"quant_config: {quant_config}")
322
- linear_method = quant_config.get_linear_method()
323
323
  model = model_class(
324
- config=self.model_config.hf_config, linear_method=linear_method
324
+ config=self.model_config.hf_config, quant_config=quant_config
325
325
  )
326
326
  model.load_weights(
327
327
  self.model_config.path,
@@ -367,148 +367,92 @@ class ModelRunner:
367
367
  )
368
368
 
369
369
  @torch.inference_mode()
370
- def forward_prefill(
371
- self,
372
- input_ids,
373
- req_pool_indices,
374
- seq_lens,
375
- prefix_lens,
376
- position_ids_offsets,
377
- out_cache_loc,
378
- return_logprob,
379
- ):
370
+ def forward_prefill(self, batch: Batch):
380
371
  input_metadata = InputMetadata.create(
381
372
  self,
382
373
  forward_mode=ForwardMode.PREFILL,
383
374
  tp_size=self.tp_size,
384
- req_pool_indices=req_pool_indices,
385
- seq_lens=seq_lens,
386
- prefix_lens=prefix_lens,
387
- position_ids_offsets=position_ids_offsets,
388
- out_cache_loc=out_cache_loc,
389
- return_logprob=return_logprob,
375
+ req_pool_indices=batch.req_pool_indices,
376
+ seq_lens=batch.seq_lens,
377
+ prefix_lens=batch.prefix_lens,
378
+ position_ids_offsets=batch.position_ids_offsets,
379
+ out_cache_loc=batch.out_cache_loc,
380
+ top_logprobs_nums=batch.top_logprobs_nums,
381
+ return_logprob=batch.return_logprob,
382
+ )
383
+ return self.model.forward(
384
+ batch.input_ids, input_metadata.positions, input_metadata
390
385
  )
391
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
392
386
 
393
387
  @torch.inference_mode()
394
- def forward_extend(
395
- self,
396
- input_ids,
397
- req_pool_indices,
398
- seq_lens,
399
- prefix_lens,
400
- position_ids_offsets,
401
- out_cache_loc,
402
- return_logprob,
403
- ):
388
+ def forward_extend(self, batch: Batch):
404
389
  input_metadata = InputMetadata.create(
405
390
  self,
406
391
  forward_mode=ForwardMode.EXTEND,
407
392
  tp_size=self.tp_size,
408
- req_pool_indices=req_pool_indices,
409
- seq_lens=seq_lens,
410
- prefix_lens=prefix_lens,
411
- position_ids_offsets=position_ids_offsets,
412
- out_cache_loc=out_cache_loc,
413
- return_logprob=return_logprob,
393
+ req_pool_indices=batch.req_pool_indices,
394
+ seq_lens=batch.seq_lens,
395
+ prefix_lens=batch.prefix_lens,
396
+ position_ids_offsets=batch.position_ids_offsets,
397
+ out_cache_loc=batch.out_cache_loc,
398
+ top_logprobs_nums=batch.top_logprobs_nums,
399
+ return_logprob=batch.return_logprob,
400
+ )
401
+ return self.model.forward(
402
+ batch.input_ids, input_metadata.positions, input_metadata
414
403
  )
415
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
416
404
 
417
405
  @torch.inference_mode()
418
- def forward_decode(
419
- self,
420
- input_ids,
421
- req_pool_indices,
422
- seq_lens,
423
- prefix_lens,
424
- position_ids_offsets,
425
- out_cache_loc,
426
- out_cache_cont_start,
427
- out_cache_cont_end,
428
- return_logprob,
429
- ):
406
+ def forward_decode(self, batch: Batch):
430
407
  input_metadata = InputMetadata.create(
431
408
  self,
432
409
  forward_mode=ForwardMode.DECODE,
433
410
  tp_size=self.tp_size,
434
- req_pool_indices=req_pool_indices,
435
- seq_lens=seq_lens,
436
- prefix_lens=prefix_lens,
437
- position_ids_offsets=position_ids_offsets,
438
- out_cache_loc=out_cache_loc,
439
- out_cache_cont_start=out_cache_cont_start,
440
- out_cache_cont_end=out_cache_cont_end,
441
- return_logprob=return_logprob,
411
+ req_pool_indices=batch.req_pool_indices,
412
+ seq_lens=batch.seq_lens,
413
+ prefix_lens=batch.prefix_lens,
414
+ position_ids_offsets=batch.position_ids_offsets,
415
+ out_cache_loc=batch.out_cache_loc,
416
+ out_cache_cont_start=batch.out_cache_cont_start,
417
+ out_cache_cont_end=batch.out_cache_cont_end,
418
+ top_logprobs_nums=batch.top_logprobs_nums,
419
+ return_logprob=batch.return_logprob,
420
+ )
421
+ return self.model.forward(
422
+ batch.input_ids, input_metadata.positions, input_metadata
442
423
  )
443
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
444
424
 
445
425
  @torch.inference_mode()
446
- def forward_extend_multi_modal(
447
- self,
448
- input_ids,
449
- pixel_values,
450
- image_sizes,
451
- image_offsets,
452
- req_pool_indices,
453
- seq_lens,
454
- prefix_lens,
455
- position_ids_offsets,
456
- out_cache_loc,
457
- return_logprob,
458
- ):
426
+ def forward_extend_multi_modal(self, batch: Batch):
459
427
  input_metadata = InputMetadata.create(
460
428
  self,
461
429
  forward_mode=ForwardMode.EXTEND,
462
430
  tp_size=self.tp_size,
463
- req_pool_indices=req_pool_indices,
464
- seq_lens=seq_lens,
465
- prefix_lens=prefix_lens,
466
- position_ids_offsets=position_ids_offsets,
467
- out_cache_loc=out_cache_loc,
468
- return_logprob=return_logprob,
431
+ req_pool_indices=batch.req_pool_indices,
432
+ seq_lens=batch.seq_lens,
433
+ prefix_lens=batch.prefix_lens,
434
+ position_ids_offsets=batch.position_ids_offsets,
435
+ out_cache_loc=batch.out_cache_loc,
436
+ top_logprobs_nums=batch.top_logprobs_nums,
437
+ return_logprob=batch.return_logprob,
469
438
  )
470
439
  return self.model.forward(
471
- input_ids,
440
+ batch.input_ids,
472
441
  input_metadata.positions,
473
442
  input_metadata,
474
- pixel_values,
475
- image_sizes,
476
- image_offsets,
443
+ batch.pixel_values,
444
+ batch.image_sizes,
445
+ batch.image_offsets,
477
446
  )
478
447
 
479
- def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
448
+ def forward(self, batch: Batch, forward_mode: ForwardMode):
480
449
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
481
- kwargs = {
482
- "input_ids": batch.input_ids,
483
- "pixel_values": batch.pixel_values,
484
- "image_sizes": batch.image_sizes,
485
- "image_offsets": batch.image_offsets,
486
- "req_pool_indices": batch.req_pool_indices,
487
- "seq_lens": batch.seq_lens,
488
- "prefix_lens": batch.prefix_lens,
489
- "position_ids_offsets": batch.position_ids_offsets,
490
- "out_cache_loc": batch.out_cache_loc,
491
- "return_logprob": return_logprob,
492
- }
493
- return self.forward_extend_multi_modal(**kwargs)
494
- else:
495
- kwargs = {
496
- "input_ids": batch.input_ids,
497
- "req_pool_indices": batch.req_pool_indices,
498
- "seq_lens": batch.seq_lens,
499
- "prefix_lens": batch.prefix_lens,
500
- "position_ids_offsets": batch.position_ids_offsets,
501
- "out_cache_loc": batch.out_cache_loc,
502
- "return_logprob": return_logprob,
503
- }
504
-
505
- if forward_mode == ForwardMode.DECODE:
506
- kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
507
- kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
508
- return self.forward_decode(**kwargs)
450
+ return self.forward_extend_multi_modal(batch)
451
+ elif forward_mode == ForwardMode.DECODE:
452
+ return self.forward_decode(batch)
509
453
  elif forward_mode == ForwardMode.EXTEND:
510
- return self.forward_extend(**kwargs)
454
+ return self.forward_extend(batch)
511
455
  elif forward_mode == ForwardMode.PREFILL:
512
- return self.forward_prefill(**kwargs)
456
+ return self.forward_prefill(batch)
513
457
  else:
514
458
  raise ValueError(f"Invaid forward mode: {forward_mode}")
@@ -1,8 +1,6 @@
1
1
  import heapq
2
2
  import time
3
3
  from collections import defaultdict
4
- from dataclasses import dataclass
5
- from typing import Tuple
6
4
 
7
5
  import torch
8
6
 
@@ -11,32 +9,34 @@ class TreeNode:
11
9
  def __init__(self):
12
10
  self.children = defaultdict(TreeNode)
13
11
  self.parent = None
12
+ self.key = None
14
13
  self.value = None
15
14
  self.ref_counter = 0
16
15
  self.last_access_time = time.time()
17
16
 
18
- def __lt__(self, other):
17
+ def __lt__(self, other: "TreeNode"):
19
18
  return self.last_access_time < other.last_access_time
20
19
 
21
20
 
22
- def match(key, seq):
21
+ def _key_match(key0, key1):
23
22
  i = 0
24
- for k, w in zip(key, seq):
25
- if k != w:
23
+ for k0, k1 in zip(key0, key1):
24
+ if k0 != k1:
26
25
  break
27
26
  i += 1
28
27
  return i
29
28
 
30
29
 
31
30
  class RadixCache:
32
- def __init__(self, disable=False):
33
- self.reset()
31
+ def __init__(self, disable: bool = False):
34
32
  self.disable = disable
33
+ self.reset()
35
34
 
36
35
  ##### Public API #####
37
36
 
38
37
  def reset(self):
39
38
  self.root_node = TreeNode()
39
+ self.root_node.key = []
40
40
  self.root_node.value = []
41
41
  self.root_node.ref_counter = 1
42
42
  self.evictable_size_ = 0
@@ -69,7 +69,7 @@ class RadixCache:
69
69
 
70
70
  def evict(self, num_tokens, evict_callback):
71
71
  if self.disable:
72
- raise RuntimeError()
72
+ return
73
73
 
74
74
  leaves = self._collect_leaves()
75
75
  heapq.heapify(leaves)
@@ -113,42 +113,48 @@ class RadixCache:
113
113
  return self.evictable_size_
114
114
 
115
115
  ##### Internal Helper Functions #####
116
+
116
117
  def _match_prefix_helper(self, node, key, value, last_node):
117
118
  node.last_access_time = time.time()
118
-
119
- for c_key, child in node.children.items():
120
- prefix_len = match(c_key, key)
121
- if prefix_len != 0:
122
- if prefix_len < len(c_key):
123
- new_node = self._split_node(c_key, child, prefix_len)
124
- value.append(new_node.value)
125
- last_node[0] = new_node
126
- else:
127
- value.append(child.value)
128
- last_node[0] = child
129
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
130
- break
119
+ if len(key) == 0:
120
+ return
121
+
122
+ if key[0] in node.children.keys():
123
+ child = node.children[key[0]]
124
+ prefix_len = _key_match(child.key, key)
125
+ if prefix_len < len(child.key):
126
+ new_node = self._split_node(child.key, child, prefix_len)
127
+ value.append(new_node.value)
128
+ last_node[0] = new_node
129
+ else:
130
+ value.append(child.value)
131
+ last_node[0] = child
132
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
131
133
 
132
134
  def _split_node(self, key, child, split_len):
133
135
  # new_node -> child
134
136
  new_node = TreeNode()
135
- new_node.children = {key[split_len:]: child}
137
+ new_node.children = {key[split_len:][0]: child}
136
138
  new_node.parent = child.parent
137
139
  new_node.ref_counter = child.ref_counter
140
+ new_node.key = child.key[:split_len]
138
141
  new_node.value = child.value[:split_len]
139
142
  child.parent = new_node
143
+ child.key = child.key[split_len:]
140
144
  child.value = child.value[split_len:]
141
- new_node.parent.children[key[:split_len]] = new_node
142
- del new_node.parent.children[key]
145
+ new_node.parent.children[key[:split_len][0]] = new_node
143
146
  return new_node
144
147
 
145
148
  def _insert_helper(self, node, key, value):
146
149
  node.last_access_time = time.time()
150
+ if len(key) == 0:
151
+ return 0
147
152
 
148
- for c_key, child in node.children.items():
149
- prefix_len = match(c_key, key)
153
+ if key[0] in node.children.keys():
154
+ child = node.children[key[0]]
155
+ prefix_len = _key_match(child.key, key)
150
156
 
151
- if prefix_len == len(c_key):
157
+ if prefix_len == len(child.key):
152
158
  if prefix_len == len(key):
153
159
  return prefix_len
154
160
  else:
@@ -156,23 +162,25 @@ class RadixCache:
156
162
  value = value[prefix_len:]
157
163
  return prefix_len + self._insert_helper(child, key, value)
158
164
 
159
- if prefix_len:
160
- new_node = self._split_node(c_key, child, prefix_len)
161
- return prefix_len + self._insert_helper(
162
- new_node, key[prefix_len:], value[prefix_len:]
163
- )
165
+ new_node = self._split_node(child.key, child, prefix_len)
166
+ return prefix_len + self._insert_helper(
167
+ new_node, key[prefix_len:], value[prefix_len:]
168
+ )
164
169
 
165
170
  if len(key):
166
171
  new_node = TreeNode()
167
172
  new_node.parent = node
173
+ new_node.key = key
168
174
  new_node.value = value
169
- node.children[key] = new_node
175
+ node.children[key[0]] = new_node
170
176
  self.evictable_size_ += len(value)
171
177
  return 0
172
178
 
173
179
  def _print_helper(self, node, indent):
174
- for key, child in node.children.items():
175
- print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
180
+ for _, child in node.children.items():
181
+ print(
182
+ " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
183
+ )
176
184
  self._print_helper(child, indent=indent + 2)
177
185
 
178
186
  def _delete_leaf(self, node):
@@ -180,7 +188,7 @@ class RadixCache:
180
188
  if v == node:
181
189
  break
182
190
  del node.parent.children[k]
183
- self.evictable_size_ -= len(k)
191
+ self.evictable_size_ -= len(node.key)
184
192
 
185
193
  def _total_size_helper(self, node):
186
194
  x = len(node.value)
@@ -203,7 +211,7 @@ class RadixCache:
203
211
 
204
212
 
205
213
  if __name__ == "__main__":
206
- tree = RadixCache(disable=False)
214
+ tree = RadixCache()
207
215
 
208
216
  tree.insert("Hello")
209
217
  tree.insert("Hello")