sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,35 @@
1
1
  import importlib
2
- import logging
2
+ import importlib.resources
3
3
  import inspect
4
+ import logging
5
+ import pkgutil
4
6
  from dataclasses import dataclass
5
7
  from functools import lru_cache
6
- from pathlib import Path
7
- import importlib.resources
8
+ from typing import List
8
9
 
9
10
  import numpy as np
10
11
  import torch
11
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
12
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
13
- from sglang.srt.utils import is_multimodal_model
14
- from sglang.utils import get_available_gpu_memory
12
+ from vllm.distributed import initialize_model_parallel
15
13
  from vllm.model_executor.layers.quantization.awq import AWQConfig
16
14
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
17
15
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
18
- from vllm.model_executor.model_loader import _set_default_torch_dtype
19
- from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
16
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
20
17
 
21
- import importlib
22
- import pkgutil
18
+ from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
+ from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
23
21
 
24
- import sglang
25
22
 
26
- QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
23
+ QUANTIZATION_CONFIG_MAPPING = {
24
+ "awq": AWQConfig,
25
+ "gptq": GPTQConfig,
26
+ "marlin": MarlinConfig,
27
+ }
27
28
 
28
29
  logger = logging.getLogger("model_runner")
29
30
 
30
-
31
31
  # for server args in model endpoints
32
- global_server_args_dict: dict = None
32
+ global_server_args_dict = {}
33
33
 
34
34
 
35
35
  @lru_cache()
@@ -37,7 +37,7 @@ def import_model_classes():
37
37
  model_arch_name_to_cls = {}
38
38
  package_name = "sglang.srt.models"
39
39
  package = importlib.import_module(package_name)
40
- for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
40
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
41
  if not ispkg:
42
42
  module = importlib.import_module(name)
43
43
  if hasattr(module, "EntryClass"):
@@ -87,6 +87,7 @@ class InputMetadata:
87
87
 
88
88
  other_kv_index: torch.Tensor = None
89
89
  return_logprob: bool = False
90
+ top_logprobs_nums: List[int] = None
90
91
 
91
92
  # for flashinfer
92
93
  qo_indptr: torch.Tensor = None
@@ -106,18 +107,20 @@ class InputMetadata:
106
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
107
108
  )
108
109
  self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
110
+ self.kv_last_page_len = torch.ones(
111
+ (self.batch_size,), dtype=torch.int32, device="cuda"
112
+ )
113
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
114
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
109
115
  self.kv_indices = torch.cat(
110
116
  [
111
117
  self.req_to_token_pool.req_to_token[
112
- self.req_pool_indices[i].item(), : self.seq_lens[i].item()
118
+ req_pool_indices_cpu[i], : seq_lens_cpu[i]
113
119
  ]
114
120
  for i in range(self.batch_size)
115
121
  ],
116
122
  dim=0,
117
123
  ).contiguous()
118
- self.kv_last_page_len = torch.ones(
119
- (self.batch_size,), dtype=torch.int32, device="cuda"
120
- )
121
124
 
122
125
  workspace_buffer = torch.empty(
123
126
  32 * 1024 * 1024, dtype=torch.int8, device="cuda"
@@ -140,13 +143,9 @@ class InputMetadata:
140
143
  self.kv_last_page_len,
141
144
  self.model_runner.model_config.num_attention_heads // tp_size,
142
145
  self.model_runner.model_config.num_key_value_heads // tp_size,
146
+ self.model_runner.model_config.head_dim,
143
147
  ]
144
148
 
145
- # flashinfer >= 0.0.3
146
- # FIXME: Drop this when flashinfer updates to 0.0.4
147
- if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
148
- args.append(self.model_runner.model_config.head_dim)
149
-
150
149
  self.prefill_wrapper.begin_forward(*args)
151
150
  else:
152
151
  self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
@@ -183,6 +182,7 @@ class InputMetadata:
183
182
  out_cache_loc,
184
183
  out_cache_cont_start=None,
185
184
  out_cache_cont_end=None,
185
+ top_logprobs_nums=None,
186
186
  return_logprob=False,
187
187
  ):
188
188
  batch_size = len(req_pool_indices)
@@ -197,15 +197,15 @@ class InputMetadata:
197
197
  req_pool_indices[0], seq_lens[0] - 1
198
198
  ].item()
199
199
  else:
200
- seq_lens_np = seq_lens.cpu().numpy()
201
- prefix_lens_np = prefix_lens.cpu().numpy()
202
- position_ids_offsets_np = position_ids_offsets.cpu().numpy()
200
+ seq_lens_cpu = seq_lens.cpu().numpy()
201
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
202
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
203
203
  positions = torch.tensor(
204
204
  np.concatenate(
205
205
  [
206
206
  np.arange(
207
- prefix_lens_np[i] + position_ids_offsets_np[i],
208
- seq_lens_np[i] + position_ids_offsets_np[i],
207
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
208
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
209
209
  )
210
210
  for i in range(batch_size)
211
211
  ],
@@ -231,8 +231,9 @@ class InputMetadata:
231
231
  out_cache_loc=out_cache_loc,
232
232
  out_cache_cont_start=out_cache_cont_start,
233
233
  out_cache_cont_end=out_cache_cont_end,
234
- return_logprob=return_logprob,
235
234
  other_kv_index=other_kv_index,
235
+ return_logprob=return_logprob,
236
+ top_logprobs_nums=top_logprobs_nums,
236
237
  )
237
238
 
238
239
  if forward_mode == ForwardMode.EXTEND:
@@ -276,9 +277,6 @@ class ModelRunner:
276
277
  init_method=f"tcp://127.0.0.1:{self.nccl_port}",
277
278
  )
278
279
 
279
- # A small all_reduce for warmup.
280
- if self.tp_size > 1:
281
- torch.distributed.all_reduce(torch.zeros(1).cuda())
282
280
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
283
281
 
284
282
  total_gpu_memory = get_available_gpu_memory(
@@ -297,31 +295,33 @@ class ModelRunner:
297
295
  logger.info(f"Rank {self.tp_rank}: load weight begin.")
298
296
 
299
297
  # Load weights
300
- linear_method = None
301
- with _set_default_torch_dtype(torch.float16):
298
+ quant_config = None
299
+
300
+ quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
+ if quant_cfg is not None:
302
+ quant_method = quant_cfg.get("quant_method", "").lower()
303
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
304
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
305
+ is_format_marlin = quant_cfg.get(
306
+ "checkpoint_format"
307
+ ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
+
309
+ # Use marlin if the GPTQ model is serialized in marlin format.
310
+ if quant_method == "gptq" and is_format_marlin:
311
+ quant_method = "marlin"
312
+
313
+ quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
+
315
+ if quant_config_class is None:
316
+ raise ValueError(f"Unsupported quantization method: {quant_method}")
317
+
318
+ quant_config = quant_config_class.from_config(quant_cfg)
319
+ logger.info(f"quant_config: {quant_config}")
320
+
321
+ with set_default_torch_dtype(torch.float16):
302
322
  with torch.device("cuda"):
303
- hf_quant_config = getattr(
304
- self.model_config.hf_config, "quantization_config", None
305
- )
306
- if hf_quant_config is not None:
307
- hf_quant_method = hf_quant_config["quant_method"]
308
-
309
- # compat: autogptq uses is_marlin_format within quant config
310
- if (hf_quant_method == "gptq"
311
- and "is_marlin_format" in hf_quant_config
312
- and hf_quant_config["is_marlin_format"]):
313
- hf_quant_method = "marlin"
314
- quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
315
-
316
- if quant_config_class is None:
317
- raise ValueError(
318
- f"Unsupported quantization method: {hf_quant_config['quant_method']}"
319
- )
320
- quant_config = quant_config_class.from_config(hf_quant_config)
321
- logger.info(f"quant_config: {quant_config}")
322
- linear_method = quant_config.get_linear_method()
323
323
  model = model_class(
324
- config=self.model_config.hf_config, linear_method=linear_method
324
+ config=self.model_config.hf_config, quant_config=quant_config
325
325
  )
326
326
  model.load_weights(
327
327
  self.model_config.path,
@@ -367,148 +367,92 @@ class ModelRunner:
367
367
  )
368
368
 
369
369
  @torch.inference_mode()
370
- def forward_prefill(
371
- self,
372
- input_ids,
373
- req_pool_indices,
374
- seq_lens,
375
- prefix_lens,
376
- position_ids_offsets,
377
- out_cache_loc,
378
- return_logprob,
379
- ):
370
+ def forward_prefill(self, batch: Batch):
380
371
  input_metadata = InputMetadata.create(
381
372
  self,
382
373
  forward_mode=ForwardMode.PREFILL,
383
374
  tp_size=self.tp_size,
384
- req_pool_indices=req_pool_indices,
385
- seq_lens=seq_lens,
386
- prefix_lens=prefix_lens,
387
- position_ids_offsets=position_ids_offsets,
388
- out_cache_loc=out_cache_loc,
389
- return_logprob=return_logprob,
375
+ req_pool_indices=batch.req_pool_indices,
376
+ seq_lens=batch.seq_lens,
377
+ prefix_lens=batch.prefix_lens,
378
+ position_ids_offsets=batch.position_ids_offsets,
379
+ out_cache_loc=batch.out_cache_loc,
380
+ top_logprobs_nums=batch.top_logprobs_nums,
381
+ return_logprob=batch.return_logprob,
382
+ )
383
+ return self.model.forward(
384
+ batch.input_ids, input_metadata.positions, input_metadata
390
385
  )
391
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
392
386
 
393
387
  @torch.inference_mode()
394
- def forward_extend(
395
- self,
396
- input_ids,
397
- req_pool_indices,
398
- seq_lens,
399
- prefix_lens,
400
- position_ids_offsets,
401
- out_cache_loc,
402
- return_logprob,
403
- ):
388
+ def forward_extend(self, batch: Batch):
404
389
  input_metadata = InputMetadata.create(
405
390
  self,
406
391
  forward_mode=ForwardMode.EXTEND,
407
392
  tp_size=self.tp_size,
408
- req_pool_indices=req_pool_indices,
409
- seq_lens=seq_lens,
410
- prefix_lens=prefix_lens,
411
- position_ids_offsets=position_ids_offsets,
412
- out_cache_loc=out_cache_loc,
413
- return_logprob=return_logprob,
393
+ req_pool_indices=batch.req_pool_indices,
394
+ seq_lens=batch.seq_lens,
395
+ prefix_lens=batch.prefix_lens,
396
+ position_ids_offsets=batch.position_ids_offsets,
397
+ out_cache_loc=batch.out_cache_loc,
398
+ top_logprobs_nums=batch.top_logprobs_nums,
399
+ return_logprob=batch.return_logprob,
400
+ )
401
+ return self.model.forward(
402
+ batch.input_ids, input_metadata.positions, input_metadata
414
403
  )
415
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
416
404
 
417
405
  @torch.inference_mode()
418
- def forward_decode(
419
- self,
420
- input_ids,
421
- req_pool_indices,
422
- seq_lens,
423
- prefix_lens,
424
- position_ids_offsets,
425
- out_cache_loc,
426
- out_cache_cont_start,
427
- out_cache_cont_end,
428
- return_logprob,
429
- ):
406
+ def forward_decode(self, batch: Batch):
430
407
  input_metadata = InputMetadata.create(
431
408
  self,
432
409
  forward_mode=ForwardMode.DECODE,
433
410
  tp_size=self.tp_size,
434
- req_pool_indices=req_pool_indices,
435
- seq_lens=seq_lens,
436
- prefix_lens=prefix_lens,
437
- position_ids_offsets=position_ids_offsets,
438
- out_cache_loc=out_cache_loc,
439
- out_cache_cont_start=out_cache_cont_start,
440
- out_cache_cont_end=out_cache_cont_end,
441
- return_logprob=return_logprob,
411
+ req_pool_indices=batch.req_pool_indices,
412
+ seq_lens=batch.seq_lens,
413
+ prefix_lens=batch.prefix_lens,
414
+ position_ids_offsets=batch.position_ids_offsets,
415
+ out_cache_loc=batch.out_cache_loc,
416
+ out_cache_cont_start=batch.out_cache_cont_start,
417
+ out_cache_cont_end=batch.out_cache_cont_end,
418
+ top_logprobs_nums=batch.top_logprobs_nums,
419
+ return_logprob=batch.return_logprob,
420
+ )
421
+ return self.model.forward(
422
+ batch.input_ids, input_metadata.positions, input_metadata
442
423
  )
443
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
444
424
 
445
425
  @torch.inference_mode()
446
- def forward_extend_multi_modal(
447
- self,
448
- input_ids,
449
- pixel_values,
450
- image_sizes,
451
- image_offsets,
452
- req_pool_indices,
453
- seq_lens,
454
- prefix_lens,
455
- position_ids_offsets,
456
- out_cache_loc,
457
- return_logprob,
458
- ):
426
+ def forward_extend_multi_modal(self, batch: Batch):
459
427
  input_metadata = InputMetadata.create(
460
428
  self,
461
429
  forward_mode=ForwardMode.EXTEND,
462
430
  tp_size=self.tp_size,
463
- req_pool_indices=req_pool_indices,
464
- seq_lens=seq_lens,
465
- prefix_lens=prefix_lens,
466
- position_ids_offsets=position_ids_offsets,
467
- out_cache_loc=out_cache_loc,
468
- return_logprob=return_logprob,
431
+ req_pool_indices=batch.req_pool_indices,
432
+ seq_lens=batch.seq_lens,
433
+ prefix_lens=batch.prefix_lens,
434
+ position_ids_offsets=batch.position_ids_offsets,
435
+ out_cache_loc=batch.out_cache_loc,
436
+ top_logprobs_nums=batch.top_logprobs_nums,
437
+ return_logprob=batch.return_logprob,
469
438
  )
470
439
  return self.model.forward(
471
- input_ids,
440
+ batch.input_ids,
472
441
  input_metadata.positions,
473
442
  input_metadata,
474
- pixel_values,
475
- image_sizes,
476
- image_offsets,
443
+ batch.pixel_values,
444
+ batch.image_sizes,
445
+ batch.image_offsets,
477
446
  )
478
447
 
479
- def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
448
+ def forward(self, batch: Batch, forward_mode: ForwardMode):
480
449
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
481
- kwargs = {
482
- "input_ids": batch.input_ids,
483
- "pixel_values": batch.pixel_values,
484
- "image_sizes": batch.image_sizes,
485
- "image_offsets": batch.image_offsets,
486
- "req_pool_indices": batch.req_pool_indices,
487
- "seq_lens": batch.seq_lens,
488
- "prefix_lens": batch.prefix_lens,
489
- "position_ids_offsets": batch.position_ids_offsets,
490
- "out_cache_loc": batch.out_cache_loc,
491
- "return_logprob": return_logprob,
492
- }
493
- return self.forward_extend_multi_modal(**kwargs)
494
- else:
495
- kwargs = {
496
- "input_ids": batch.input_ids,
497
- "req_pool_indices": batch.req_pool_indices,
498
- "seq_lens": batch.seq_lens,
499
- "prefix_lens": batch.prefix_lens,
500
- "position_ids_offsets": batch.position_ids_offsets,
501
- "out_cache_loc": batch.out_cache_loc,
502
- "return_logprob": return_logprob,
503
- }
504
-
505
- if forward_mode == ForwardMode.DECODE:
506
- kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
507
- kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
508
- return self.forward_decode(**kwargs)
450
+ return self.forward_extend_multi_modal(batch)
451
+ elif forward_mode == ForwardMode.DECODE:
452
+ return self.forward_decode(batch)
509
453
  elif forward_mode == ForwardMode.EXTEND:
510
- return self.forward_extend(**kwargs)
454
+ return self.forward_extend(batch)
511
455
  elif forward_mode == ForwardMode.PREFILL:
512
- return self.forward_prefill(**kwargs)
456
+ return self.forward_prefill(batch)
513
457
  else:
514
458
  raise ValueError(f"Invaid forward mode: {forward_mode}")
@@ -1,8 +1,6 @@
1
1
  import heapq
2
2
  import time
3
3
  from collections import defaultdict
4
- from dataclasses import dataclass
5
- from typing import Tuple
6
4
 
7
5
  import torch
8
6
 
@@ -11,34 +9,38 @@ class TreeNode:
11
9
  def __init__(self):
12
10
  self.children = defaultdict(TreeNode)
13
11
  self.parent = None
12
+ self.key = None
14
13
  self.value = None
15
- self.ref_counter = 0
14
+ self.lock_ref = 0
16
15
  self.last_access_time = time.time()
17
16
 
18
- def __lt__(self, other):
17
+ def __lt__(self, other: "TreeNode"):
19
18
  return self.last_access_time < other.last_access_time
20
19
 
21
20
 
22
- def match(key, seq):
21
+ def _key_match(key0, key1):
23
22
  i = 0
24
- for k, w in zip(key, seq):
25
- if k != w:
23
+ for k0, k1 in zip(key0, key1):
24
+ if k0 != k1:
26
25
  break
27
26
  i += 1
28
27
  return i
29
28
 
30
29
 
31
30
  class RadixCache:
32
- def __init__(self, disable=False):
33
- self.reset()
31
+ def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
32
+ self.req_to_token_pool = req_to_token_pool
33
+ self.token_to_kv_pool = token_to_kv_pool
34
34
  self.disable = disable
35
+ self.reset()
35
36
 
36
37
  ##### Public API #####
37
38
 
38
39
  def reset(self):
39
40
  self.root_node = TreeNode()
41
+ self.root_node.key = []
40
42
  self.root_node.value = []
41
- self.root_node.ref_counter = 1
43
+ self.root_node.lock_ref = 1
42
44
  self.evictable_size_ = 0
43
45
 
44
46
  def match_prefix(self, key):
@@ -50,6 +52,8 @@ class RadixCache:
50
52
  self._match_prefix_helper(self.root_node, key, value, last_node)
51
53
  if value:
52
54
  value = torch.concat(value)
55
+ else:
56
+ value = torch.tensor([], dtype=torch.int64)
53
57
  return value, last_node[0]
54
58
 
55
59
  def insert(self, key, value=None):
@@ -60,6 +64,34 @@ class RadixCache:
60
64
  value = [x for x in key]
61
65
  return self._insert_helper(self.root_node, key, value)
62
66
 
67
+ def cache_req(
68
+ self,
69
+ token_ids,
70
+ last_uncached_pos,
71
+ req_pool_idx,
72
+ del_in_memory_pool=True,
73
+ old_last_node=None,
74
+ ):
75
+ # Insert the request into radix cache
76
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
+ new_prefix_len = self.insert(token_ids, indices.clone())
78
+
79
+ # Radix Cache takes one ref in memory pool
80
+ self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
81
+
82
+ if del_in_memory_pool:
83
+ self.req_to_token_pool.free(req_pool_idx)
84
+ else:
85
+ cached_indices, new_last_node = self.match_prefix(token_ids)
86
+ assert len(cached_indices) == len(token_ids)
87
+
88
+ self.req_to_token_pool.req_to_token[
89
+ req_pool_idx, last_uncached_pos : len(cached_indices)
90
+ ] = cached_indices[last_uncached_pos:]
91
+ self.dec_lock_ref(old_last_node)
92
+ self.inc_lock_ref(new_last_node)
93
+ return cached_indices, new_last_node
94
+
63
95
  def pretty_print(self):
64
96
  self._print_helper(self.root_node, 0)
65
97
  print(f"#tokens: {self.total_size()}")
@@ -69,7 +101,7 @@ class RadixCache:
69
101
 
70
102
  def evict(self, num_tokens, evict_callback):
71
103
  if self.disable:
72
- raise RuntimeError()
104
+ return
73
105
 
74
106
  leaves = self._collect_leaves()
75
107
  heapq.heapify(leaves)
@@ -80,7 +112,7 @@ class RadixCache:
80
112
 
81
113
  if x == self.root_node:
82
114
  break
83
- if x.ref_counter > 0:
115
+ if x.lock_ref > 0:
84
116
  continue
85
117
 
86
118
  num_evicted += evict_callback(x.value)
@@ -89,23 +121,23 @@ class RadixCache:
89
121
  if len(x.parent.children) == 0:
90
122
  heapq.heappush(leaves, x.parent)
91
123
 
92
- def inc_ref_counter(self, node):
124
+ def inc_lock_ref(self, node: TreeNode):
93
125
  delta = 0
94
126
  while node != self.root_node:
95
- if node.ref_counter == 0:
127
+ if node.lock_ref == 0:
96
128
  self.evictable_size_ -= len(node.value)
97
129
  delta -= len(node.value)
98
- node.ref_counter += 1
130
+ node.lock_ref += 1
99
131
  node = node.parent
100
132
  return delta
101
133
 
102
- def dec_ref_counter(self, node):
134
+ def dec_lock_ref(self, node: TreeNode):
103
135
  delta = 0
104
136
  while node != self.root_node:
105
- if node.ref_counter == 1:
137
+ if node.lock_ref == 1:
106
138
  self.evictable_size_ += len(node.value)
107
139
  delta += len(node.value)
108
- node.ref_counter -= 1
140
+ node.lock_ref -= 1
109
141
  node = node.parent
110
142
  return delta
111
143
 
@@ -113,42 +145,48 @@ class RadixCache:
113
145
  return self.evictable_size_
114
146
 
115
147
  ##### Internal Helper Functions #####
148
+
116
149
  def _match_prefix_helper(self, node, key, value, last_node):
117
150
  node.last_access_time = time.time()
118
-
119
- for c_key, child in node.children.items():
120
- prefix_len = match(c_key, key)
121
- if prefix_len != 0:
122
- if prefix_len < len(c_key):
123
- new_node = self._split_node(c_key, child, prefix_len)
124
- value.append(new_node.value)
125
- last_node[0] = new_node
126
- else:
127
- value.append(child.value)
128
- last_node[0] = child
129
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
130
- break
131
-
132
- def _split_node(self, key, child, split_len):
151
+ if len(key) == 0:
152
+ return
153
+
154
+ if key[0] in node.children.keys():
155
+ child = node.children[key[0]]
156
+ prefix_len = _key_match(child.key, key)
157
+ if prefix_len < len(child.key):
158
+ new_node = self._split_node(child.key, child, prefix_len)
159
+ value.append(new_node.value)
160
+ last_node[0] = new_node
161
+ else:
162
+ value.append(child.value)
163
+ last_node[0] = child
164
+ self._match_prefix_helper(child, key[prefix_len:], value, last_node)
165
+
166
+ def _split_node(self, key, child: TreeNode, split_len):
133
167
  # new_node -> child
134
168
  new_node = TreeNode()
135
- new_node.children = {key[split_len:]: child}
169
+ new_node.children = {key[split_len:][0]: child}
136
170
  new_node.parent = child.parent
137
- new_node.ref_counter = child.ref_counter
171
+ new_node.lock_ref = child.lock_ref
172
+ new_node.key = child.key[:split_len]
138
173
  new_node.value = child.value[:split_len]
139
174
  child.parent = new_node
175
+ child.key = child.key[split_len:]
140
176
  child.value = child.value[split_len:]
141
- new_node.parent.children[key[:split_len]] = new_node
142
- del new_node.parent.children[key]
177
+ new_node.parent.children[key[:split_len][0]] = new_node
143
178
  return new_node
144
179
 
145
180
  def _insert_helper(self, node, key, value):
146
181
  node.last_access_time = time.time()
182
+ if len(key) == 0:
183
+ return 0
147
184
 
148
- for c_key, child in node.children.items():
149
- prefix_len = match(c_key, key)
185
+ if key[0] in node.children.keys():
186
+ child = node.children[key[0]]
187
+ prefix_len = _key_match(child.key, key)
150
188
 
151
- if prefix_len == len(c_key):
189
+ if prefix_len == len(child.key):
152
190
  if prefix_len == len(key):
153
191
  return prefix_len
154
192
  else:
@@ -156,23 +194,23 @@ class RadixCache:
156
194
  value = value[prefix_len:]
157
195
  return prefix_len + self._insert_helper(child, key, value)
158
196
 
159
- if prefix_len:
160
- new_node = self._split_node(c_key, child, prefix_len)
161
- return prefix_len + self._insert_helper(
162
- new_node, key[prefix_len:], value[prefix_len:]
163
- )
197
+ new_node = self._split_node(child.key, child, prefix_len)
198
+ return prefix_len + self._insert_helper(
199
+ new_node, key[prefix_len:], value[prefix_len:]
200
+ )
164
201
 
165
202
  if len(key):
166
203
  new_node = TreeNode()
167
204
  new_node.parent = node
205
+ new_node.key = key
168
206
  new_node.value = value
169
- node.children[key] = new_node
207
+ node.children[key[0]] = new_node
170
208
  self.evictable_size_ += len(value)
171
209
  return 0
172
210
 
173
- def _print_helper(self, node, indent):
174
- for key, child in node.children.items():
175
- print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
211
+ def _print_helper(self, node: TreeNode, indent):
212
+ for _, child in node.children.items():
213
+ print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
176
214
  self._print_helper(child, indent=indent + 2)
177
215
 
178
216
  def _delete_leaf(self, node):
@@ -180,7 +218,7 @@ class RadixCache:
180
218
  if v == node:
181
219
  break
182
220
  del node.parent.children[k]
183
- self.evictable_size_ -= len(k)
221
+ self.evictable_size_ -= len(node.key)
184
222
 
185
223
  def _total_size_helper(self, node):
186
224
  x = len(node.value)
@@ -203,7 +241,7 @@ class RadixCache:
203
241
 
204
242
 
205
243
  if __name__ == "__main__":
206
- tree = RadixCache(disable=False)
244
+ tree = RadixCache(None, None, False)
207
245
 
208
246
  tree.insert("Hello")
209
247
  tree.insert("Hello")