sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
36
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
|
37
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
38
39
|
|
39
40
|
if TYPE_CHECKING:
|
40
41
|
from sglang.srt.managers.schedule_batch import Req
|
41
42
|
|
42
43
|
|
44
|
+
class RadixKey:
|
45
|
+
|
46
|
+
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
|
47
|
+
# token ids sequence
|
48
|
+
self.token_ids = token_ids
|
49
|
+
# extra key (e.g. lora_id, cache_salt)
|
50
|
+
self.extra_key = extra_key
|
51
|
+
|
52
|
+
def __len__(self) -> int:
|
53
|
+
return len(self.token_ids)
|
54
|
+
|
55
|
+
def __iter__(self) -> Iterator[int]:
|
56
|
+
return iter(self.token_ids)
|
57
|
+
|
58
|
+
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
|
59
|
+
if isinstance(idx, slice):
|
60
|
+
return RadixKey(self.token_ids[idx], self.extra_key)
|
61
|
+
return RadixKey([self.token_ids[idx]], self.extra_key)
|
62
|
+
|
63
|
+
def __repr__(self) -> str:
|
64
|
+
preview = self.token_ids[:10]
|
65
|
+
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
|
66
|
+
|
67
|
+
|
43
68
|
class TreeNode:
|
44
69
|
|
45
70
|
counter = 0
|
@@ -47,7 +72,7 @@ class TreeNode:
|
|
47
72
|
def __init__(self, id: Optional[int] = None):
|
48
73
|
self.children = defaultdict(TreeNode)
|
49
74
|
self.parent: TreeNode = None
|
50
|
-
self.key:
|
75
|
+
self.key: RadixKey = None
|
51
76
|
self.value: Optional[torch.Tensor] = None
|
52
77
|
self.lock_ref = 0
|
53
78
|
self.last_access_time = time.monotonic()
|
@@ -93,27 +118,57 @@ class TreeNode:
|
|
93
118
|
return self.last_access_time < other.last_access_time
|
94
119
|
|
95
120
|
|
96
|
-
def
|
121
|
+
def _check_extra_key(key0: RadixKey, key1: RadixKey):
|
122
|
+
if key0.extra_key != key1.extra_key:
|
123
|
+
raise ValueError(
|
124
|
+
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
|
129
|
+
_check_extra_key(key0, key1)
|
97
130
|
i = 0
|
98
|
-
for k0, k1 in zip(key0, key1):
|
131
|
+
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
99
132
|
if k0 != k1:
|
100
133
|
break
|
101
134
|
i += 1
|
102
135
|
return i
|
103
136
|
|
104
137
|
|
105
|
-
def _key_match_paged(key0:
|
138
|
+
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
|
139
|
+
_check_extra_key(key0, key1)
|
106
140
|
min_len = min(len(key0), len(key1))
|
107
141
|
|
108
142
|
i = 0
|
109
143
|
while i < min_len:
|
110
|
-
if key0[i : i + page_size] != key1[i : i + page_size]:
|
144
|
+
if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
|
111
145
|
break
|
112
146
|
i += page_size
|
113
147
|
|
114
148
|
return i
|
115
149
|
|
116
150
|
|
151
|
+
def get_child_key(key: RadixKey, page_size: int = 1):
|
152
|
+
if page_size == 1:
|
153
|
+
plain_key = key.token_ids[0]
|
154
|
+
else:
|
155
|
+
plain_key = tuple(key.token_ids[:page_size])
|
156
|
+
if key.extra_key is None:
|
157
|
+
return plain_key
|
158
|
+
else:
|
159
|
+
return (key.extra_key, plain_key)
|
160
|
+
|
161
|
+
|
162
|
+
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
163
|
+
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
164
|
+
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
165
|
+
if len(tokens) < 2:
|
166
|
+
return []
|
167
|
+
if isinstance(tokens[0], tuple):
|
168
|
+
return tokens
|
169
|
+
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
170
|
+
|
171
|
+
|
117
172
|
class RadixCache(BasePrefixCache):
|
118
173
|
def __init__(
|
119
174
|
self,
|
@@ -122,6 +177,8 @@ class RadixCache(BasePrefixCache):
|
|
122
177
|
page_size: int,
|
123
178
|
disable: bool = False,
|
124
179
|
enable_kv_cache_events: bool = False,
|
180
|
+
eviction_policy: str = "lru",
|
181
|
+
is_eagle: bool = False,
|
125
182
|
):
|
126
183
|
self.req_to_token_pool = req_to_token_pool
|
127
184
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
@@ -129,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
|
129
186
|
self.disable = disable
|
130
187
|
self.enable_kv_cache_events = enable_kv_cache_events
|
131
188
|
self.kv_event_queue = []
|
189
|
+
self.is_eagle = is_eagle
|
132
190
|
|
133
191
|
if self.token_to_kv_pool_allocator:
|
134
192
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -137,17 +195,31 @@ class RadixCache(BasePrefixCache):
|
|
137
195
|
|
138
196
|
if self.page_size == 1:
|
139
197
|
self.key_match_fn = _key_match_page_size1
|
140
|
-
self.get_child_key_fn =
|
198
|
+
self.get_child_key_fn = get_child_key
|
141
199
|
else:
|
142
200
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
143
|
-
self.get_child_key_fn =
|
201
|
+
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
202
|
+
|
203
|
+
if is_eagle:
|
204
|
+
self.key_convert_fn = _convert_to_bigram_key
|
205
|
+
else:
|
206
|
+
self.key_convert_fn = lambda key: key
|
207
|
+
|
208
|
+
if eviction_policy.lower() == "lru":
|
209
|
+
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
210
|
+
elif eviction_policy.lower() == "lfu":
|
211
|
+
self.eviction_strategy: EvictionStrategy = LFUStrategy()
|
212
|
+
else:
|
213
|
+
raise ValueError(
|
214
|
+
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
|
215
|
+
)
|
144
216
|
self.reset()
|
145
217
|
|
146
218
|
##### Public API #####
|
147
219
|
|
148
220
|
def reset(self):
|
149
221
|
self.root_node = TreeNode()
|
150
|
-
self.root_node.key = []
|
222
|
+
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
|
151
223
|
self.root_node.value = []
|
152
224
|
self.root_node.host_value = []
|
153
225
|
self.root_node.lock_ref = 1
|
@@ -155,18 +227,47 @@ class RadixCache(BasePrefixCache):
|
|
155
227
|
self.protected_size_ = 0
|
156
228
|
self._record_all_cleared_event()
|
157
229
|
|
158
|
-
def match_prefix(self, key:
|
159
|
-
"""Find the
|
230
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
231
|
+
"""Find the longest cached prefix of ``key`` in the radix tree.
|
232
|
+
|
233
|
+
The logical namespace for prefix matching is determined by both the
|
234
|
+
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
|
235
|
+
Entries that share identical leading token ids but have *different*
|
236
|
+
``extra_key`` values are intentionally kept disjoint and never share
|
237
|
+
prefix nodes. This is useful to:
|
238
|
+
|
239
|
+
* Isolate KV cache lines for different LoRA / adapter IDs.
|
240
|
+
* Separate requests that intentionally should not share state (e.g.,
|
241
|
+
different sampling salt, cache version, or retrieval augmentation
|
242
|
+
context) by supplying a distinct ``extra_key``.
|
243
|
+
|
160
244
|
Args:
|
161
|
-
key:
|
245
|
+
key (RadixKey): The lookup key containing a list of token ids and an
|
246
|
+
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
|
247
|
+
length is internally truncated to a multiple of ``page_size``
|
248
|
+
before matching. Passing an empty key returns an empty result
|
249
|
+
with the root as the last node.
|
250
|
+
**kwargs: Reserved for future extensions (ignored currently).
|
251
|
+
|
162
252
|
Returns:
|
163
|
-
|
164
|
-
the
|
165
|
-
|
166
|
-
|
167
|
-
|
253
|
+
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
|
254
|
+
the concatenated KV cache indices corresponding to the longest
|
255
|
+
cached prefix (may be length 0). ``last_device_node`` and
|
256
|
+
``last_host_node`` (currently the same) are the tree node objects
|
257
|
+
representing the terminal node of the matched prefix. This method
|
258
|
+
may mutate internal structure by splitting an existing node if the
|
259
|
+
match ends inside a stored segment.
|
260
|
+
|
261
|
+
Internal updates:
|
262
|
+
* Refreshes access metadata (timestamps) used by the
|
263
|
+
configured eviction strategy.
|
264
|
+
* If the lookup ends inside a stored segment the node is split once
|
265
|
+
to expose a precise boundary; this structural refinement improves
|
266
|
+
subsequent match efficiency and does not duplicate data.
|
168
267
|
"""
|
169
|
-
|
268
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
269
|
+
|
270
|
+
def empty_match_result():
|
170
271
|
return MatchResult(
|
171
272
|
device_indices=torch.empty(
|
172
273
|
(0,),
|
@@ -177,10 +278,16 @@ class RadixCache(BasePrefixCache):
|
|
177
278
|
last_host_node=self.root_node,
|
178
279
|
)
|
179
280
|
|
281
|
+
if self.disable or len(key) == 0:
|
282
|
+
return empty_match_result()
|
283
|
+
|
180
284
|
if self.page_size != 1:
|
181
285
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
182
286
|
key = key[:page_aligned_len]
|
183
287
|
|
288
|
+
if len(key) == 0:
|
289
|
+
return empty_match_result()
|
290
|
+
|
184
291
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
185
292
|
if value:
|
186
293
|
value = torch.cat(value)
|
@@ -192,12 +299,19 @@ class RadixCache(BasePrefixCache):
|
|
192
299
|
last_host_node=last_node,
|
193
300
|
)
|
194
301
|
|
195
|
-
def insert(self, key:
|
302
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
196
303
|
if self.disable:
|
197
304
|
return 0
|
198
305
|
|
306
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
307
|
+
|
199
308
|
if value is None:
|
200
|
-
value =
|
309
|
+
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
310
|
+
|
311
|
+
if self.is_eagle:
|
312
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
313
|
+
value = value[: len(key)]
|
314
|
+
|
201
315
|
return self._insert_helper(self.root_node, key, value)
|
202
316
|
|
203
317
|
def cache_finished_req(self, req: Req):
|
@@ -211,27 +325,39 @@ class RadixCache(BasePrefixCache):
|
|
211
325
|
return
|
212
326
|
|
213
327
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
328
|
+
all_token_len = len(token_ids)
|
329
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
214
330
|
kv_indices = self.req_to_token_pool.req_to_token[
|
215
|
-
req.req_pool_idx, :
|
331
|
+
req.req_pool_idx, :all_token_len
|
216
332
|
]
|
217
333
|
|
218
334
|
if self.page_size != 1:
|
219
|
-
page_aligned_len =
|
335
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
220
336
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
221
337
|
dtype=torch.int64, copy=True
|
222
338
|
)
|
223
339
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
224
340
|
else:
|
225
|
-
page_aligned_len =
|
341
|
+
page_aligned_len = actual_kv_len
|
226
342
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
343
|
+
if self.is_eagle:
|
344
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
345
|
+
|
346
|
+
page_aligned_token_len = (
|
347
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
348
|
+
)
|
349
|
+
|
350
|
+
old_prefix_len = len(req.prefix_indices)
|
351
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
352
|
+
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
353
|
+
old_prefix_len -= 1
|
227
354
|
|
228
355
|
# Radix Cache takes one ref in memory pool
|
229
356
|
new_prefix_len = self.insert(
|
230
|
-
token_ids[:
|
231
|
-
|
232
|
-
self.token_to_kv_pool_allocator.free(
|
233
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
357
|
+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
358
|
+
page_aligned_kv_indices,
|
234
359
|
)
|
360
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
235
361
|
|
236
362
|
# Remove req slot release the cache lock
|
237
363
|
self.req_to_token_pool.free(req.req_pool_idx)
|
@@ -243,45 +369,73 @@ class RadixCache(BasePrefixCache):
|
|
243
369
|
return
|
244
370
|
|
245
371
|
token_ids = req.fill_ids
|
372
|
+
all_token_len = len(token_ids)
|
373
|
+
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
374
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
246
375
|
kv_indices = self.req_to_token_pool.req_to_token[
|
247
|
-
req.req_pool_idx, :
|
376
|
+
req.req_pool_idx, :all_token_len
|
248
377
|
]
|
249
378
|
|
250
379
|
if self.page_size != 1:
|
251
|
-
page_aligned_len =
|
380
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
252
381
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
253
382
|
dtype=torch.int64, copy=True
|
254
383
|
)
|
255
384
|
else:
|
256
|
-
page_aligned_len =
|
385
|
+
page_aligned_len = actual_kv_len
|
257
386
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
258
|
-
|
387
|
+
|
388
|
+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
389
|
+
page_aligned_token_len = (
|
390
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
391
|
+
)
|
392
|
+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
393
|
+
|
394
|
+
old_prefix_len = len(req.prefix_indices)
|
395
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
396
|
+
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
397
|
+
old_prefix_len -= 1
|
259
398
|
|
260
399
|
# Radix Cache takes one ref in memory pool
|
261
400
|
new_prefix_len = self.insert(
|
262
|
-
page_aligned_token_ids,
|
263
|
-
|
264
|
-
|
265
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
401
|
+
RadixKey(page_aligned_token_ids, req.extra_key),
|
402
|
+
page_aligned_kv_indices,
|
403
|
+
chunked=chunked,
|
266
404
|
)
|
405
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
267
406
|
|
268
407
|
# The prefix indices could be updated, reuse it
|
269
|
-
new_indices, new_last_node, _, _ = self.match_prefix(
|
408
|
+
new_indices, new_last_node, _, _ = self.match_prefix(
|
409
|
+
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
410
|
+
)
|
270
411
|
self.req_to_token_pool.write(
|
271
|
-
(req.req_pool_idx, slice(
|
272
|
-
new_indices[
|
412
|
+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
413
|
+
new_indices[old_prefix_len:],
|
273
414
|
)
|
274
415
|
|
416
|
+
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
417
|
+
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
418
|
+
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
419
|
+
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
420
|
+
req.last_matched_prefix_len = len(new_indices)
|
421
|
+
|
275
422
|
self.dec_lock_ref(req.last_node)
|
276
423
|
self.inc_lock_ref(new_last_node)
|
277
424
|
|
278
425
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
279
426
|
if self.page_size != 1:
|
427
|
+
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
280
428
|
req.prefix_indices = torch.cat(
|
281
429
|
[new_indices, kv_indices[len(new_indices) :]]
|
282
430
|
)
|
283
431
|
else:
|
284
|
-
|
432
|
+
if self.is_eagle:
|
433
|
+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
434
|
+
req.prefix_indices = torch.cat(
|
435
|
+
[new_indices, kv_indices[actual_kv_len:]]
|
436
|
+
)
|
437
|
+
else:
|
438
|
+
req.prefix_indices = new_indices
|
285
439
|
req.last_node = new_last_node
|
286
440
|
|
287
441
|
def pretty_print(self):
|
@@ -296,11 +450,14 @@ class RadixCache(BasePrefixCache):
|
|
296
450
|
return
|
297
451
|
|
298
452
|
leaves = self._collect_leaves()
|
299
|
-
|
453
|
+
eviction_heap = [
|
454
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
455
|
+
]
|
456
|
+
heapq.heapify(eviction_heap)
|
300
457
|
|
301
458
|
num_evicted = 0
|
302
|
-
while num_evicted < num_tokens and len(
|
303
|
-
x = heapq.heappop(
|
459
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
460
|
+
_priority, x = heapq.heappop(eviction_heap)
|
304
461
|
|
305
462
|
if x == self.root_node:
|
306
463
|
break
|
@@ -312,7 +469,8 @@ class RadixCache(BasePrefixCache):
|
|
312
469
|
self._delete_leaf(x)
|
313
470
|
|
314
471
|
if len(x.parent.children) == 0:
|
315
|
-
|
472
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
473
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
316
474
|
|
317
475
|
self._record_remove_event(x)
|
318
476
|
|
@@ -323,9 +481,9 @@ class RadixCache(BasePrefixCache):
|
|
323
481
|
delta = 0
|
324
482
|
while node != self.root_node:
|
325
483
|
if node.lock_ref == 0:
|
326
|
-
self.evictable_size_ -= len(node.
|
327
|
-
self.protected_size_ += len(node.
|
328
|
-
delta -= len(node.
|
484
|
+
self.evictable_size_ -= len(node.key)
|
485
|
+
self.protected_size_ += len(node.key)
|
486
|
+
delta -= len(node.key)
|
329
487
|
node.lock_ref += 1
|
330
488
|
node = node.parent
|
331
489
|
return delta
|
@@ -337,9 +495,9 @@ class RadixCache(BasePrefixCache):
|
|
337
495
|
delta = 0
|
338
496
|
while node != self.root_node:
|
339
497
|
if node.lock_ref == 1:
|
340
|
-
self.evictable_size_ += len(node.
|
341
|
-
self.protected_size_ -= len(node.
|
342
|
-
delta += len(node.
|
498
|
+
self.evictable_size_ += len(node.key)
|
499
|
+
self.protected_size_ -= len(node.key)
|
500
|
+
delta += len(node.key)
|
343
501
|
node.lock_ref -= 1
|
344
502
|
node = node.parent
|
345
503
|
return delta
|
@@ -364,7 +522,7 @@ class RadixCache(BasePrefixCache):
|
|
364
522
|
|
365
523
|
##### Internal Helper Functions #####
|
366
524
|
|
367
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
525
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
368
526
|
node.last_access_time = time.monotonic()
|
369
527
|
|
370
528
|
child_key = self.get_child_key_fn(key)
|
@@ -389,7 +547,7 @@ class RadixCache(BasePrefixCache):
|
|
389
547
|
|
390
548
|
return value, node
|
391
549
|
|
392
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
550
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
393
551
|
# new_node -> child
|
394
552
|
self._record_remove_event(child)
|
395
553
|
new_node = TreeNode()
|
@@ -408,7 +566,7 @@ class RadixCache(BasePrefixCache):
|
|
408
566
|
|
409
567
|
return new_node
|
410
568
|
|
411
|
-
def _insert_helper(self, node: TreeNode, key:
|
569
|
+
def _insert_helper(self, node: TreeNode, key: RadixKey, value):
|
412
570
|
node.last_access_time = time.monotonic()
|
413
571
|
if len(key) == 0:
|
414
572
|
return 0
|
@@ -437,7 +595,7 @@ class RadixCache(BasePrefixCache):
|
|
437
595
|
new_node.key = key
|
438
596
|
new_node.value = value
|
439
597
|
node.children[child_key] = new_node
|
440
|
-
self.evictable_size_ += len(
|
598
|
+
self.evictable_size_ += len(key)
|
441
599
|
self._record_store_event(new_node)
|
442
600
|
return total_prefix_length
|
443
601
|
|
@@ -449,7 +607,7 @@ class RadixCache(BasePrefixCache):
|
|
449
607
|
print(
|
450
608
|
" " * current_indent,
|
451
609
|
len(current_node.key),
|
452
|
-
current_node.key[:10],
|
610
|
+
current_node.key.token_ids[:10],
|
453
611
|
f"r={current_node.lock_ref}",
|
454
612
|
)
|
455
613
|
for key, child in current_node.children.items():
|
@@ -501,11 +659,11 @@ class RadixCache(BasePrefixCache):
|
|
501
659
|
last_page_start = (
|
502
660
|
(len(node.parent.key) - 1) // self.page_size
|
503
661
|
) * self.page_size
|
504
|
-
parent_parent_tokens = node.parent.key[last_page_start:]
|
662
|
+
parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
|
505
663
|
parent_block_hash = hash(tuple(parent_parent_tokens))
|
506
664
|
|
507
665
|
for start in range(0, len(node.key), self.page_size):
|
508
|
-
page_tokens = node.key[start : start + self.page_size]
|
666
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
509
667
|
if not page_tokens:
|
510
668
|
continue
|
511
669
|
|
@@ -528,7 +686,7 @@ class RadixCache(BasePrefixCache):
|
|
528
686
|
# One BlockRemoved per chunk.
|
529
687
|
if self.enable_kv_cache_events:
|
530
688
|
for start in range(0, len(node.key), self.page_size):
|
531
|
-
page_tokens = node.key[start : start + self.page_size]
|
689
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
532
690
|
if not page_tokens:
|
533
691
|
continue
|
534
692
|
block_hash = hash(tuple(page_tokens))
|
@@ -554,19 +712,12 @@ class RadixCache(BasePrefixCache):
|
|
554
712
|
if __name__ == "__main__":
|
555
713
|
tree = RadixCache(None, None, page_size=1, disable=False)
|
556
714
|
|
557
|
-
|
558
|
-
tree.insert(
|
559
|
-
tree.insert(
|
560
|
-
|
561
|
-
|
715
|
+
# Example token id sequences (as lists of ints)
|
716
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
717
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
718
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
|
719
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
|
720
|
+
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
|
562
721
|
tree.pretty_print()
|
563
722
|
|
564
|
-
|
565
|
-
|
566
|
-
# def evict_callback(x):
|
567
|
-
# print("evict", x)
|
568
|
-
# return len(x)
|
569
|
-
|
570
|
-
# tree.evict(5, evict_callback)
|
571
|
-
# tree.evict(10, evict_callback)
|
572
|
-
# tree.pretty_print()
|
723
|
+
print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
|
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
|
13
13
|
TreeNodeCpp,
|
14
14
|
)
|
15
15
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
|
|
93
94
|
raise NotImplementedError("Host cache is not supported yet")
|
94
95
|
self.tree.reset()
|
95
96
|
|
96
|
-
def match_prefix(self, key:
|
97
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
97
98
|
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
-
self.tree.match_prefix(key)
|
99
|
+
self.tree.match_prefix(key.token_ids)
|
99
100
|
)
|
100
101
|
return MatchResult(
|
101
102
|
device_indices=self._merge_tensor(device_indices_vec),
|
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
104
105
|
host_hit_length=host_indices_length,
|
105
106
|
)
|
106
107
|
|
107
|
-
def _insert(self, key:
|
108
|
+
def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
|
108
109
|
"""
|
109
110
|
Insert a key-value pair into the radix tree.
|
110
111
|
Args:
|
111
|
-
key (
|
112
|
+
key (RadixKey): The key to insert, represented as a RadixKey.
|
112
113
|
value (torch.Tensor): The value to associate with the key.
|
113
114
|
Returns:
|
114
115
|
int: Number of device indices that were already present in the tree before the insertion.
|
115
116
|
"""
|
116
|
-
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
ongoing_write, length = self.tree.writing_through(key.token_ids, value)
|
117
118
|
if self.cache_controller is None:
|
118
119
|
assert len(ongoing_write) == 0, "Implementation error"
|
119
120
|
return length
|
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
160
161
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
162
|
# it will automatically align them, but length of them should be equal
|
162
163
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
164
165
|
|
165
166
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
167
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
191
192
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
193
|
# it will automatically align them, but length of them should be equal
|
193
194
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
195
196
|
|
196
197
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
198
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
199
|
|
199
200
|
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
201
|
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
-
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
202
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
203
|
+
RadixKey(token_ids, req.extra_key).token_ids
|
204
|
+
)
|
202
205
|
new_indices = self._merge_tensor(new_indices_vec)
|
203
206
|
assert new_prefix_len <= len(new_indices)
|
204
207
|
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
|
3
|
+
|
4
|
+
"""Storage backend module for SGLang HiCache."""
|
5
|
+
|
6
|
+
from .backend_factory import StorageBackendFactory
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"StorageBackendFactory",
|
10
|
+
]
|