sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +33 -13
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/ir.py +1 -1
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server.py +2 -1
  13. sglang/srt/constrained/fsm_cache.py +15 -3
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/hf_transformers_utils.py +2 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  18. sglang/srt/layers/extend_attention.py +1 -0
  19. sglang/srt/layers/logits_processor.py +114 -54
  20. sglang/srt/layers/radix_attention.py +2 -1
  21. sglang/srt/layers/token_attention.py +1 -0
  22. sglang/srt/managers/detokenizer_manager.py +5 -1
  23. sglang/srt/managers/io_struct.py +12 -0
  24. sglang/srt/managers/router/infer_batch.py +70 -33
  25. sglang/srt/managers/router/manager.py +7 -2
  26. sglang/srt/managers/router/model_rpc.py +116 -73
  27. sglang/srt/managers/router/model_runner.py +121 -155
  28. sglang/srt/managers/router/radix_cache.py +46 -38
  29. sglang/srt/managers/tokenizer_manager.py +56 -11
  30. sglang/srt/memory_pool.py +5 -14
  31. sglang/srt/model_config.py +7 -0
  32. sglang/srt/models/commandr.py +376 -0
  33. sglang/srt/models/dbrx.py +413 -0
  34. sglang/srt/models/dbrx_config.py +281 -0
  35. sglang/srt/models/gemma.py +22 -20
  36. sglang/srt/models/llama2.py +23 -21
  37. sglang/srt/models/llava.py +12 -10
  38. sglang/srt/models/mixtral.py +27 -25
  39. sglang/srt/models/qwen.py +23 -21
  40. sglang/srt/models/qwen2.py +23 -21
  41. sglang/srt/models/stablelm.py +292 -0
  42. sglang/srt/models/yivl.py +6 -5
  43. sglang/srt/openai_api_adapter.py +356 -0
  44. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  45. sglang/srt/sampling_params.py +2 -0
  46. sglang/srt/server.py +68 -439
  47. sglang/srt/server_args.py +76 -49
  48. sglang/srt/utils.py +88 -32
  49. sglang/srt/weight_utils.py +402 -0
  50. sglang/test/test_programs.py +8 -7
  51. sglang/test/test_utils.py +196 -8
  52. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
  53. sglang-0.1.15.dist-info/RECORD +69 -0
  54. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
  55. sglang-0.1.13.dist-info/RECORD +0 -63
  56. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  57. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,47 @@
1
1
  import importlib
2
+ import importlib.resources
3
+ import inspect
2
4
  import logging
5
+ import pkgutil
3
6
  from dataclasses import dataclass
4
7
  from functools import lru_cache
5
- from pathlib import Path
8
+ from typing import List
6
9
 
7
10
  import numpy as np
8
11
  import torch
12
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
13
+ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
14
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
15
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
16
+ from vllm.distributed import initialize_model_parallel
17
+
9
18
  from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
10
19
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
11
20
  from sglang.srt.utils import is_multimodal_model
12
21
  from sglang.utils import get_available_gpu_memory
13
- from vllm.model_executor.layers.quantization.awq import AWQConfig
14
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
15
- from vllm.model_executor.model_loader import _set_default_torch_dtype
16
- from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
17
-
18
- import sglang
19
22
 
20
- QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
23
+ QUANTIZATION_CONFIG_MAPPING = {
24
+ "awq": AWQConfig,
25
+ "gptq": GPTQConfig,
26
+ "marlin": MarlinConfig,
27
+ }
21
28
 
22
29
  logger = logging.getLogger("model_runner")
23
30
 
24
-
25
31
  # for server args in model endpoints
26
- global_server_args_dict: dict = None
32
+ global_server_args_dict = {}
27
33
 
28
34
 
29
35
  @lru_cache()
30
36
  def import_model_classes():
31
37
  model_arch_name_to_cls = {}
32
- for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
33
- module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
34
- if hasattr(module, "EntryClass"):
35
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
38
+ package_name = "sglang.srt.models"
39
+ package = importlib.import_module(package_name)
40
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
+ if not ispkg:
42
+ module = importlib.import_module(name)
43
+ if hasattr(module, "EntryClass"):
44
+ model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
36
45
  return model_arch_name_to_cls
37
46
 
38
47
 
@@ -78,6 +87,7 @@ class InputMetadata:
78
87
 
79
88
  other_kv_index: torch.Tensor = None
80
89
  return_logprob: bool = False
90
+ top_logprobs_nums: List[int] = None
81
91
 
82
92
  # for flashinfer
83
93
  qo_indptr: torch.Tensor = None
@@ -97,18 +107,20 @@ class InputMetadata:
97
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
98
108
  )
99
109
  self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
110
+ self.kv_last_page_len = torch.ones(
111
+ (self.batch_size,), dtype=torch.int32, device="cuda"
112
+ )
113
+ req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
114
+ seq_lens_cpu = self.seq_lens.tolist()
100
115
  self.kv_indices = torch.cat(
101
116
  [
102
117
  self.req_to_token_pool.req_to_token[
103
- self.req_pool_indices[i].item(), : self.seq_lens[i].item()
118
+ req_pool_indices_cpu[i], : seq_lens_cpu[i]
104
119
  ]
105
120
  for i in range(self.batch_size)
106
121
  ],
107
122
  dim=0,
108
123
  ).contiguous()
109
- self.kv_last_page_len = torch.ones(
110
- (self.batch_size,), dtype=torch.int32, device="cuda"
111
- )
112
124
 
113
125
  workspace_buffer = torch.empty(
114
126
  32 * 1024 * 1024, dtype=torch.int8, device="cuda"
@@ -124,14 +136,17 @@ class InputMetadata:
124
136
  self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
125
137
  workspace_buffer, "NHD"
126
138
  )
127
- self.prefill_wrapper.begin_forward(
139
+ args = [
128
140
  self.qo_indptr,
129
141
  self.kv_indptr,
130
142
  self.kv_indices,
131
143
  self.kv_last_page_len,
132
144
  self.model_runner.model_config.num_attention_heads // tp_size,
133
145
  self.model_runner.model_config.num_key_value_heads // tp_size,
134
- )
146
+ self.model_runner.model_config.head_dim
147
+ ]
148
+
149
+ self.prefill_wrapper.begin_forward(*args)
135
150
  else:
136
151
  self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
137
152
  workspace_buffer, "NHD"
@@ -167,6 +182,7 @@ class InputMetadata:
167
182
  out_cache_loc,
168
183
  out_cache_cont_start=None,
169
184
  out_cache_cont_end=None,
185
+ top_logprobs_nums=None,
170
186
  return_logprob=False,
171
187
  ):
172
188
  batch_size = len(req_pool_indices)
@@ -181,15 +197,15 @@ class InputMetadata:
181
197
  req_pool_indices[0], seq_lens[0] - 1
182
198
  ].item()
183
199
  else:
184
- seq_lens_np = seq_lens.cpu().numpy()
185
- prefix_lens_np = prefix_lens.cpu().numpy()
186
- position_ids_offsets_np = position_ids_offsets.cpu().numpy()
200
+ seq_lens_cpu = seq_lens.cpu().numpy()
201
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
202
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
187
203
  positions = torch.tensor(
188
204
  np.concatenate(
189
205
  [
190
206
  np.arange(
191
- prefix_lens_np[i] + position_ids_offsets_np[i],
192
- seq_lens_np[i] + position_ids_offsets_np[i],
207
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
208
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
193
209
  )
194
210
  for i in range(batch_size)
195
211
  ],
@@ -215,8 +231,9 @@ class InputMetadata:
215
231
  out_cache_loc=out_cache_loc,
216
232
  out_cache_cont_start=out_cache_cont_start,
217
233
  out_cache_cont_end=out_cache_cont_end,
218
- return_logprob=return_logprob,
219
234
  other_kv_index=other_kv_index,
235
+ return_logprob=return_logprob,
236
+ top_logprobs_nums=top_logprobs_nums,
220
237
  )
221
238
 
222
239
  if forward_mode == ForwardMode.EXTEND:
@@ -260,9 +277,6 @@ class ModelRunner:
260
277
  init_method=f"tcp://127.0.0.1:{self.nccl_port}",
261
278
  )
262
279
 
263
- # A small all_reduce for warmup.
264
- if self.tp_size > 1:
265
- torch.distributed.all_reduce(torch.zeros(1).cuda())
266
280
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
267
281
 
268
282
  total_gpu_memory = get_available_gpu_memory(
@@ -281,25 +295,33 @@ class ModelRunner:
281
295
  logger.info(f"Rank {self.tp_rank}: load weight begin.")
282
296
 
283
297
  # Load weights
284
- linear_method = None
285
- with _set_default_torch_dtype(torch.float16):
298
+ quant_config = None
299
+
300
+ quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
+ if quant_cfg is not None:
302
+ quant_method = quant_cfg.get("quant_method", "").lower()
303
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
304
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
305
+ is_format_marlin = quant_cfg.get(
306
+ "checkpoint_format"
307
+ ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
+
309
+ # Use marlin if the GPTQ model is serialized in marlin format.
310
+ if quant_method == "gptq" and is_format_marlin:
311
+ quant_method = "marlin"
312
+
313
+ quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
+
315
+ if quant_config_class is None:
316
+ raise ValueError(f"Unsupported quantization method: {quant_method}")
317
+
318
+ quant_config = quant_config_class.from_config(quant_cfg)
319
+ logger.info(f"quant_config: {quant_config}")
320
+
321
+ with set_default_torch_dtype(torch.float16):
286
322
  with torch.device("cuda"):
287
- hf_quant_config = getattr(
288
- self.model_config.hf_config, "quantization_config", None
289
- )
290
- if hf_quant_config is not None:
291
- quant_config_class = QUANTIONCONFIG_MAPPING.get(
292
- hf_quant_config["quant_method"]
293
- )
294
- if quant_config_class is None:
295
- raise ValueError(
296
- f"Unsupported quantization method: {hf_quant_config['quant_method']}"
297
- )
298
- quant_config = quant_config_class.from_config(hf_quant_config)
299
- logger.info(f"quant_config: {quant_config}")
300
- linear_method = quant_config.get_linear_method()
301
323
  model = model_class(
302
- config=self.model_config.hf_config, linear_method=linear_method
324
+ config=self.model_config.hf_config, quant_config=quant_config
303
325
  )
304
326
  model.load_weights(
305
327
  self.model_config.path,
@@ -345,148 +367,92 @@ class ModelRunner:
345
367
  )
346
368
 
347
369
  @torch.inference_mode()
348
- def forward_prefill(
349
- self,
350
- input_ids,
351
- req_pool_indices,
352
- seq_lens,
353
- prefix_lens,
354
- position_ids_offsets,
355
- out_cache_loc,
356
- return_logprob,
357
- ):
370
+ def forward_prefill(self, batch: Batch):
358
371
  input_metadata = InputMetadata.create(
359
372
  self,
360
373
  forward_mode=ForwardMode.PREFILL,
361
374
  tp_size=self.tp_size,
362
- req_pool_indices=req_pool_indices,
363
- seq_lens=seq_lens,
364
- prefix_lens=prefix_lens,
365
- position_ids_offsets=position_ids_offsets,
366
- out_cache_loc=out_cache_loc,
367
- return_logprob=return_logprob,
375
+ req_pool_indices=batch.req_pool_indices,
376
+ seq_lens=batch.seq_lens,
377
+ prefix_lens=batch.prefix_lens,
378
+ position_ids_offsets=batch.position_ids_offsets,
379
+ out_cache_loc=batch.out_cache_loc,
380
+ top_logprobs_nums=batch.top_logprobs_nums,
381
+ return_logprob=batch.return_logprob,
382
+ )
383
+ return self.model.forward(
384
+ batch.input_ids, input_metadata.positions, input_metadata
368
385
  )
369
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
370
386
 
371
387
  @torch.inference_mode()
372
- def forward_extend(
373
- self,
374
- input_ids,
375
- req_pool_indices,
376
- seq_lens,
377
- prefix_lens,
378
- position_ids_offsets,
379
- out_cache_loc,
380
- return_logprob,
381
- ):
388
+ def forward_extend(self, batch: Batch):
382
389
  input_metadata = InputMetadata.create(
383
390
  self,
384
391
  forward_mode=ForwardMode.EXTEND,
385
392
  tp_size=self.tp_size,
386
- req_pool_indices=req_pool_indices,
387
- seq_lens=seq_lens,
388
- prefix_lens=prefix_lens,
389
- position_ids_offsets=position_ids_offsets,
390
- out_cache_loc=out_cache_loc,
391
- return_logprob=return_logprob,
393
+ req_pool_indices=batch.req_pool_indices,
394
+ seq_lens=batch.seq_lens,
395
+ prefix_lens=batch.prefix_lens,
396
+ position_ids_offsets=batch.position_ids_offsets,
397
+ out_cache_loc=batch.out_cache_loc,
398
+ top_logprobs_nums=batch.top_logprobs_nums,
399
+ return_logprob=batch.return_logprob,
400
+ )
401
+ return self.model.forward(
402
+ batch.input_ids, input_metadata.positions, input_metadata
392
403
  )
393
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
394
404
 
395
405
  @torch.inference_mode()
396
- def forward_decode(
397
- self,
398
- input_ids,
399
- req_pool_indices,
400
- seq_lens,
401
- prefix_lens,
402
- position_ids_offsets,
403
- out_cache_loc,
404
- out_cache_cont_start,
405
- out_cache_cont_end,
406
- return_logprob,
407
- ):
406
+ def forward_decode(self, batch: Batch):
408
407
  input_metadata = InputMetadata.create(
409
408
  self,
410
409
  forward_mode=ForwardMode.DECODE,
411
410
  tp_size=self.tp_size,
412
- req_pool_indices=req_pool_indices,
413
- seq_lens=seq_lens,
414
- prefix_lens=prefix_lens,
415
- position_ids_offsets=position_ids_offsets,
416
- out_cache_loc=out_cache_loc,
417
- out_cache_cont_start=out_cache_cont_start,
418
- out_cache_cont_end=out_cache_cont_end,
419
- return_logprob=return_logprob,
411
+ req_pool_indices=batch.req_pool_indices,
412
+ seq_lens=batch.seq_lens,
413
+ prefix_lens=batch.prefix_lens,
414
+ position_ids_offsets=batch.position_ids_offsets,
415
+ out_cache_loc=batch.out_cache_loc,
416
+ out_cache_cont_start=batch.out_cache_cont_start,
417
+ out_cache_cont_end=batch.out_cache_cont_end,
418
+ top_logprobs_nums=batch.top_logprobs_nums,
419
+ return_logprob=batch.return_logprob,
420
+ )
421
+ return self.model.forward(
422
+ batch.input_ids, input_metadata.positions, input_metadata
420
423
  )
421
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
422
424
 
423
425
  @torch.inference_mode()
424
- def forward_extend_multi_modal(
425
- self,
426
- input_ids,
427
- pixel_values,
428
- image_sizes,
429
- image_offsets,
430
- req_pool_indices,
431
- seq_lens,
432
- prefix_lens,
433
- position_ids_offsets,
434
- out_cache_loc,
435
- return_logprob,
436
- ):
426
+ def forward_extend_multi_modal(self, batch: Batch):
437
427
  input_metadata = InputMetadata.create(
438
428
  self,
439
429
  forward_mode=ForwardMode.EXTEND,
440
430
  tp_size=self.tp_size,
441
- req_pool_indices=req_pool_indices,
442
- seq_lens=seq_lens,
443
- prefix_lens=prefix_lens,
444
- position_ids_offsets=position_ids_offsets,
445
- out_cache_loc=out_cache_loc,
446
- return_logprob=return_logprob,
431
+ req_pool_indices=batch.req_pool_indices,
432
+ seq_lens=batch.seq_lens,
433
+ prefix_lens=batch.prefix_lens,
434
+ position_ids_offsets=batch.position_ids_offsets,
435
+ out_cache_loc=batch.out_cache_loc,
436
+ top_logprobs_nums=batch.top_logprobs_nums,
437
+ return_logprob=batch.return_logprob,
447
438
  )
448
439
  return self.model.forward(
449
- input_ids,
440
+ batch.input_ids,
450
441
  input_metadata.positions,
451
442
  input_metadata,
452
- pixel_values,
453
- image_sizes,
454
- image_offsets,
443
+ batch.pixel_values,
444
+ batch.image_sizes,
445
+ batch.image_offsets,
455
446
  )
456
447
 
457
- def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
448
+ def forward(self, batch: Batch, forward_mode: ForwardMode):
458
449
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
459
- kwargs = {
460
- "input_ids": batch.input_ids,
461
- "pixel_values": batch.pixel_values,
462
- "image_sizes": batch.image_sizes,
463
- "image_offsets": batch.image_offsets,
464
- "req_pool_indices": batch.req_pool_indices,
465
- "seq_lens": batch.seq_lens,
466
- "prefix_lens": batch.prefix_lens,
467
- "position_ids_offsets": batch.position_ids_offsets,
468
- "out_cache_loc": batch.out_cache_loc,
469
- "return_logprob": return_logprob,
470
- }
471
- return self.forward_extend_multi_modal(**kwargs)
472
- else:
473
- kwargs = {
474
- "input_ids": batch.input_ids,
475
- "req_pool_indices": batch.req_pool_indices,
476
- "seq_lens": batch.seq_lens,
477
- "prefix_lens": batch.prefix_lens,
478
- "position_ids_offsets": batch.position_ids_offsets,
479
- "out_cache_loc": batch.out_cache_loc,
480
- "return_logprob": return_logprob,
481
- }
482
-
483
- if forward_mode == ForwardMode.DECODE:
484
- kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
485
- kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
486
- return self.forward_decode(**kwargs)
450
+ return self.forward_extend_multi_modal(batch)
451
+ elif forward_mode == ForwardMode.DECODE:
452
+ return self.forward_decode(batch)
487
453
  elif forward_mode == ForwardMode.EXTEND:
488
- return self.forward_extend(**kwargs)
454
+ return self.forward_extend(batch)
489
455
  elif forward_mode == ForwardMode.PREFILL:
490
- return self.forward_prefill(**kwargs)
456
+ return self.forward_prefill(batch)
491
457
  else:
492
458
  raise ValueError(f"Invaid forward mode: {forward_mode}")
@@ -1,8 +1,6 @@
1
1
  import heapq
2
2
  import time
3
3
  from collections import defaultdict
4
- from dataclasses import dataclass
5
- from typing import Tuple
6
4
 
7
5
  import torch
8
6
 
@@ -11,32 +9,34 @@ class TreeNode:
11
9
  def __init__(self):
12
10
  self.children = defaultdict(TreeNode)
13
11
  self.parent = None
12
+ self.key = None
14
13
  self.value = None
15
14
  self.ref_counter = 0
16
15
  self.last_access_time = time.time()
17
16
 
18
- def __lt__(self, other):
17
+ def __lt__(self, other: "TreeNode"):
19
18
  return self.last_access_time < other.last_access_time
20
19
 
21
20
 
22
- def match(key, seq):
21
+ def _key_match(key0, key1):
23
22
  i = 0
24
- for k, w in zip(key, seq):
25
- if k != w:
23
+ for k0, k1 in zip(key0, key1):
24
+ if k0 != k1:
26
25
  break
27
26
  i += 1
28
27
  return i
29
28
 
30
29
 
31
30
  class RadixCache:
32
- def __init__(self, disable=False):
33
- self.reset()
31
+ def __init__(self, disable: bool = False):
34
32
  self.disable = disable
33
+ self.reset()
35
34
 
36
35
  ##### Public API #####
37
36
 
38
37
  def reset(self):
39
38
  self.root_node = TreeNode()
39
+ self.root_node.key = []
40
40
  self.root_node.value = []
41
41
  self.root_node.ref_counter = 1
42
42
  self.evictable_size_ = 0
@@ -69,7 +69,7 @@ class RadixCache:
69
69
 
70
70
  def evict(self, num_tokens, evict_callback):
71
71
  if self.disable:
72
- raise RuntimeError()
72
+ return
73
73
 
74
74
  leaves = self._collect_leaves()
75
75
  heapq.heapify(leaves)
@@ -113,42 +113,48 @@ class RadixCache:
113
113
  return self.evictable_size_
114
114
 
115
115
  ##### Internal Helper Functions #####
116
+
116
117
  def _match_prefix_helper(self, node, key, value, last_node):
117
118
  node.last_access_time = time.time()
118
-
119
- for c_key, child in node.children.items():
120
- prefix_len = match(c_key, key)
121
- if prefix_len != 0:
122
- if prefix_len < len(c_key):
123
- new_node = self._split_node(c_key, child, prefix_len)
124
- value.append(new_node.value)
125
- last_node[0] = new_node
126
- else:
127
- value.append(child.value)
128
- last_node[0] = child
129
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
130
- break
119
+ if len(key) == 0:
120
+ return
121
+
122
+ if key[0] in node.children.keys():
123
+ child = node.children[key[0]]
124
+ prefix_len = _key_match(child.key, key)
125
+ if prefix_len < len(child.key):
126
+ new_node = self._split_node(child.key, child, prefix_len)
127
+ value.append(new_node.value)
128
+ last_node[0] = new_node
129
+ else:
130
+ value.append(child.value)
131
+ last_node[0] = child
132
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
131
133
 
132
134
  def _split_node(self, key, child, split_len):
133
135
  # new_node -> child
134
136
  new_node = TreeNode()
135
- new_node.children = {key[split_len:]: child}
137
+ new_node.children = {key[split_len:][0]: child}
136
138
  new_node.parent = child.parent
137
139
  new_node.ref_counter = child.ref_counter
140
+ new_node.key = child.key[:split_len]
138
141
  new_node.value = child.value[:split_len]
139
142
  child.parent = new_node
143
+ child.key = child.key[split_len:]
140
144
  child.value = child.value[split_len:]
141
- new_node.parent.children[key[:split_len]] = new_node
142
- del new_node.parent.children[key]
145
+ new_node.parent.children[key[:split_len][0]] = new_node
143
146
  return new_node
144
147
 
145
148
  def _insert_helper(self, node, key, value):
146
149
  node.last_access_time = time.time()
150
+ if len(key) == 0:
151
+ return 0
147
152
 
148
- for c_key, child in node.children.items():
149
- prefix_len = match(c_key, key)
153
+ if key[0] in node.children.keys():
154
+ child = node.children[key[0]]
155
+ prefix_len = _key_match(child.key, key)
150
156
 
151
- if prefix_len == len(c_key):
157
+ if prefix_len == len(child.key):
152
158
  if prefix_len == len(key):
153
159
  return prefix_len
154
160
  else:
@@ -156,23 +162,25 @@ class RadixCache:
156
162
  value = value[prefix_len:]
157
163
  return prefix_len + self._insert_helper(child, key, value)
158
164
 
159
- if prefix_len:
160
- new_node = self._split_node(c_key, child, prefix_len)
161
- return prefix_len + self._insert_helper(
162
- new_node, key[prefix_len:], value[prefix_len:]
163
- )
165
+ new_node = self._split_node(child.key, child, prefix_len)
166
+ return prefix_len + self._insert_helper(
167
+ new_node, key[prefix_len:], value[prefix_len:]
168
+ )
164
169
 
165
170
  if len(key):
166
171
  new_node = TreeNode()
167
172
  new_node.parent = node
173
+ new_node.key = key
168
174
  new_node.value = value
169
- node.children[key] = new_node
175
+ node.children[key[0]] = new_node
170
176
  self.evictable_size_ += len(value)
171
177
  return 0
172
178
 
173
179
  def _print_helper(self, node, indent):
174
- for key, child in node.children.items():
175
- print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
180
+ for _, child in node.children.items():
181
+ print(
182
+ " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
183
+ )
176
184
  self._print_helper(child, indent=indent + 2)
177
185
 
178
186
  def _delete_leaf(self, node):
@@ -180,7 +188,7 @@ class RadixCache:
180
188
  if v == node:
181
189
  break
182
190
  del node.parent.children[k]
183
- self.evictable_size_ -= len(k)
191
+ self.evictable_size_ -= len(node.key)
184
192
 
185
193
  def _total_size_helper(self, node):
186
194
  x = len(node.value)
@@ -203,7 +211,7 @@ class RadixCache:
203
211
 
204
212
 
205
213
  if __name__ == "__main__":
206
- tree = RadixCache(disable=False)
214
+ tree = RadixCache()
207
215
 
208
216
  tree.insert("Hello")
209
217
  tree.insert("Hello")