sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,445 +0,0 @@
1
- import importlib
2
- import importlib.resources
3
- import logging
4
- import pkgutil
5
- from dataclasses import dataclass
6
- from functools import lru_cache
7
- from typing import List, Optional, Type
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- from vllm.config import DeviceConfig, LoadConfig
13
- from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import initialize_model_parallel
15
- from vllm.model_executor.model_loader import get_model
16
- from vllm.model_executor.models import ModelRegistry
17
-
18
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
- from sglang.srt.server_args import ServerArgs
21
- from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
22
-
23
-
24
- logger = logging.getLogger("model_runner")
25
-
26
- # for server args in model endpoints
27
- global_server_args_dict = {}
28
-
29
-
30
- @dataclass
31
- class InputMetadata:
32
- model_runner: "ModelRunner"
33
- forward_mode: ForwardMode
34
- batch_size: int
35
- total_num_tokens: int
36
- max_seq_len: int
37
- req_pool_indices: torch.Tensor
38
- start_loc: torch.Tensor
39
- seq_lens: torch.Tensor
40
- prefix_lens: torch.Tensor
41
- positions: torch.Tensor
42
- req_to_token_pool: ReqToTokenPool
43
- token_to_kv_pool: TokenToKVPool
44
-
45
- # for extend
46
- extend_seq_lens: torch.Tensor = None
47
- extend_start_loc: torch.Tensor = None
48
- max_extend_len: int = 0
49
-
50
- out_cache_loc: torch.Tensor = None
51
- out_cache_cont_start: torch.Tensor = None
52
- out_cache_cont_end: torch.Tensor = None
53
-
54
- other_kv_index: torch.Tensor = None
55
- return_logprob: bool = False
56
- top_logprobs_nums: List[int] = None
57
-
58
- # for flashinfer
59
- qo_indptr: torch.Tensor = None
60
- kv_indptr: torch.Tensor = None
61
- kv_indices: torch.Tensor = None
62
- kv_last_page_len: torch.Tensor = None
63
- prefill_wrapper = None
64
- decode_wrapper = None
65
-
66
- def init_flashinfer_args(self, tp_size):
67
- from flashinfer import (
68
- BatchDecodeWithPagedKVCacheWrapper,
69
- BatchPrefillWithPagedKVCacheWrapper,
70
- )
71
-
72
- self.kv_indptr = torch.zeros(
73
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
74
- )
75
- self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
76
- self.kv_last_page_len = torch.ones(
77
- (self.batch_size,), dtype=torch.int32, device="cuda"
78
- )
79
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
80
- seq_lens_cpu = self.seq_lens.cpu().numpy()
81
- self.kv_indices = torch.cat(
82
- [
83
- self.req_to_token_pool.req_to_token[
84
- req_pool_indices_cpu[i], : seq_lens_cpu[i]
85
- ]
86
- for i in range(self.batch_size)
87
- ],
88
- dim=0,
89
- ).contiguous()
90
-
91
- workspace_buffer = torch.empty(
92
- 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
93
- )
94
- if (
95
- self.forward_mode == ForwardMode.PREFILL
96
- or self.forward_mode == ForwardMode.EXTEND
97
- ):
98
- self.qo_indptr = torch.zeros(
99
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
100
- )
101
- self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
102
- self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
103
- workspace_buffer, "NHD"
104
- )
105
- args = [
106
- self.qo_indptr,
107
- self.kv_indptr,
108
- self.kv_indices,
109
- self.kv_last_page_len,
110
- self.model_runner.model_config.num_attention_heads // tp_size,
111
- self.model_runner.model_config.num_key_value_heads // tp_size,
112
- self.model_runner.model_config.head_dim,
113
- ]
114
-
115
- self.prefill_wrapper.begin_forward(*args)
116
- else:
117
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118
- workspace_buffer, "NHD"
119
- )
120
- self.decode_wrapper.begin_forward(
121
- self.kv_indptr,
122
- self.kv_indices,
123
- self.kv_last_page_len,
124
- self.model_runner.model_config.num_attention_heads // tp_size,
125
- self.model_runner.model_config.num_key_value_heads // tp_size,
126
- self.model_runner.model_config.head_dim,
127
- 1,
128
- "NONE",
129
- "float16",
130
- )
131
-
132
- def init_extend_args(self):
133
- self.extend_seq_lens = self.seq_lens - self.prefix_lens
134
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
135
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
136
- self.max_extend_len = int(torch.max(self.extend_seq_lens))
137
-
138
- @classmethod
139
- def create(
140
- cls,
141
- model_runner,
142
- tp_size,
143
- forward_mode,
144
- req_pool_indices,
145
- seq_lens,
146
- prefix_lens,
147
- position_ids_offsets,
148
- out_cache_loc,
149
- out_cache_cont_start=None,
150
- out_cache_cont_end=None,
151
- top_logprobs_nums=None,
152
- return_logprob=False,
153
- ):
154
- batch_size = len(req_pool_indices)
155
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
156
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
157
- total_num_tokens = int(torch.sum(seq_lens))
158
- max_seq_len = int(torch.max(seq_lens))
159
-
160
- if forward_mode == ForwardMode.DECODE:
161
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
162
- other_kv_index = model_runner.req_to_token_pool.req_to_token[
163
- req_pool_indices[0], seq_lens[0] - 1
164
- ].item()
165
- else:
166
- seq_lens_cpu = seq_lens.cpu().numpy()
167
- prefix_lens_cpu = prefix_lens.cpu().numpy()
168
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
169
- positions = torch.tensor(
170
- np.concatenate(
171
- [
172
- np.arange(
173
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
174
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
175
- )
176
- for i in range(batch_size)
177
- ],
178
- axis=0,
179
- ),
180
- device="cuda",
181
- )
182
- other_kv_index = None
183
-
184
- ret = cls(
185
- model_runner=model_runner,
186
- forward_mode=forward_mode,
187
- batch_size=batch_size,
188
- total_num_tokens=total_num_tokens,
189
- max_seq_len=max_seq_len,
190
- req_pool_indices=req_pool_indices,
191
- start_loc=start_loc,
192
- seq_lens=seq_lens,
193
- prefix_lens=prefix_lens,
194
- positions=positions,
195
- req_to_token_pool=model_runner.req_to_token_pool,
196
- token_to_kv_pool=model_runner.token_to_kv_pool,
197
- out_cache_loc=out_cache_loc,
198
- out_cache_cont_start=out_cache_cont_start,
199
- out_cache_cont_end=out_cache_cont_end,
200
- other_kv_index=other_kv_index,
201
- return_logprob=return_logprob,
202
- top_logprobs_nums=top_logprobs_nums,
203
- )
204
-
205
- if forward_mode == ForwardMode.EXTEND:
206
- ret.init_extend_args()
207
-
208
- if global_server_args_dict.get("enable_flashinfer", False):
209
- ret.init_flashinfer_args(tp_size)
210
-
211
- return ret
212
-
213
-
214
- class ModelRunner:
215
- def __init__(
216
- self,
217
- model_config,
218
- mem_fraction_static,
219
- tp_rank,
220
- tp_size,
221
- nccl_port,
222
- server_args: ServerArgs,
223
- ):
224
- self.model_config = model_config
225
- self.mem_fraction_static = mem_fraction_static
226
- self.tp_rank = tp_rank
227
- self.tp_size = tp_size
228
- self.nccl_port = nccl_port
229
- self.server_args = server_args
230
-
231
- global global_server_args_dict
232
- global_server_args_dict = {
233
- "enable_flashinfer": server_args.enable_flashinfer,
234
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
235
- }
236
-
237
- # Init torch distributed
238
- logger.info(f"[rank={self.tp_rank}] Set cuda device.")
239
- torch.cuda.set_device(self.tp_rank)
240
- logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
241
- torch.distributed.init_process_group(
242
- backend="nccl",
243
- world_size=self.tp_size,
244
- rank=self.tp_rank,
245
- init_method=f"tcp://127.0.0.1:{self.nccl_port}",
246
- )
247
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
248
- logger.info(f"[rank={self.tp_rank}] Init torch end.")
249
-
250
- total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
251
-
252
- if self.tp_size > 1:
253
- total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
254
- if total_local_gpu_memory < total_gpu_memory * 0.9:
255
- raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
256
-
257
- self.load_model()
258
- self.init_memory_pool(total_gpu_memory)
259
-
260
- self.is_multimodal_model = is_multimodal_model(self.model_config)
261
-
262
- def load_model(self):
263
- logger.info(f"[rank={self.tp_rank}] Load weight begin.")
264
-
265
- device_config = DeviceConfig()
266
- load_config = LoadConfig(load_format=self.server_args.load_format)
267
- vllm_model_config = VllmModelConfig(
268
- model=self.server_args.model_path,
269
- quantization=self.server_args.quantization,
270
- tokenizer=None,
271
- tokenizer_mode=None,
272
- trust_remote_code=self.server_args.trust_remote_code,
273
- dtype=torch.float16,
274
- seed=42,
275
- skip_tokenizer_init=True,
276
- )
277
- if self.model_config.model_overide_args is not None:
278
- vllm_model_config.hf_config.update(self.model_config.model_overide_args)
279
-
280
- self.model = get_model(
281
- model_config=vllm_model_config,
282
- device_config=device_config,
283
- load_config=load_config,
284
- lora_config=None,
285
- vision_language_config=None,
286
- parallel_config=None,
287
- scheduler_config=None,
288
- )
289
- logger.info(f"[rank={self.tp_rank}] Load weight end. "
290
- f"Type={type(self.model).__name__}. "
291
- f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
292
-
293
- def profile_max_num_token(self, total_gpu_memory):
294
- available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
295
- head_dim = self.model_config.head_dim
296
- head_num = self.model_config.num_key_value_heads // self.tp_size
297
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
298
- rest_memory = available_gpu_memory - total_gpu_memory * (
299
- 1 - self.mem_fraction_static
300
- )
301
- max_num_token = int(rest_memory * (1 << 30) // cell_size)
302
- return max_num_token
303
-
304
- def init_memory_pool(self, total_gpu_memory):
305
- self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
306
-
307
- if self.max_total_num_tokens <= 0:
308
- raise RuntimeError(
309
- "Not enought memory. " "Please try to increase --mem-fraction-static."
310
- )
311
-
312
- self.req_to_token_pool = ReqToTokenPool(
313
- int(self.max_total_num_tokens / self.model_config.context_len * 256),
314
- self.model_config.context_len + 8,
315
- )
316
- self.token_to_kv_pool = TokenToKVPool(
317
- self.max_total_num_tokens,
318
- dtype=torch.float16,
319
- head_num=self.model_config.num_key_value_heads // self.tp_size,
320
- head_dim=self.model_config.head_dim,
321
- layer_num=self.model_config.num_hidden_layers,
322
- )
323
-
324
- @torch.inference_mode()
325
- def forward_prefill(self, batch: Batch):
326
- input_metadata = InputMetadata.create(
327
- self,
328
- forward_mode=ForwardMode.PREFILL,
329
- tp_size=self.tp_size,
330
- req_pool_indices=batch.req_pool_indices,
331
- seq_lens=batch.seq_lens,
332
- prefix_lens=batch.prefix_lens,
333
- position_ids_offsets=batch.position_ids_offsets,
334
- out_cache_loc=batch.out_cache_loc,
335
- top_logprobs_nums=batch.top_logprobs_nums,
336
- return_logprob=batch.return_logprob,
337
- )
338
- return self.model.forward(
339
- batch.input_ids, input_metadata.positions, input_metadata
340
- )
341
-
342
- @torch.inference_mode()
343
- def forward_extend(self, batch: Batch):
344
- input_metadata = InputMetadata.create(
345
- self,
346
- forward_mode=ForwardMode.EXTEND,
347
- tp_size=self.tp_size,
348
- req_pool_indices=batch.req_pool_indices,
349
- seq_lens=batch.seq_lens,
350
- prefix_lens=batch.prefix_lens,
351
- position_ids_offsets=batch.position_ids_offsets,
352
- out_cache_loc=batch.out_cache_loc,
353
- top_logprobs_nums=batch.top_logprobs_nums,
354
- return_logprob=batch.return_logprob,
355
- )
356
- return self.model.forward(
357
- batch.input_ids, input_metadata.positions, input_metadata
358
- )
359
-
360
- @torch.inference_mode()
361
- def forward_decode(self, batch: Batch):
362
- input_metadata = InputMetadata.create(
363
- self,
364
- forward_mode=ForwardMode.DECODE,
365
- tp_size=self.tp_size,
366
- req_pool_indices=batch.req_pool_indices,
367
- seq_lens=batch.seq_lens,
368
- prefix_lens=batch.prefix_lens,
369
- position_ids_offsets=batch.position_ids_offsets,
370
- out_cache_loc=batch.out_cache_loc,
371
- out_cache_cont_start=batch.out_cache_cont_start,
372
- out_cache_cont_end=batch.out_cache_cont_end,
373
- top_logprobs_nums=batch.top_logprobs_nums,
374
- return_logprob=batch.return_logprob,
375
- )
376
- return self.model.forward(
377
- batch.input_ids, input_metadata.positions, input_metadata
378
- )
379
-
380
- @torch.inference_mode()
381
- def forward_extend_multi_modal(self, batch: Batch):
382
- input_metadata = InputMetadata.create(
383
- self,
384
- forward_mode=ForwardMode.EXTEND,
385
- tp_size=self.tp_size,
386
- req_pool_indices=batch.req_pool_indices,
387
- seq_lens=batch.seq_lens,
388
- prefix_lens=batch.prefix_lens,
389
- position_ids_offsets=batch.position_ids_offsets,
390
- out_cache_loc=batch.out_cache_loc,
391
- top_logprobs_nums=batch.top_logprobs_nums,
392
- return_logprob=batch.return_logprob,
393
- )
394
- return self.model.forward(
395
- batch.input_ids,
396
- input_metadata.positions,
397
- input_metadata,
398
- batch.pixel_values,
399
- batch.image_sizes,
400
- batch.image_offsets,
401
- )
402
-
403
- def forward(self, batch: Batch, forward_mode: ForwardMode):
404
- if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
405
- return self.forward_extend_multi_modal(batch)
406
- elif forward_mode == ForwardMode.DECODE:
407
- return self.forward_decode(batch)
408
- elif forward_mode == ForwardMode.EXTEND:
409
- return self.forward_extend(batch)
410
- elif forward_mode == ForwardMode.PREFILL:
411
- return self.forward_prefill(batch)
412
- else:
413
- raise ValueError(f"Invaid forward mode: {forward_mode}")
414
-
415
-
416
- @lru_cache()
417
- def import_model_classes():
418
- model_arch_name_to_cls = {}
419
- package_name = "sglang.srt.models"
420
- package = importlib.import_module(package_name)
421
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
422
- if not ispkg:
423
- module = importlib.import_module(name)
424
- if hasattr(module, "EntryClass"):
425
- entry = module.EntryClass
426
- if isinstance(entry, list): # To support multiple model classes in one module
427
- for cls in entry:
428
- model_arch_name_to_cls[cls.__name__] = cls
429
- else:
430
- model_arch_name_to_cls[entry.__name__] = entry
431
- return model_arch_name_to_cls
432
-
433
-
434
- def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
435
- model_arch_name_to_cls = import_model_classes()
436
- if model_arch not in model_arch_name_to_cls:
437
- raise ValueError(
438
- f"Unsupported architectures: {model_arch}. "
439
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
440
- )
441
- return model_arch_name_to_cls[model_arch]
442
-
443
-
444
- # Monkey patch model loader
445
- setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
@@ -1,267 +0,0 @@
1
- import heapq
2
- import time
3
- from collections import defaultdict
4
-
5
- import torch
6
-
7
-
8
- class TreeNode:
9
- def __init__(self):
10
- self.children = defaultdict(TreeNode)
11
- self.parent = None
12
- self.key = None
13
- self.value = None
14
- self.lock_ref = 0
15
- self.last_access_time = time.time()
16
-
17
- def __lt__(self, other: "TreeNode"):
18
- return self.last_access_time < other.last_access_time
19
-
20
-
21
- def _key_match(key0, key1):
22
- i = 0
23
- for k0, k1 in zip(key0, key1):
24
- if k0 != k1:
25
- break
26
- i += 1
27
- return i
28
-
29
-
30
- class RadixCache:
31
- def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
32
- self.req_to_token_pool = req_to_token_pool
33
- self.token_to_kv_pool = token_to_kv_pool
34
- self.disable = disable
35
- self.reset()
36
-
37
- ##### Public API #####
38
-
39
- def reset(self):
40
- self.root_node = TreeNode()
41
- self.root_node.key = []
42
- self.root_node.value = []
43
- self.root_node.lock_ref = 1
44
- self.evictable_size_ = 0
45
-
46
- def match_prefix(self, key):
47
- if self.disable:
48
- return [], self.root_node
49
-
50
- value = []
51
- last_node = [self.root_node]
52
- self._match_prefix_helper(self.root_node, key, value, last_node)
53
- if value:
54
- value = torch.concat(value)
55
- else:
56
- value = torch.tensor([], dtype=torch.int64)
57
- return value, last_node[0]
58
-
59
- def insert(self, key, value=None):
60
- if self.disable:
61
- return 0
62
-
63
- if value is None:
64
- value = [x for x in key]
65
- return self._insert_helper(self.root_node, key, value)
66
-
67
- def cache_req(
68
- self,
69
- token_ids,
70
- last_uncached_pos,
71
- req_pool_idx,
72
- del_in_memory_pool=True,
73
- old_last_node=None,
74
- ):
75
- # Insert the request into radix cache
76
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
- new_prefix_len = self.insert(token_ids, indices.clone())
78
-
79
- if self.disable:
80
- if del_in_memory_pool:
81
- self.token_to_kv_pool.dec_refs(indices)
82
- else:
83
- return torch.tensor([], dtype=torch.int64), self.root_node
84
-
85
- # Radix Cache takes one ref in memory pool
86
- self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
87
-
88
- if del_in_memory_pool:
89
- self.req_to_token_pool.free(req_pool_idx)
90
- else:
91
- cached_indices, new_last_node = self.match_prefix(token_ids)
92
- assert len(cached_indices) == len(token_ids)
93
-
94
- self.req_to_token_pool.req_to_token[
95
- req_pool_idx, last_uncached_pos : len(cached_indices)
96
- ] = cached_indices[last_uncached_pos:]
97
- self.dec_lock_ref(old_last_node)
98
- self.inc_lock_ref(new_last_node)
99
- return cached_indices, new_last_node
100
-
101
- def pretty_print(self):
102
- self._print_helper(self.root_node, 0)
103
- print(f"#tokens: {self.total_size()}")
104
-
105
- def total_size(self):
106
- return self._total_size_helper(self.root_node)
107
-
108
- def evict(self, num_tokens, evict_callback):
109
- if self.disable:
110
- return
111
-
112
- leaves = self._collect_leaves()
113
- heapq.heapify(leaves)
114
-
115
- num_evicted = 0
116
- while num_evicted < num_tokens and len(leaves):
117
- x = heapq.heappop(leaves)
118
-
119
- if x == self.root_node:
120
- break
121
- if x.lock_ref > 0:
122
- continue
123
-
124
- num_evicted += evict_callback(x.value)
125
- self._delete_leaf(x)
126
-
127
- if len(x.parent.children) == 0:
128
- heapq.heappush(leaves, x.parent)
129
-
130
- def inc_lock_ref(self, node: TreeNode):
131
- delta = 0
132
- while node != self.root_node:
133
- if node.lock_ref == 0:
134
- self.evictable_size_ -= len(node.value)
135
- delta -= len(node.value)
136
- node.lock_ref += 1
137
- node = node.parent
138
- return delta
139
-
140
- def dec_lock_ref(self, node: TreeNode):
141
- delta = 0
142
- while node != self.root_node:
143
- if node.lock_ref == 1:
144
- self.evictable_size_ += len(node.value)
145
- delta += len(node.value)
146
- node.lock_ref -= 1
147
- node = node.parent
148
- return delta
149
-
150
- def evictable_size(self):
151
- return self.evictable_size_
152
-
153
- ##### Internal Helper Functions #####
154
-
155
- def _match_prefix_helper(self, node, key, value, last_node):
156
- node.last_access_time = time.time()
157
- if len(key) == 0:
158
- return
159
-
160
- if key[0] in node.children.keys():
161
- child = node.children[key[0]]
162
- prefix_len = _key_match(child.key, key)
163
- if prefix_len < len(child.key):
164
- new_node = self._split_node(child.key, child, prefix_len)
165
- value.append(new_node.value)
166
- last_node[0] = new_node
167
- else:
168
- value.append(child.value)
169
- last_node[0] = child
170
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
171
-
172
- def _split_node(self, key, child: TreeNode, split_len):
173
- # new_node -> child
174
- new_node = TreeNode()
175
- new_node.children = {key[split_len:][0]: child}
176
- new_node.parent = child.parent
177
- new_node.lock_ref = child.lock_ref
178
- new_node.key = child.key[:split_len]
179
- new_node.value = child.value[:split_len]
180
- child.parent = new_node
181
- child.key = child.key[split_len:]
182
- child.value = child.value[split_len:]
183
- new_node.parent.children[key[:split_len][0]] = new_node
184
- return new_node
185
-
186
- def _insert_helper(self, node, key, value):
187
- node.last_access_time = time.time()
188
- if len(key) == 0:
189
- return 0
190
-
191
- if key[0] in node.children.keys():
192
- child = node.children[key[0]]
193
- prefix_len = _key_match(child.key, key)
194
-
195
- if prefix_len == len(child.key):
196
- if prefix_len == len(key):
197
- return prefix_len
198
- else:
199
- key = key[prefix_len:]
200
- value = value[prefix_len:]
201
- return prefix_len + self._insert_helper(child, key, value)
202
-
203
- new_node = self._split_node(child.key, child, prefix_len)
204
- return prefix_len + self._insert_helper(
205
- new_node, key[prefix_len:], value[prefix_len:]
206
- )
207
-
208
- if len(key):
209
- new_node = TreeNode()
210
- new_node.parent = node
211
- new_node.key = key
212
- new_node.value = value
213
- node.children[key[0]] = new_node
214
- self.evictable_size_ += len(value)
215
- return 0
216
-
217
- def _print_helper(self, node: TreeNode, indent):
218
- for _, child in node.children.items():
219
- print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
220
- self._print_helper(child, indent=indent + 2)
221
-
222
- def _delete_leaf(self, node):
223
- for k, v in node.parent.children.items():
224
- if v == node:
225
- break
226
- del node.parent.children[k]
227
- self.evictable_size_ -= len(node.key)
228
-
229
- def _total_size_helper(self, node):
230
- x = len(node.value)
231
- for child in node.children.values():
232
- x += self._total_size_helper(child)
233
- return x
234
-
235
- def _collect_leaves(self):
236
- ret_list = []
237
-
238
- def dfs_(cur_node):
239
- if len(cur_node.children) == 0:
240
- ret_list.append(cur_node)
241
-
242
- for x in cur_node.children.values():
243
- dfs_(x)
244
-
245
- dfs_(self.root_node)
246
- return ret_list
247
-
248
-
249
- if __name__ == "__main__":
250
- tree = RadixCache(None, None, False)
251
-
252
- tree.insert("Hello")
253
- tree.insert("Hello")
254
- tree.insert("Hello_L.A.!")
255
- # tree.insert("Hello_world! Happy")
256
- # tree.insert("I love you!")
257
- tree.pretty_print()
258
-
259
- # print(tree.match_prefix("I love you! aha"))
260
-
261
- # def evict_callback(x):
262
- # print("evict", x)
263
- # return len(x)
264
-
265
- # tree.evict(5, evict_callback)
266
- # tree.evict(10, evict_callback)
267
- # tree.pretty_print()