sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,359 @@
1
+ """ModelRunner runs the forward passes of the models."""
2
+
3
+ import importlib
4
+ import importlib.resources
5
+ import logging
6
+ import pkgutil
7
+ from functools import lru_cache
8
+ from typing import Optional, Type
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from vllm.config import DeviceConfig, LoadConfig
13
+ from vllm.config import ModelConfig as VllmModelConfig
14
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
15
+ from vllm.model_executor.model_loader import get_model
16
+ from vllm.model_executor.models import ModelRegistry
17
+
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.managers.controller.infer_batch import (
20
+ Batch,
21
+ ForwardMode,
22
+ InputMetadata,
23
+ global_server_args_dict,
24
+ )
25
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
26
+ from sglang.srt.server_args import ServerArgs
27
+ from sglang.srt.utils import (
28
+ get_available_gpu_memory,
29
+ is_multimodal_model,
30
+ monkey_patch_vllm_dummy_weight_loader,
31
+ monkey_patch_vllm_p2p_access_check,
32
+ )
33
+
34
+ logger = logging.getLogger("srt.model_runner")
35
+
36
+
37
+ class ModelRunner:
38
+ def __init__(
39
+ self,
40
+ model_config,
41
+ mem_fraction_static: float,
42
+ gpu_id: int,
43
+ tp_rank: int,
44
+ tp_size: int,
45
+ nccl_port: int,
46
+ server_args: ServerArgs,
47
+ ):
48
+ # Parse args
49
+ self.model_config = model_config
50
+ self.mem_fraction_static = mem_fraction_static
51
+ self.gpu_id = gpu_id
52
+ self.tp_rank = tp_rank
53
+ self.tp_size = tp_size
54
+ self.nccl_port = nccl_port
55
+ self.server_args = server_args
56
+ self.is_multimodal_model = is_multimodal_model(self.model_config)
57
+ monkey_patch_vllm_dummy_weight_loader()
58
+
59
+ # Init torch distributed
60
+ torch.cuda.set_device(self.gpu_id)
61
+ logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
62
+
63
+ if not server_args.enable_p2p_check:
64
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
65
+
66
+ if server_args.nccl_init_addr:
67
+ nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
68
+ else:
69
+ nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
70
+ init_distributed_environment(
71
+ backend="nccl",
72
+ world_size=self.tp_size,
73
+ rank=self.tp_rank,
74
+ local_rank=self.gpu_id,
75
+ distributed_init_method=nccl_init_method,
76
+ )
77
+ initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
78
+ self.tp_group = get_tp_group()
79
+ total_gpu_memory = get_available_gpu_memory(
80
+ self.gpu_id, distributed=self.tp_size > 1
81
+ )
82
+
83
+ if self.tp_size > 1:
84
+ total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
85
+ if total_local_gpu_memory < total_gpu_memory * 0.9:
86
+ raise ValueError(
87
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
88
+ )
89
+
90
+ # Set some global args
91
+ global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
92
+ global_server_args_dict[
93
+ "attention_reduce_in_fp32"
94
+ ] = server_args.attention_reduce_in_fp32
95
+
96
+ # Load the model and create memory pool
97
+ self.load_model()
98
+ self.init_memory_pool(total_gpu_memory)
99
+ self.init_cublas()
100
+ self.init_flash_infer()
101
+
102
+ # Capture cuda graphs
103
+ self.init_cuda_graphs()
104
+
105
+ def load_model(self):
106
+ logger.info(
107
+ f"[gpu_id={self.gpu_id}] Load weight begin. "
108
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
109
+ )
110
+
111
+ device_config = DeviceConfig()
112
+ load_config = LoadConfig(load_format=self.server_args.load_format)
113
+ vllm_model_config = VllmModelConfig(
114
+ model=self.server_args.model_path,
115
+ quantization=self.server_args.quantization,
116
+ tokenizer=None,
117
+ tokenizer_mode=None,
118
+ trust_remote_code=self.server_args.trust_remote_code,
119
+ dtype=self.server_args.dtype,
120
+ seed=42,
121
+ skip_tokenizer_init=True,
122
+ )
123
+ self.dtype = vllm_model_config.dtype
124
+ if self.model_config.model_overide_args is not None:
125
+ vllm_model_config.hf_config.update(self.model_config.model_overide_args)
126
+
127
+ self.model = get_model(
128
+ model_config=vllm_model_config,
129
+ device_config=device_config,
130
+ load_config=load_config,
131
+ lora_config=None,
132
+ multimodal_config=None,
133
+ parallel_config=None,
134
+ scheduler_config=None,
135
+ cache_config=None,
136
+ )
137
+ logger.info(
138
+ f"[gpu_id={self.gpu_id}] Load weight end. "
139
+ f"type={type(self.model).__name__}, "
140
+ f"dtype={self.dtype}, "
141
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
142
+ )
143
+
144
+ def profile_max_num_token(self, total_gpu_memory):
145
+ available_gpu_memory = get_available_gpu_memory(
146
+ self.gpu_id, distributed=self.tp_size > 1
147
+ )
148
+ head_dim = self.model_config.head_dim
149
+ head_num = self.model_config.get_num_kv_heads(self.tp_size)
150
+ cell_size = (
151
+ head_num
152
+ * head_dim
153
+ * self.model_config.num_hidden_layers
154
+ * 2
155
+ * torch._utils._element_size(self.dtype)
156
+ )
157
+ rest_memory = available_gpu_memory - total_gpu_memory * (
158
+ 1 - self.mem_fraction_static
159
+ )
160
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
161
+ return max_num_token
162
+
163
+ def init_memory_pool(self, total_gpu_memory):
164
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
165
+
166
+ if self.max_total_num_tokens <= 0:
167
+ raise RuntimeError(
168
+ "Not enough memory. Please try to increase --mem-fraction-static."
169
+ )
170
+
171
+ self.req_to_token_pool = ReqToTokenPool(
172
+ int(self.max_total_num_tokens / self.model_config.context_len * 256),
173
+ self.model_config.context_len + 8,
174
+ )
175
+ self.token_to_kv_pool = TokenToKVPool(
176
+ self.max_total_num_tokens,
177
+ dtype=self.dtype,
178
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
179
+ head_dim=self.model_config.head_dim,
180
+ layer_num=self.model_config.num_hidden_layers,
181
+ )
182
+ logger.info(
183
+ f"[gpu_id={self.gpu_id}] Memory pool end. "
184
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
185
+ )
186
+
187
+ def init_cublas(self):
188
+ """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
189
+ dtype = torch.float16
190
+ device = "cuda"
191
+ a = torch.ones((16, 16), dtype=dtype, device=device)
192
+ b = torch.ones((16, 16), dtype=dtype, device=device)
193
+ c = a @ b
194
+ return c
195
+
196
+ def init_flash_infer(self):
197
+ if self.server_args.disable_flashinfer:
198
+ self.flashinfer_prefill_wrapper_ragged = None
199
+ self.flashinfer_prefill_wrapper_paged = None
200
+ self.flashinfer_decode_wrapper = None
201
+ return
202
+
203
+ from flashinfer import (
204
+ BatchDecodeWithPagedKVCacheWrapper,
205
+ BatchPrefillWithPagedKVCacheWrapper,
206
+ BatchPrefillWithRaggedKVCacheWrapper,
207
+ )
208
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
209
+
210
+ if not _grouped_size_compiled_for_decode_kernels(
211
+ self.model_config.num_attention_heads // self.tp_size,
212
+ self.model_config.get_num_kv_heads(self.tp_size),
213
+ ):
214
+ use_tensor_cores = True
215
+ else:
216
+ use_tensor_cores = False
217
+
218
+ self.flashinfer_workspace_buffers = torch.empty(
219
+ 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
220
+ )
221
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
222
+ self.flashinfer_workspace_buffers[0], "NHD"
223
+ )
224
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
225
+ self.flashinfer_workspace_buffers[1], "NHD"
226
+ )
227
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
228
+ self.flashinfer_workspace_buffers[0],
229
+ "NHD",
230
+ use_tensor_cores=use_tensor_cores,
231
+ )
232
+
233
+ def init_cuda_graphs(self):
234
+ from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
235
+
236
+ if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
237
+ self.cuda_graph_runner = None
238
+ return
239
+
240
+ logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
241
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
242
+ self.cuda_graph_runner = CudaGraphRunner(
243
+ self, max_batch_size_to_capture=max(batch_size_list)
244
+ )
245
+ self.cuda_graph_runner.capture(batch_size_list)
246
+
247
+ @torch.inference_mode()
248
+ def forward_decode(self, batch: Batch):
249
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
250
+ return self.cuda_graph_runner.replay(batch)
251
+
252
+ input_metadata = InputMetadata.create(
253
+ self,
254
+ forward_mode=ForwardMode.DECODE,
255
+ req_pool_indices=batch.req_pool_indices,
256
+ seq_lens=batch.seq_lens,
257
+ prefix_lens=batch.prefix_lens,
258
+ position_ids_offsets=batch.position_ids_offsets,
259
+ out_cache_loc=batch.out_cache_loc,
260
+ top_logprobs_nums=batch.top_logprobs_nums,
261
+ return_logprob=batch.return_logprob,
262
+ )
263
+ return self.model.forward(
264
+ batch.input_ids, input_metadata.positions, input_metadata
265
+ )
266
+
267
+ @torch.inference_mode()
268
+ def forward_extend(self, batch: Batch):
269
+ input_metadata = InputMetadata.create(
270
+ self,
271
+ forward_mode=ForwardMode.EXTEND,
272
+ req_pool_indices=batch.req_pool_indices,
273
+ seq_lens=batch.seq_lens,
274
+ prefix_lens=batch.prefix_lens,
275
+ position_ids_offsets=batch.position_ids_offsets,
276
+ out_cache_loc=batch.out_cache_loc,
277
+ top_logprobs_nums=batch.top_logprobs_nums,
278
+ return_logprob=batch.return_logprob,
279
+ )
280
+ return self.model.forward(
281
+ batch.input_ids, input_metadata.positions, input_metadata
282
+ )
283
+
284
+ @torch.inference_mode()
285
+ def forward_extend_multi_modal(self, batch: Batch):
286
+ input_metadata = InputMetadata.create(
287
+ self,
288
+ forward_mode=ForwardMode.EXTEND,
289
+ req_pool_indices=batch.req_pool_indices,
290
+ seq_lens=batch.seq_lens,
291
+ prefix_lens=batch.prefix_lens,
292
+ position_ids_offsets=batch.position_ids_offsets,
293
+ out_cache_loc=batch.out_cache_loc,
294
+ return_logprob=batch.return_logprob,
295
+ top_logprobs_nums=batch.top_logprobs_nums,
296
+ )
297
+ return self.model.forward(
298
+ batch.input_ids,
299
+ input_metadata.positions,
300
+ input_metadata,
301
+ batch.pixel_values,
302
+ batch.image_sizes,
303
+ batch.image_offsets,
304
+ )
305
+
306
+ def forward(self, batch: Batch, forward_mode: ForwardMode):
307
+ if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
308
+ return self.forward_extend_multi_modal(batch)
309
+ elif forward_mode == ForwardMode.DECODE:
310
+ return self.forward_decode(batch)
311
+ elif forward_mode == ForwardMode.EXTEND:
312
+ return self.forward_extend(batch)
313
+ else:
314
+ raise ValueError(f"Invaid forward mode: {forward_mode}")
315
+
316
+
317
+ @lru_cache()
318
+ def import_model_classes():
319
+ model_arch_name_to_cls = {}
320
+ package_name = "sglang.srt.models"
321
+ package = importlib.import_module(package_name)
322
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
323
+ if not ispkg:
324
+ module = importlib.import_module(name)
325
+ if hasattr(module, "EntryClass"):
326
+ entry = module.EntryClass
327
+ if isinstance(
328
+ entry, list
329
+ ): # To support multiple model classes in one module
330
+ for tmp in entry:
331
+ model_arch_name_to_cls[tmp.__name__] = tmp
332
+ else:
333
+ model_arch_name_to_cls[entry.__name__] = entry
334
+
335
+ # compat: some models such as chatglm has incorrect class set in config.json
336
+ # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
337
+ if hasattr(module, "EntryClassRemapping") and isinstance(
338
+ module.EntryClassRemapping, list
339
+ ):
340
+ for remap in module.EntryClassRemapping:
341
+ if isinstance(remap, tuple) and len(remap) == 2:
342
+ model_arch_name_to_cls[remap[0]] = remap[1]
343
+
344
+ return model_arch_name_to_cls
345
+
346
+
347
+ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
348
+ model_arch_name_to_cls = import_model_classes()
349
+
350
+ if model_arch not in model_arch_name_to_cls:
351
+ raise ValueError(
352
+ f"Unsupported architectures: {model_arch}. "
353
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
354
+ )
355
+ return model_arch_name_to_cls[model_arch]
356
+
357
+
358
+ # Monkey patch model loader
359
+ setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
@@ -1,8 +1,10 @@
1
+ """
2
+ The radix tree data structure for managing the KV cache.
3
+ """
4
+
1
5
  import heapq
2
6
  import time
3
7
  from collections import defaultdict
4
- from dataclasses import dataclass
5
- from typing import Tuple
6
8
 
7
9
  import torch
8
10
 
@@ -11,34 +13,38 @@ class TreeNode:
11
13
  def __init__(self):
12
14
  self.children = defaultdict(TreeNode)
13
15
  self.parent = None
16
+ self.key = None
14
17
  self.value = None
15
- self.ref_counter = 0
18
+ self.lock_ref = 0
16
19
  self.last_access_time = time.time()
17
20
 
18
- def __lt__(self, other):
21
+ def __lt__(self, other: "TreeNode"):
19
22
  return self.last_access_time < other.last_access_time
20
23
 
21
24
 
22
- def match(key, seq):
25
+ def _key_match(key0, key1):
23
26
  i = 0
24
- for k, w in zip(key, seq):
25
- if k != w:
27
+ for k0, k1 in zip(key0, key1):
28
+ if k0 != k1:
26
29
  break
27
30
  i += 1
28
31
  return i
29
32
 
30
33
 
31
34
  class RadixCache:
32
- def __init__(self, disable=False):
33
- self.reset()
35
+ def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
36
+ self.req_to_token_pool = req_to_token_pool
37
+ self.token_to_kv_pool = token_to_kv_pool
34
38
  self.disable = disable
39
+ self.reset()
35
40
 
36
41
  ##### Public API #####
37
42
 
38
43
  def reset(self):
39
44
  self.root_node = TreeNode()
45
+ self.root_node.key = []
40
46
  self.root_node.value = []
41
- self.root_node.ref_counter = 1
47
+ self.root_node.lock_ref = 1
42
48
  self.evictable_size_ = 0
43
49
 
44
50
  def match_prefix(self, key):
@@ -50,16 +56,52 @@ class RadixCache:
50
56
  self._match_prefix_helper(self.root_node, key, value, last_node)
51
57
  if value:
52
58
  value = torch.concat(value)
59
+ else:
60
+ value = torch.tensor([], dtype=torch.int64)
53
61
  return value, last_node[0]
54
62
 
55
63
  def insert(self, key, value=None):
56
64
  if self.disable:
57
- return len(key)
65
+ return 0
58
66
 
59
67
  if value is None:
60
68
  value = [x for x in key]
61
69
  return self._insert_helper(self.root_node, key, value)
62
70
 
71
+ def cache_req(
72
+ self,
73
+ token_ids,
74
+ last_uncached_pos,
75
+ req_pool_idx,
76
+ del_in_memory_pool=True,
77
+ old_last_node=None,
78
+ ):
79
+ # Insert the request into radix cache
80
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
81
+ new_prefix_len = self.insert(token_ids, indices.clone())
82
+
83
+ if self.disable:
84
+ if del_in_memory_pool:
85
+ self.token_to_kv_pool.free(indices)
86
+ else:
87
+ return torch.tensor([], dtype=torch.int64), self.root_node
88
+
89
+ # Radix Cache takes one ref in memory pool
90
+ self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
91
+
92
+ if del_in_memory_pool:
93
+ self.req_to_token_pool.free(req_pool_idx)
94
+ else:
95
+ cached_indices, new_last_node = self.match_prefix(token_ids)
96
+ assert len(cached_indices) == len(token_ids)
97
+
98
+ self.req_to_token_pool.req_to_token[
99
+ req_pool_idx, last_uncached_pos : len(cached_indices)
100
+ ] = cached_indices[last_uncached_pos:]
101
+ self.dec_lock_ref(old_last_node)
102
+ self.inc_lock_ref(new_last_node)
103
+ return cached_indices, new_last_node
104
+
63
105
  def pretty_print(self):
64
106
  self._print_helper(self.root_node, 0)
65
107
  print(f"#tokens: {self.total_size()}")
@@ -69,7 +111,7 @@ class RadixCache:
69
111
 
70
112
  def evict(self, num_tokens, evict_callback):
71
113
  if self.disable:
72
- raise RuntimeError()
114
+ return
73
115
 
74
116
  leaves = self._collect_leaves()
75
117
  heapq.heapify(leaves)
@@ -80,32 +122,33 @@ class RadixCache:
80
122
 
81
123
  if x == self.root_node:
82
124
  break
83
- if x.ref_counter > 0:
125
+ if x.lock_ref > 0:
84
126
  continue
85
127
 
86
- num_evicted += evict_callback(x.value)
128
+ evict_callback(x.value)
129
+ num_evicted += len(x.value)
87
130
  self._delete_leaf(x)
88
131
 
89
132
  if len(x.parent.children) == 0:
90
133
  heapq.heappush(leaves, x.parent)
91
134
 
92
- def inc_ref_counter(self, node):
135
+ def inc_lock_ref(self, node: TreeNode):
93
136
  delta = 0
94
137
  while node != self.root_node:
95
- if node.ref_counter == 0:
138
+ if node.lock_ref == 0:
96
139
  self.evictable_size_ -= len(node.value)
97
140
  delta -= len(node.value)
98
- node.ref_counter += 1
141
+ node.lock_ref += 1
99
142
  node = node.parent
100
143
  return delta
101
144
 
102
- def dec_ref_counter(self, node):
145
+ def dec_lock_ref(self, node: TreeNode):
103
146
  delta = 0
104
147
  while node != self.root_node:
105
- if node.ref_counter == 1:
148
+ if node.lock_ref == 1:
106
149
  self.evictable_size_ += len(node.value)
107
150
  delta += len(node.value)
108
- node.ref_counter -= 1
151
+ node.lock_ref -= 1
109
152
  node = node.parent
110
153
  return delta
111
154
 
@@ -113,42 +156,48 @@ class RadixCache:
113
156
  return self.evictable_size_
114
157
 
115
158
  ##### Internal Helper Functions #####
159
+
116
160
  def _match_prefix_helper(self, node, key, value, last_node):
117
161
  node.last_access_time = time.time()
118
-
119
- for c_key, child in node.children.items():
120
- prefix_len = match(c_key, key)
121
- if prefix_len != 0:
122
- if prefix_len < len(c_key):
123
- new_node = self._split_node(c_key, child, prefix_len)
124
- value.append(new_node.value)
125
- last_node[0] = new_node
126
- else:
127
- value.append(child.value)
128
- last_node[0] = child
129
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
130
- break
131
-
132
- def _split_node(self, key, child, split_len):
162
+ if len(key) == 0:
163
+ return
164
+
165
+ if key[0] in node.children.keys():
166
+ child = node.children[key[0]]
167
+ prefix_len = _key_match(child.key, key)
168
+ if prefix_len < len(child.key):
169
+ new_node = self._split_node(child.key, child, prefix_len)
170
+ value.append(new_node.value)
171
+ last_node[0] = new_node
172
+ else:
173
+ value.append(child.value)
174
+ last_node[0] = child
175
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
176
+
177
+ def _split_node(self, key, child: TreeNode, split_len):
133
178
  # new_node -> child
134
179
  new_node = TreeNode()
135
- new_node.children = {key[split_len:]: child}
180
+ new_node.children = {key[split_len:][0]: child}
136
181
  new_node.parent = child.parent
137
- new_node.ref_counter = child.ref_counter
182
+ new_node.lock_ref = child.lock_ref
183
+ new_node.key = child.key[:split_len]
138
184
  new_node.value = child.value[:split_len]
139
185
  child.parent = new_node
186
+ child.key = child.key[split_len:]
140
187
  child.value = child.value[split_len:]
141
- new_node.parent.children[key[:split_len]] = new_node
142
- del new_node.parent.children[key]
188
+ new_node.parent.children[key[:split_len][0]] = new_node
143
189
  return new_node
144
190
 
145
191
  def _insert_helper(self, node, key, value):
146
192
  node.last_access_time = time.time()
193
+ if len(key) == 0:
194
+ return 0
147
195
 
148
- for c_key, child in node.children.items():
149
- prefix_len = match(c_key, key)
196
+ if key[0] in node.children.keys():
197
+ child = node.children[key[0]]
198
+ prefix_len = _key_match(child.key, key)
150
199
 
151
- if prefix_len == len(c_key):
200
+ if prefix_len == len(child.key):
152
201
  if prefix_len == len(key):
153
202
  return prefix_len
154
203
  else:
@@ -156,23 +205,23 @@ class RadixCache:
156
205
  value = value[prefix_len:]
157
206
  return prefix_len + self._insert_helper(child, key, value)
158
207
 
159
- if prefix_len:
160
- new_node = self._split_node(c_key, child, prefix_len)
161
- return prefix_len + self._insert_helper(
162
- new_node, key[prefix_len:], value[prefix_len:]
163
- )
208
+ new_node = self._split_node(child.key, child, prefix_len)
209
+ return prefix_len + self._insert_helper(
210
+ new_node, key[prefix_len:], value[prefix_len:]
211
+ )
164
212
 
165
213
  if len(key):
166
214
  new_node = TreeNode()
167
215
  new_node.parent = node
216
+ new_node.key = key
168
217
  new_node.value = value
169
- node.children[key] = new_node
218
+ node.children[key[0]] = new_node
170
219
  self.evictable_size_ += len(value)
171
220
  return 0
172
221
 
173
- def _print_helper(self, node, indent):
174
- for key, child in node.children.items():
175
- print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
222
+ def _print_helper(self, node: TreeNode, indent):
223
+ for _, child in node.children.items():
224
+ print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
176
225
  self._print_helper(child, indent=indent + 2)
177
226
 
178
227
  def _delete_leaf(self, node):
@@ -180,7 +229,7 @@ class RadixCache:
180
229
  if v == node:
181
230
  break
182
231
  del node.parent.children[k]
183
- self.evictable_size_ -= len(k)
232
+ self.evictable_size_ -= len(node.key)
184
233
 
185
234
  def _total_size_helper(self, node):
186
235
  x = len(node.value)
@@ -203,7 +252,7 @@ class RadixCache:
203
252
 
204
253
 
205
254
  if __name__ == "__main__":
206
- tree = RadixCache(disable=False)
255
+ tree = RadixCache(None, None, False)
207
256
 
208
257
  tree.insert("Hello")
209
258
  tree.insert("Hello")