sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,993 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Copyright 2023-2024 SGLang Team
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import heapq
|
|
23
|
+
import time
|
|
24
|
+
from collections import defaultdict
|
|
25
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
|
30
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
|
31
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
|
32
|
+
from sglang.srt.mem_cache.radix_cache import (
|
|
33
|
+
RadixKey,
|
|
34
|
+
_key_match_page_size1,
|
|
35
|
+
get_child_key,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from sglang.srt.managers.schedule_batch import Req
|
|
40
|
+
|
|
41
|
+
import logging
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TreeNode:
|
|
47
|
+
|
|
48
|
+
counter = 0
|
|
49
|
+
|
|
50
|
+
def __init__(self, id: Optional[int] = None):
|
|
51
|
+
self.children = defaultdict(TreeNode)
|
|
52
|
+
self.parent: TreeNode = None
|
|
53
|
+
self.key: RadixKey = None
|
|
54
|
+
self.value: Optional[torch.Tensor] = None
|
|
55
|
+
self.mamba_value: Optional[torch.Tensor] = None
|
|
56
|
+
# invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked;
|
|
57
|
+
# if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So,
|
|
58
|
+
# full_lock_ref is always >= mamba_lock_ref.
|
|
59
|
+
# for full_lock, once it is locked, its parent must be locked as well
|
|
60
|
+
# for mamba_lock, it only need lock node itself
|
|
61
|
+
self.full_lock_ref = 0
|
|
62
|
+
self.mamba_lock_ref = 0
|
|
63
|
+
# last access time is only used for sanity check. LRU is maintained by the lru list.
|
|
64
|
+
self.last_access_time = time.monotonic()
|
|
65
|
+
|
|
66
|
+
self.hit_count = 0
|
|
67
|
+
# store the host indices of KV cache
|
|
68
|
+
self.host_value = None
|
|
69
|
+
|
|
70
|
+
# for lru list, invariant:
|
|
71
|
+
# 1. prev has greater last_access_time
|
|
72
|
+
# 2. next has smaller last_access_time
|
|
73
|
+
self.prev = None
|
|
74
|
+
self.next = None
|
|
75
|
+
self.mamba_prev = None
|
|
76
|
+
self.mamba_next = None
|
|
77
|
+
|
|
78
|
+
self.id = TreeNode.counter if id is None else id
|
|
79
|
+
TreeNode.counter += 1
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def evicted(self):
|
|
83
|
+
return self.value is None
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def backuped(self):
|
|
87
|
+
return self.host_value is not None
|
|
88
|
+
|
|
89
|
+
def __lt__(self, other: "TreeNode"):
|
|
90
|
+
return self.last_access_time < other.last_access_time
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LRUList:
|
|
94
|
+
def __init__(self, mamba: bool = False):
|
|
95
|
+
self.mamba = mamba
|
|
96
|
+
if self.mamba:
|
|
97
|
+
self.prv = "mamba_prev"
|
|
98
|
+
self.nxt = "mamba_next"
|
|
99
|
+
self.lock_ref = "mamba_lock_ref"
|
|
100
|
+
else:
|
|
101
|
+
self.prv = "prev"
|
|
102
|
+
self.nxt = "next"
|
|
103
|
+
self.lock_ref = "full_lock_ref"
|
|
104
|
+
# Initialize dummy head and tail nodes
|
|
105
|
+
self.head = TreeNode() # Most recently used side
|
|
106
|
+
self.tail = TreeNode() # Least recently used side
|
|
107
|
+
setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
|
|
108
|
+
setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
|
|
109
|
+
self.cache = {}
|
|
110
|
+
|
|
111
|
+
def _add_node(self, node):
|
|
112
|
+
"""Helper to add node right after head (most recently used)"""
|
|
113
|
+
self._add_node_after(self.head, node)
|
|
114
|
+
|
|
115
|
+
def _add_node_after(self, old_node, new_node):
|
|
116
|
+
"""Helper to add node right after old_node"""
|
|
117
|
+
setattr(new_node, self.prv, old_node) # new_node.prev = old_node
|
|
118
|
+
setattr(
|
|
119
|
+
new_node, self.nxt, getattr(old_node, self.nxt)
|
|
120
|
+
) # new_node.next = old_node.next
|
|
121
|
+
setattr(
|
|
122
|
+
getattr(old_node, self.nxt), self.prv, new_node
|
|
123
|
+
) # old_node.next.prev = new_node
|
|
124
|
+
setattr(old_node, self.nxt, new_node) # old_node.next = new_node
|
|
125
|
+
|
|
126
|
+
def _remove_node(self, node):
|
|
127
|
+
"""Helper to remove node from linked list"""
|
|
128
|
+
setattr(
|
|
129
|
+
getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
|
|
130
|
+
) # node.prev.next = node.next
|
|
131
|
+
setattr(
|
|
132
|
+
getattr(node, self.nxt), self.prv, getattr(node, self.prv)
|
|
133
|
+
) # node.next.prev = node.prev
|
|
134
|
+
|
|
135
|
+
def _get_lru(self) -> Optional[TreeNode]:
|
|
136
|
+
"""
|
|
137
|
+
Get the least recently used node
|
|
138
|
+
"""
|
|
139
|
+
if len(self.cache) == 0:
|
|
140
|
+
return None
|
|
141
|
+
return getattr(self.tail, self.prv)
|
|
142
|
+
|
|
143
|
+
def reset_node_mru(self, node):
|
|
144
|
+
"""
|
|
145
|
+
Move a (existing) node to most recently used position
|
|
146
|
+
"""
|
|
147
|
+
assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
|
|
148
|
+
assert (
|
|
149
|
+
not self.mamba or node.mamba_value is not None
|
|
150
|
+
), f"Resetting mamba tombstone node in mamba lru list: {node.id=}"
|
|
151
|
+
self._remove_node(node)
|
|
152
|
+
self._add_node(node)
|
|
153
|
+
|
|
154
|
+
def reset_node_and_parents_mru(self, node, root_node):
|
|
155
|
+
"""
|
|
156
|
+
Move an (existing) node and its parents to most recently used position. Child node is
|
|
157
|
+
more recently used than parent node.
|
|
158
|
+
"""
|
|
159
|
+
prev_node = self.head
|
|
160
|
+
while node != root_node:
|
|
161
|
+
if not self.mamba or node.mamba_value is not None:
|
|
162
|
+
assert (
|
|
163
|
+
node.id in self.cache
|
|
164
|
+
), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
|
|
165
|
+
self._remove_node(node)
|
|
166
|
+
self._add_node_after(prev_node, node)
|
|
167
|
+
prev_node = node
|
|
168
|
+
node = node.parent
|
|
169
|
+
|
|
170
|
+
def insert_mru(self, node):
|
|
171
|
+
"""
|
|
172
|
+
Insert a (new) node as most recently used
|
|
173
|
+
"""
|
|
174
|
+
assert (
|
|
175
|
+
not self.mamba or node.mamba_value is not None
|
|
176
|
+
), f"Inserting mamba tombstone node in mamba lru list: {node.id=}"
|
|
177
|
+
assert (
|
|
178
|
+
node.id not in self.cache
|
|
179
|
+
), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
|
|
180
|
+
self.cache[node.id] = node
|
|
181
|
+
self._add_node(node)
|
|
182
|
+
|
|
183
|
+
def remove_node(self, node: TreeNode):
|
|
184
|
+
"""
|
|
185
|
+
Remove node from lru list
|
|
186
|
+
"""
|
|
187
|
+
assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
|
|
188
|
+
assert (
|
|
189
|
+
not self.mamba or node.mamba_value is not None
|
|
190
|
+
), f"Removing mamba tombstone node from mamba lru list: {node.id=}"
|
|
191
|
+
del self.cache[node.id]
|
|
192
|
+
self._remove_node(node)
|
|
193
|
+
|
|
194
|
+
def get_lru_no_lock(self) -> Optional[TreeNode]:
|
|
195
|
+
"""
|
|
196
|
+
Get the least recently used node that is not locked
|
|
197
|
+
"""
|
|
198
|
+
return self.get_prev_no_lock(self.tail, check_id=False)
|
|
199
|
+
|
|
200
|
+
def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
|
|
201
|
+
"""
|
|
202
|
+
Get the least recently used leaf node that is not locked
|
|
203
|
+
"""
|
|
204
|
+
return self.get_prev_leaf_no_lock(self.tail, check_id=False)
|
|
205
|
+
|
|
206
|
+
def get_prev_no_lock(
|
|
207
|
+
self, node: TreeNode, check_id: bool = True
|
|
208
|
+
) -> Optional[TreeNode]:
|
|
209
|
+
"""
|
|
210
|
+
Get the previous (i.e. more recently used) node that is not locked
|
|
211
|
+
"""
|
|
212
|
+
if check_id:
|
|
213
|
+
assert (
|
|
214
|
+
node.id in self.cache
|
|
215
|
+
), f"Getting prev of node {node.id=} not in lru list"
|
|
216
|
+
x = getattr(node, self.prv) # x = node.prev
|
|
217
|
+
while getattr(x, self.lock_ref) > 0:
|
|
218
|
+
x = getattr(x, self.prv) # x = x.prev
|
|
219
|
+
# if x is the head, it means there is no node in the lru list without lock
|
|
220
|
+
if x == self.head:
|
|
221
|
+
return None
|
|
222
|
+
return x
|
|
223
|
+
|
|
224
|
+
def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
|
|
225
|
+
"""
|
|
226
|
+
Get the previous (i.e. more recently used) leaf node that is not locked
|
|
227
|
+
"""
|
|
228
|
+
if check_id:
|
|
229
|
+
assert (
|
|
230
|
+
node.id in self.cache
|
|
231
|
+
), f"Getting prev of node {node.id=} not in lru list"
|
|
232
|
+
x = getattr(node, self.prv) # x = node.prev
|
|
233
|
+
while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
|
|
234
|
+
x = getattr(x, self.prv) # x = x.prev
|
|
235
|
+
# if x is the head, it means there is no leaf node in the lru list without lock
|
|
236
|
+
if x == self.head:
|
|
237
|
+
return None
|
|
238
|
+
return x
|
|
239
|
+
|
|
240
|
+
def in_list(self, node: Optional[TreeNode]):
|
|
241
|
+
"""
|
|
242
|
+
Check if the node is in the lru list
|
|
243
|
+
"""
|
|
244
|
+
if not node:
|
|
245
|
+
return False
|
|
246
|
+
return node.id in self.cache
|
|
247
|
+
|
|
248
|
+
# Note: this is expensive, only use for debug
|
|
249
|
+
def sanity_check_evictable_size(self):
|
|
250
|
+
"""
|
|
251
|
+
Check the evictable size (i.e. the size of the nodes that are not locked)
|
|
252
|
+
"""
|
|
253
|
+
node = self.get_lru_no_lock()
|
|
254
|
+
evictable_size = 0
|
|
255
|
+
while self.in_list(node):
|
|
256
|
+
evictable_size += (
|
|
257
|
+
len(node.value) if not self.mamba else len(node.mamba_value)
|
|
258
|
+
)
|
|
259
|
+
node = self.get_prev_no_lock(node)
|
|
260
|
+
return evictable_size
|
|
261
|
+
|
|
262
|
+
# Note: this is expensive, only use for debug or idle check
|
|
263
|
+
def sanity_check(self, tree_cache: "MambaRadixCache"):
|
|
264
|
+
"""
|
|
265
|
+
Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
|
|
266
|
+
checking if the lru list is valid.
|
|
267
|
+
"""
|
|
268
|
+
try:
|
|
269
|
+
if self.mamba:
|
|
270
|
+
nodes = tree_cache._collect_nontombstone_nodes()
|
|
271
|
+
else:
|
|
272
|
+
nodes = tree_cache._collect_all_nodes()
|
|
273
|
+
total_nodes = len(nodes)
|
|
274
|
+
total_lru = len(self.cache)
|
|
275
|
+
# heapify based on last_access_time
|
|
276
|
+
heapq.heapify(nodes)
|
|
277
|
+
# the root node is not in the lru list
|
|
278
|
+
assert len(nodes) == (
|
|
279
|
+
total_lru + (0 if self.mamba else 1)
|
|
280
|
+
), f"len(nodes): {len(nodes)}, total_lru: {total_lru}"
|
|
281
|
+
|
|
282
|
+
x_lru = self._get_lru()
|
|
283
|
+
while len(nodes):
|
|
284
|
+
x = heapq.heappop(nodes)
|
|
285
|
+
if x == tree_cache.root_node:
|
|
286
|
+
# root node is not in the lru list
|
|
287
|
+
continue
|
|
288
|
+
assert (
|
|
289
|
+
x == x_lru
|
|
290
|
+
), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}"
|
|
291
|
+
assert (
|
|
292
|
+
x_lru.full_lock_ref == 0
|
|
293
|
+
), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}"
|
|
294
|
+
assert (
|
|
295
|
+
x_lru.mamba_lock_ref == 0
|
|
296
|
+
), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}"
|
|
297
|
+
x_lru = getattr(x, self.prv)
|
|
298
|
+
|
|
299
|
+
if self.mamba:
|
|
300
|
+
evictable_size = tree_cache.mamba_evictable_size()
|
|
301
|
+
lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size()
|
|
302
|
+
else:
|
|
303
|
+
evictable_size = tree_cache.full_evictable_size()
|
|
304
|
+
lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
|
|
305
|
+
|
|
306
|
+
assert (
|
|
307
|
+
evictable_size == lru_list_evictable_size
|
|
308
|
+
), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
|
|
309
|
+
except Exception as e:
|
|
310
|
+
msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}"
|
|
311
|
+
logger.error(msg)
|
|
312
|
+
raise Exception(msg)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class MambaRadixCache(BasePrefixCache):
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
req_to_token_pool: HybridReqToTokenPool,
|
|
319
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
|
320
|
+
page_size: int,
|
|
321
|
+
disable: bool = False,
|
|
322
|
+
):
|
|
323
|
+
assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator)
|
|
324
|
+
self.req_to_token_pool = req_to_token_pool
|
|
325
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
326
|
+
|
|
327
|
+
assert page_size == 1, "Only support page_size=1 in mamba radix cache now."
|
|
328
|
+
self.page_size = page_size
|
|
329
|
+
self.disable = disable
|
|
330
|
+
|
|
331
|
+
if self.token_to_kv_pool_allocator:
|
|
332
|
+
self.device = self.token_to_kv_pool_allocator.device
|
|
333
|
+
else:
|
|
334
|
+
self.device = torch.device("cpu")
|
|
335
|
+
|
|
336
|
+
self.key_match_fn = _key_match_page_size1
|
|
337
|
+
self.get_child_key_fn = get_child_key
|
|
338
|
+
self.reset()
|
|
339
|
+
|
|
340
|
+
##### Public API #####
|
|
341
|
+
|
|
342
|
+
def reset(self) -> None:
|
|
343
|
+
self.root_node = TreeNode()
|
|
344
|
+
self.root_node.key = []
|
|
345
|
+
self.root_node.value = []
|
|
346
|
+
self.root_node.full_lock_ref = 1
|
|
347
|
+
self.root_node.mamba_lock_ref = 1
|
|
348
|
+
self.full_evictable_size_ = 0
|
|
349
|
+
self.mamba_evictable_size_ = 0
|
|
350
|
+
self.full_protected_size_ = 0
|
|
351
|
+
self.mamba_protected_size_ = 0
|
|
352
|
+
# LRU lists are used to maintain the order of eviction of the nodes in the tree
|
|
353
|
+
self.full_lru_list = LRUList(mamba=False)
|
|
354
|
+
self.mamba_lru_list = LRUList(mamba=True)
|
|
355
|
+
|
|
356
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
|
357
|
+
"""Find the matching prefix from the radix tree.
|
|
358
|
+
Args:
|
|
359
|
+
key: A RadixKey contains token IDs to find a matching prefix.
|
|
360
|
+
Returns:
|
|
361
|
+
A tuple of a tensor of matching prefix token IDs and
|
|
362
|
+
the last node that contains the prefix values. Note that
|
|
363
|
+
this API can modify the internal state of the Radix tree.
|
|
364
|
+
The last node create a new child if the prefix is shorter
|
|
365
|
+
than the last node's value.
|
|
366
|
+
"""
|
|
367
|
+
cow_mamba: bool = kwargs.get("cow_mamba", False)
|
|
368
|
+
req: Req = kwargs.get("req", None)
|
|
369
|
+
|
|
370
|
+
if self.disable or len(key) == 0:
|
|
371
|
+
return MatchResult(
|
|
372
|
+
device_indices=torch.empty(
|
|
373
|
+
(0,),
|
|
374
|
+
dtype=torch.int64,
|
|
375
|
+
device=self.device,
|
|
376
|
+
),
|
|
377
|
+
last_device_node=self.root_node,
|
|
378
|
+
last_host_node=self.root_node,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
value, last_node = self._match_prefix_helper(key)
|
|
382
|
+
|
|
383
|
+
# copy mamba state to req local space if cow is true
|
|
384
|
+
if cow_mamba and last_node.mamba_value is not None:
|
|
385
|
+
assert req.req_pool_idx is None # req_pool_idx is uninitialed
|
|
386
|
+
|
|
387
|
+
# for reqs without mamba cache
|
|
388
|
+
if req.mamba_pool_idx is None:
|
|
389
|
+
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
|
|
390
|
+
# try to alloc again, protect last_node from eviction
|
|
391
|
+
if dst_index is None:
|
|
392
|
+
self.inc_lock_ref(last_node)
|
|
393
|
+
self.evict_mamba(1)
|
|
394
|
+
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
|
|
395
|
+
self.dec_lock_ref(last_node)
|
|
396
|
+
assert dst_index is not None, "Can not alloc mamba cache"
|
|
397
|
+
src_index = last_node.mamba_value
|
|
398
|
+
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
|
|
399
|
+
req.mamba_pool_idx = dst_index[0]
|
|
400
|
+
else:
|
|
401
|
+
src_index = last_node.mamba_value
|
|
402
|
+
dst_index = req.mamba_pool_idx.unsqueeze(0)
|
|
403
|
+
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
|
|
404
|
+
|
|
405
|
+
if value:
|
|
406
|
+
value = torch.cat(value)
|
|
407
|
+
else:
|
|
408
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
|
409
|
+
|
|
410
|
+
return MatchResult(
|
|
411
|
+
device_indices=value,
|
|
412
|
+
last_device_node=last_node,
|
|
413
|
+
last_host_node=last_node,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]:
|
|
417
|
+
if self.disable:
|
|
418
|
+
return 0
|
|
419
|
+
|
|
420
|
+
if value is None:
|
|
421
|
+
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
|
422
|
+
return self._insert_helper(self.root_node, key, value, mamba_value)
|
|
423
|
+
|
|
424
|
+
def cache_finished_req(self, req: Req) -> None:
|
|
425
|
+
"""Cache request when it finishes."""
|
|
426
|
+
if self.disable:
|
|
427
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
|
428
|
+
req.req_pool_idx,
|
|
429
|
+
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
|
430
|
+
]
|
|
431
|
+
self.token_to_kv_pool_allocator.free(kv_indices)
|
|
432
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
|
436
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
|
437
|
+
req.req_pool_idx, : len(token_ids)
|
|
438
|
+
]
|
|
439
|
+
|
|
440
|
+
page_aligned_len = len(kv_indices)
|
|
441
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
442
|
+
|
|
443
|
+
# Radix Cache takes one ref in memory pool
|
|
444
|
+
# insert the token_ids and kv_indices into the radix tree
|
|
445
|
+
# Note: the insert function already frees the overlapped kv_indices
|
|
446
|
+
mamba_value = (
|
|
447
|
+
self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
|
|
448
|
+
.unsqueeze(-1)
|
|
449
|
+
.clone()
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
new_prefix_len, mamba_exist = self.insert(
|
|
453
|
+
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
|
454
|
+
page_aligned_kv_indices,
|
|
455
|
+
mamba_value,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
self.token_to_kv_pool_allocator.free(
|
|
459
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
|
|
463
|
+
self.dec_lock_ref(req.last_node)
|
|
464
|
+
|
|
465
|
+
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
|
466
|
+
"""Cache request when it is unfinished."""
|
|
467
|
+
if self.disable:
|
|
468
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
|
469
|
+
req.req_pool_idx, : len(req.fill_ids)
|
|
470
|
+
]
|
|
471
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
|
472
|
+
req.prefix_indices = kv_indices
|
|
473
|
+
return
|
|
474
|
+
|
|
475
|
+
token_ids = req.fill_ids
|
|
476
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
|
477
|
+
req.req_pool_idx, : len(token_ids)
|
|
478
|
+
]
|
|
479
|
+
page_aligned_len = len(kv_indices)
|
|
480
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
481
|
+
page_aligned_token_ids = token_ids[:page_aligned_len]
|
|
482
|
+
|
|
483
|
+
mamba_value = self.req_to_token_pool.get_mamba_indices(
|
|
484
|
+
req.req_pool_idx
|
|
485
|
+
).unsqueeze(-1)
|
|
486
|
+
# radix tree mamba value is forked from req space
|
|
487
|
+
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
|
|
488
|
+
|
|
489
|
+
# if alloc mamba cache failed, do evict and alloc again
|
|
490
|
+
if mamba_value_forked is None:
|
|
491
|
+
self.evict_mamba(1)
|
|
492
|
+
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(
|
|
493
|
+
mamba_value
|
|
494
|
+
)
|
|
495
|
+
assert mamba_value_forked is not None, "Can not alloc mamba cache"
|
|
496
|
+
new_prefix_len, mamba_exist = self.insert(
|
|
497
|
+
RadixKey(page_aligned_token_ids, req.extra_key),
|
|
498
|
+
page_aligned_kv_indices,
|
|
499
|
+
mamba_value_forked,
|
|
500
|
+
)
|
|
501
|
+
self.token_to_kv_pool_allocator.free(
|
|
502
|
+
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
|
503
|
+
)
|
|
504
|
+
# there is a mamba cache in radix cache, release it
|
|
505
|
+
if mamba_exist:
|
|
506
|
+
self.req_to_token_pool.mamba_pool.free(mamba_value_forked)
|
|
507
|
+
|
|
508
|
+
# The prefix indices could be updated, reuse it
|
|
509
|
+
new_indices, new_last_node, _, _ = self.match_prefix(
|
|
510
|
+
RadixKey(page_aligned_token_ids, req.extra_key)
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
if not mamba_exist:
|
|
514
|
+
assert torch.equal(new_last_node.mamba_value, mamba_value_forked)
|
|
515
|
+
|
|
516
|
+
assert len(req.prefix_indices) <= len(
|
|
517
|
+
new_indices
|
|
518
|
+
), f"{req.prefix_indices=}, {new_indices=}"
|
|
519
|
+
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
|
520
|
+
|
|
521
|
+
self.req_to_token_pool.write(
|
|
522
|
+
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
|
523
|
+
new_indices[len(req.prefix_indices) :],
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
self.dec_lock_ref(req.last_node)
|
|
527
|
+
self.inc_lock_ref(new_last_node)
|
|
528
|
+
|
|
529
|
+
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
|
530
|
+
req.prefix_indices = new_indices
|
|
531
|
+
req.last_node = new_last_node
|
|
532
|
+
|
|
533
|
+
def pretty_print(self) -> None:
|
|
534
|
+
self._print_helper(self.root_node, 0)
|
|
535
|
+
total_size, total_mamba_size = self._total_size_helper()
|
|
536
|
+
print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}")
|
|
537
|
+
|
|
538
|
+
def total_size(self) -> Tuple[int, int]:
|
|
539
|
+
return self._total_size_helper()
|
|
540
|
+
|
|
541
|
+
def _evict_leaf_node(
|
|
542
|
+
self, x: TreeNode, is_evict_mamba: bool
|
|
543
|
+
) -> Tuple[int, int, TreeNode, TreeNode]:
|
|
544
|
+
assert (
|
|
545
|
+
x.full_lock_ref == 0 and x.mamba_lock_ref == 0
|
|
546
|
+
), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}"
|
|
547
|
+
|
|
548
|
+
assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}"
|
|
549
|
+
# 1. a leaf node, free full tokens and mamba
|
|
550
|
+
self.token_to_kv_pool_allocator.free(x.value)
|
|
551
|
+
full_num_evicted = len(x.value)
|
|
552
|
+
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
|
|
553
|
+
mamba_num_evicted = len(x.mamba_value)
|
|
554
|
+
|
|
555
|
+
# 2. get the next node, update the lru lists
|
|
556
|
+
if is_evict_mamba:
|
|
557
|
+
x_next = self.mamba_lru_list.get_prev_no_lock(x)
|
|
558
|
+
else:
|
|
559
|
+
x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
|
|
560
|
+
self.full_lru_list.remove_node(x)
|
|
561
|
+
self.mamba_lru_list.remove_node(x)
|
|
562
|
+
|
|
563
|
+
# 3. delete the leaf node
|
|
564
|
+
self._delete_leaf(x)
|
|
565
|
+
|
|
566
|
+
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
|
|
567
|
+
x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
|
|
568
|
+
full_num_evicted += leaf_full_num_evicted
|
|
569
|
+
return full_num_evicted, mamba_num_evicted, x, x_next
|
|
570
|
+
|
|
571
|
+
def evict_mamba(self, mamba_num: int) -> None:
|
|
572
|
+
if self.disable or mamba_num <= 0:
|
|
573
|
+
return
|
|
574
|
+
# get the least recently used node that is not locked, doesn't have to be a leaf
|
|
575
|
+
x = self.mamba_lru_list.get_lru_no_lock()
|
|
576
|
+
mamba_num_evicted = 0
|
|
577
|
+
# evict lru leaf nodes until mamba_num_tokens is reached
|
|
578
|
+
while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)):
|
|
579
|
+
assert x.mamba_value is not None, f"node has no mamba value, {x.id=}"
|
|
580
|
+
assert (
|
|
581
|
+
len(x.mamba_value) == 1
|
|
582
|
+
), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}"
|
|
583
|
+
assert x != self.root_node, f"root node is not evictable, {x.id=}"
|
|
584
|
+
assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}"
|
|
585
|
+
|
|
586
|
+
if len(x.children) > 0:
|
|
587
|
+
# 1. an internal node, free mamba tokens.
|
|
588
|
+
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
|
|
589
|
+
mamba_num_evicted += len(x.mamba_value)
|
|
590
|
+
|
|
591
|
+
# 2. get the next node, update the lru lists
|
|
592
|
+
x_next = self.mamba_lru_list.get_prev_no_lock(x)
|
|
593
|
+
self.mamba_lru_list.remove_node(x)
|
|
594
|
+
|
|
595
|
+
# 3. tombstone the node
|
|
596
|
+
self._tombstone_internal_node(x)
|
|
597
|
+
else:
|
|
598
|
+
_, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True)
|
|
599
|
+
mamba_num_evicted += mamba_evicted_delta
|
|
600
|
+
|
|
601
|
+
x = x_next
|
|
602
|
+
|
|
603
|
+
def evict(self, full_num_tokens: int) -> None:
|
|
604
|
+
if self.disable or full_num_tokens <= 0:
|
|
605
|
+
return
|
|
606
|
+
|
|
607
|
+
full_num_evicted = 0
|
|
608
|
+
# get the least recently used leaf node that is not locked
|
|
609
|
+
x = self.full_lru_list.get_leaf_lru_no_lock()
|
|
610
|
+
|
|
611
|
+
while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
|
|
612
|
+
assert (
|
|
613
|
+
x != self.root_node
|
|
614
|
+
), f"root node should not exist in full lru list, {x.id=}"
|
|
615
|
+
full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False)
|
|
616
|
+
full_num_evicted += full_num_evicted_delta
|
|
617
|
+
|
|
618
|
+
# if parent has no more children, it is a leaf. It is possible that this node is lru, so
|
|
619
|
+
# we need to get the first leaf node in the lru list
|
|
620
|
+
if len(x.parent.children) == 0:
|
|
621
|
+
x_next = self.full_lru_list.get_leaf_lru_no_lock()
|
|
622
|
+
|
|
623
|
+
x = x_next
|
|
624
|
+
|
|
625
|
+
def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
|
|
626
|
+
"""
|
|
627
|
+
Increment the lock reference count for the node.
|
|
628
|
+
It locks the full_lock_ref for nodes between the [last node, root), exclusive.
|
|
629
|
+
It locks the mamba_lock_ref for current node if its mamba_value exists.
|
|
630
|
+
"""
|
|
631
|
+
if self.disable:
|
|
632
|
+
return None
|
|
633
|
+
|
|
634
|
+
# protect mamba value in current node if it exists
|
|
635
|
+
if node.mamba_value is not None:
|
|
636
|
+
if node.mamba_lock_ref == 0:
|
|
637
|
+
self.mamba_evictable_size_ -= len(node.mamba_value)
|
|
638
|
+
self.mamba_protected_size_ += len(node.mamba_value)
|
|
639
|
+
node.mamba_lock_ref += 1
|
|
640
|
+
|
|
641
|
+
while node != self.root_node:
|
|
642
|
+
# lock full from node to root
|
|
643
|
+
assert (
|
|
644
|
+
node.full_lock_ref >= 0
|
|
645
|
+
), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
|
646
|
+
if node.full_lock_ref == 0:
|
|
647
|
+
self.full_evictable_size_ -= len(node.value)
|
|
648
|
+
self.full_protected_size_ += len(node.value)
|
|
649
|
+
node.full_lock_ref += 1
|
|
650
|
+
node = node.parent
|
|
651
|
+
return None
|
|
652
|
+
|
|
653
|
+
def dec_lock_ref(self, node: TreeNode):
|
|
654
|
+
"""
|
|
655
|
+
Decrement the lock reference count for the node.
|
|
656
|
+
It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
|
|
657
|
+
It unlocks the mamba_lock_ref for current node if its mamba_value exists.
|
|
658
|
+
"""
|
|
659
|
+
if self.disable:
|
|
660
|
+
return
|
|
661
|
+
|
|
662
|
+
if node.mamba_value is not None:
|
|
663
|
+
assert (
|
|
664
|
+
node.mamba_lock_ref > 0
|
|
665
|
+
), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}"
|
|
666
|
+
if node.mamba_lock_ref == 1:
|
|
667
|
+
self.mamba_evictable_size_ += len(node.mamba_value)
|
|
668
|
+
self.mamba_protected_size_ -= len(node.mamba_value)
|
|
669
|
+
node.mamba_lock_ref -= 1
|
|
670
|
+
|
|
671
|
+
while node != self.root_node:
|
|
672
|
+
assert (
|
|
673
|
+
node.full_lock_ref > 0
|
|
674
|
+
), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
|
|
675
|
+
if node.full_lock_ref == 1:
|
|
676
|
+
self.full_evictable_size_ += len(node.value)
|
|
677
|
+
self.full_protected_size_ -= len(node.value)
|
|
678
|
+
node.full_lock_ref -= 1
|
|
679
|
+
node = node.parent
|
|
680
|
+
|
|
681
|
+
def sanity_check(self):
|
|
682
|
+
self.full_lru_list.sanity_check(self)
|
|
683
|
+
self.mamba_lru_list.sanity_check(self)
|
|
684
|
+
|
|
685
|
+
def evictable_size(self) -> Tuple[int, int]:
|
|
686
|
+
# Note: use full_evictable_size() and mamba_evictable_size() instead.
|
|
687
|
+
raise NotImplementedError
|
|
688
|
+
|
|
689
|
+
def full_evictable_size(self) -> int:
|
|
690
|
+
return self.full_evictable_size_
|
|
691
|
+
|
|
692
|
+
def mamba_evictable_size(self) -> int:
|
|
693
|
+
return self.mamba_evictable_size_
|
|
694
|
+
|
|
695
|
+
# Note: this is expensive, only use for debug
|
|
696
|
+
def full_lru_list_evictable_size(self) -> int:
|
|
697
|
+
return self.full_lru_list.sanity_check_evictable_size()
|
|
698
|
+
|
|
699
|
+
# Note: this is expensive, only use for debug
|
|
700
|
+
def mamba_lru_list_evictable_size(self) -> int:
|
|
701
|
+
return self.mamba_lru_list.sanity_check_evictable_size()
|
|
702
|
+
|
|
703
|
+
def protected_size(self) -> Tuple[int, int]:
|
|
704
|
+
# Note: use full_protected_size() and mamba_protected_size() instead.
|
|
705
|
+
raise NotImplementedError
|
|
706
|
+
|
|
707
|
+
def full_protected_size(self) -> int:
|
|
708
|
+
# protected size refers to the size of the full cache that is locked
|
|
709
|
+
return self.full_protected_size_
|
|
710
|
+
|
|
711
|
+
def mamba_protected_size(self) -> int:
|
|
712
|
+
# protected size refers to the size of the mamba cache that is locked
|
|
713
|
+
return self.mamba_protected_size_
|
|
714
|
+
|
|
715
|
+
def all_values_flatten(self) -> torch.Tensor:
|
|
716
|
+
values = []
|
|
717
|
+
|
|
718
|
+
def _dfs_helper(node: TreeNode):
|
|
719
|
+
for _, child in node.children.items():
|
|
720
|
+
values.append(child.value)
|
|
721
|
+
_dfs_helper(child)
|
|
722
|
+
|
|
723
|
+
_dfs_helper(self.root_node)
|
|
724
|
+
return torch.cat(values)
|
|
725
|
+
|
|
726
|
+
##### Internal Helper Functions #####
|
|
727
|
+
|
|
728
|
+
def _match_prefix_helper(
|
|
729
|
+
self, key: RadixKey
|
|
730
|
+
) -> Tuple[List[torch.Tensor], TreeNode]:
|
|
731
|
+
"""
|
|
732
|
+
Mamba prefix matching helper. It factors in the sliding window size such that
|
|
733
|
+
the matched node is guaranteed to either 1. connected to root without mamba tombstone,
|
|
734
|
+
or 2. the number of matching tokens from the matched node to the last mamba tombstone
|
|
735
|
+
node is greater than or equal to the sliding window size.
|
|
736
|
+
"""
|
|
737
|
+
node = self.root_node
|
|
738
|
+
child_key = self.get_child_key_fn(key)
|
|
739
|
+
|
|
740
|
+
value = []
|
|
741
|
+
best_value_len = 0
|
|
742
|
+
best_last_node = node
|
|
743
|
+
while len(key) > 0 and child_key in node.children.keys():
|
|
744
|
+
child = node.children[child_key]
|
|
745
|
+
# update best_value_len and best_last_node if needed
|
|
746
|
+
if node.mamba_value is not None:
|
|
747
|
+
best_value_len = len(value)
|
|
748
|
+
best_last_node = node
|
|
749
|
+
|
|
750
|
+
prefix_len = self.key_match_fn(child.key, key)
|
|
751
|
+
if prefix_len < len(child.key):
|
|
752
|
+
new_node = self._split_node(child.key, child, prefix_len)
|
|
753
|
+
value.append(new_node.value)
|
|
754
|
+
node = new_node
|
|
755
|
+
break
|
|
756
|
+
else:
|
|
757
|
+
value.append(child.value)
|
|
758
|
+
node = child
|
|
759
|
+
key = key[prefix_len:]
|
|
760
|
+
|
|
761
|
+
if len(key):
|
|
762
|
+
child_key = self.get_child_key_fn(key)
|
|
763
|
+
# handle best_value_len and best_last_node, for the case that last node is fully matched
|
|
764
|
+
if node.mamba_value is not None:
|
|
765
|
+
best_value_len = len(value)
|
|
766
|
+
best_last_node = node
|
|
767
|
+
|
|
768
|
+
# update time for matched nodes, and make nodes closer to root to be least recently used
|
|
769
|
+
# this allows mamba to evict nodes closer to root first
|
|
770
|
+
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
|
771
|
+
self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
|
|
772
|
+
|
|
773
|
+
# This last_access_time is for sanity check, can be deleted after validation in production
|
|
774
|
+
cur_time = time.monotonic()
|
|
775
|
+
while node:
|
|
776
|
+
node.last_access_time = cur_time
|
|
777
|
+
cur_time -= 0.0001
|
|
778
|
+
node = node.parent
|
|
779
|
+
|
|
780
|
+
return value[:best_value_len], best_last_node
|
|
781
|
+
|
|
782
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
|
|
783
|
+
# new_node -> child
|
|
784
|
+
new_node = TreeNode()
|
|
785
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
|
786
|
+
new_node.parent = child.parent
|
|
787
|
+
new_node.mamba_value = None # mamba cache can not be split
|
|
788
|
+
new_node.full_lock_ref = child.full_lock_ref
|
|
789
|
+
new_node.mamba_lock_ref = 0
|
|
790
|
+
new_node.key = child.key[:split_len]
|
|
791
|
+
new_node.value = child.value[:split_len]
|
|
792
|
+
|
|
793
|
+
# child time should be later than parent's time for mamba tombstone
|
|
794
|
+
child.last_access_time = time.monotonic()
|
|
795
|
+
|
|
796
|
+
self.full_lru_list.remove_node(child)
|
|
797
|
+
if child.mamba_value is not None:
|
|
798
|
+
self.mamba_lru_list.remove_node(child)
|
|
799
|
+
child.parent = new_node
|
|
800
|
+
child.key = child.key[split_len:]
|
|
801
|
+
child.value = child.value[split_len:]
|
|
802
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
|
803
|
+
|
|
804
|
+
# insert the new node and child into the lru lists, insert
|
|
805
|
+
# parent first so that parent is after child in the lru list
|
|
806
|
+
self.full_lru_list.insert_mru(new_node)
|
|
807
|
+
self.full_lru_list.insert_mru(child)
|
|
808
|
+
if child.mamba_value is not None:
|
|
809
|
+
self.mamba_lru_list.insert_mru(child)
|
|
810
|
+
return new_node
|
|
811
|
+
|
|
812
|
+
def _insert_helper(
|
|
813
|
+
self,
|
|
814
|
+
node: TreeNode,
|
|
815
|
+
key: RadixKey,
|
|
816
|
+
value,
|
|
817
|
+
mamba_value,
|
|
818
|
+
) -> Tuple[int, bool]:
|
|
819
|
+
# Update the last access time from root to leaf, so that
|
|
820
|
+
# mamba will tombstone the node closer to root first
|
|
821
|
+
assert mamba_value is not None, "Mamba value should not be None here."
|
|
822
|
+
node.last_access_time = time.monotonic()
|
|
823
|
+
if node != self.root_node:
|
|
824
|
+
self.full_lru_list.reset_node_mru(node)
|
|
825
|
+
if node.mamba_value is not None:
|
|
826
|
+
self.mamba_lru_list.reset_node_mru(node)
|
|
827
|
+
if len(key) == 0:
|
|
828
|
+
return 0, True
|
|
829
|
+
|
|
830
|
+
child_key = self.get_child_key_fn(key)
|
|
831
|
+
|
|
832
|
+
total_prefix_length = 0
|
|
833
|
+
while len(key) > 0 and child_key in node.children.keys():
|
|
834
|
+
node = node.children[child_key]
|
|
835
|
+
node.last_access_time = time.monotonic()
|
|
836
|
+
self.full_lru_list.reset_node_mru(node)
|
|
837
|
+
if node.mamba_value is not None:
|
|
838
|
+
self.mamba_lru_list.reset_node_mru(node)
|
|
839
|
+
prefix_len = self.key_match_fn(node.key, key)
|
|
840
|
+
total_prefix_length += prefix_len
|
|
841
|
+
key = key[prefix_len:]
|
|
842
|
+
value = value[prefix_len:]
|
|
843
|
+
|
|
844
|
+
if prefix_len < len(node.key):
|
|
845
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
|
846
|
+
node = new_node
|
|
847
|
+
|
|
848
|
+
if len(key):
|
|
849
|
+
child_key = self.get_child_key_fn(key)
|
|
850
|
+
|
|
851
|
+
mamba_value_exist = False
|
|
852
|
+
if len(key):
|
|
853
|
+
new_node = TreeNode()
|
|
854
|
+
new_node.parent = node
|
|
855
|
+
new_node.key = key
|
|
856
|
+
new_node.value = value
|
|
857
|
+
new_node.mamba_value = mamba_value
|
|
858
|
+
self.full_lru_list.insert_mru(new_node)
|
|
859
|
+
self.full_evictable_size_ += len(value)
|
|
860
|
+
self.mamba_evictable_size_ += len(mamba_value)
|
|
861
|
+
self.mamba_lru_list.insert_mru(new_node)
|
|
862
|
+
node.children[child_key] = new_node
|
|
863
|
+
elif node.mamba_value is None: # add for mamba tombstone
|
|
864
|
+
node.mamba_value = mamba_value
|
|
865
|
+
self.mamba_evictable_size_ += len(mamba_value)
|
|
866
|
+
self.mamba_lru_list.insert_mru(node)
|
|
867
|
+
else:
|
|
868
|
+
mamba_value_exist = True
|
|
869
|
+
self.mamba_lru_list.reset_node_mru(node)
|
|
870
|
+
|
|
871
|
+
return total_prefix_length, mamba_value_exist
|
|
872
|
+
|
|
873
|
+
def _iteratively_delete_tombstone_leaf(
|
|
874
|
+
self, node: TreeNode
|
|
875
|
+
) -> Tuple[TreeNode, int]:
|
|
876
|
+
full_num_evicted = 0
|
|
877
|
+
while node.parent.mamba_value is None and len(node.parent.children) == 0:
|
|
878
|
+
# root node is not evictable
|
|
879
|
+
if node.parent == self.root_node:
|
|
880
|
+
break
|
|
881
|
+
# if locked, means node is in use, skip
|
|
882
|
+
if node.parent.full_lock_ref > 0:
|
|
883
|
+
break
|
|
884
|
+
assert (
|
|
885
|
+
node.parent.mamba_lock_ref == 0
|
|
886
|
+
), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}"
|
|
887
|
+
# delete tombstone node evicts full tokens
|
|
888
|
+
self.token_to_kv_pool_allocator.free(node.parent.value)
|
|
889
|
+
full_num_evicted += len(node.parent.value)
|
|
890
|
+
self.full_lru_list.remove_node(node.parent)
|
|
891
|
+
self._delete_tombstone_leaf(node.parent)
|
|
892
|
+
node = node.parent
|
|
893
|
+
|
|
894
|
+
return node, full_num_evicted
|
|
895
|
+
|
|
896
|
+
def _delete_leaf(self, node: TreeNode) -> None:
|
|
897
|
+
assert (
|
|
898
|
+
node.mamba_value is not None
|
|
899
|
+
), f"Invariant violated: leaf node is a tombstone, {node.id=}"
|
|
900
|
+
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
|
901
|
+
for k, v in node.parent.children.items():
|
|
902
|
+
if v == node:
|
|
903
|
+
break
|
|
904
|
+
del node.parent.children[k]
|
|
905
|
+
self.full_evictable_size_ -= len(node.key)
|
|
906
|
+
self.mamba_evictable_size_ -= len(node.mamba_value)
|
|
907
|
+
|
|
908
|
+
def _tombstone_internal_node(self, node: TreeNode) -> None:
|
|
909
|
+
assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
|
|
910
|
+
self.mamba_evictable_size_ -= len(node.mamba_value)
|
|
911
|
+
node.mamba_value = None
|
|
912
|
+
|
|
913
|
+
def _delete_tombstone_leaf(self, node: TreeNode) -> None:
|
|
914
|
+
assert (
|
|
915
|
+
node.mamba_value is None
|
|
916
|
+
), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
|
|
917
|
+
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
|
|
918
|
+
for k, v in node.parent.children.items():
|
|
919
|
+
if v == node:
|
|
920
|
+
break
|
|
921
|
+
del node.parent.children[k]
|
|
922
|
+
self.full_evictable_size_ -= len(node.key)
|
|
923
|
+
|
|
924
|
+
def _collect_leaves(self) -> List[TreeNode]:
|
|
925
|
+
ret_list = []
|
|
926
|
+
stack = [self.root_node]
|
|
927
|
+
|
|
928
|
+
while stack:
|
|
929
|
+
cur_node = stack.pop()
|
|
930
|
+
if len(cur_node.children) == 0:
|
|
931
|
+
ret_list.append(cur_node)
|
|
932
|
+
else:
|
|
933
|
+
stack.extend(cur_node.children.values())
|
|
934
|
+
|
|
935
|
+
return ret_list
|
|
936
|
+
|
|
937
|
+
def _collect_nontombstone_nodes(self) -> List[TreeNode]:
|
|
938
|
+
ret_list = []
|
|
939
|
+
stack = [self.root_node]
|
|
940
|
+
|
|
941
|
+
while stack:
|
|
942
|
+
cur_node = stack.pop()
|
|
943
|
+
if cur_node.mamba_value is not None:
|
|
944
|
+
ret_list.append(cur_node)
|
|
945
|
+
stack.extend(cur_node.children.values())
|
|
946
|
+
|
|
947
|
+
return ret_list
|
|
948
|
+
|
|
949
|
+
def _collect_all_nodes(self) -> List[TreeNode]:
|
|
950
|
+
ret_list = []
|
|
951
|
+
stack = [self.root_node]
|
|
952
|
+
while stack:
|
|
953
|
+
cur_node = stack.pop()
|
|
954
|
+
ret_list.append(cur_node)
|
|
955
|
+
stack.extend(cur_node.children.values())
|
|
956
|
+
return ret_list
|
|
957
|
+
|
|
958
|
+
def _print_helper(self, node: TreeNode, indent: int) -> None:
|
|
959
|
+
"""Prints the radix tree in a human-readable format."""
|
|
960
|
+
stack = [(node, indent)]
|
|
961
|
+
while stack:
|
|
962
|
+
current_node, current_indent = stack.pop()
|
|
963
|
+
print(
|
|
964
|
+
" " * current_indent,
|
|
965
|
+
f"[{current_node.id}]",
|
|
966
|
+
len(current_node.key),
|
|
967
|
+
f"fr={current_node.full_lock_ref}",
|
|
968
|
+
f"mr={current_node.mamba_lock_ref}",
|
|
969
|
+
f"fll={self.full_lru_list.in_list(current_node)}",
|
|
970
|
+
f"mll={self.mamba_lru_list.in_list(current_node)}",
|
|
971
|
+
f"mv={current_node.mamba_value}",
|
|
972
|
+
)
|
|
973
|
+
for key, child in current_node.children.items():
|
|
974
|
+
stack.append((child, current_indent + 2))
|
|
975
|
+
|
|
976
|
+
assert key == self.get_child_key_fn(
|
|
977
|
+
child.key
|
|
978
|
+
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
|
979
|
+
|
|
980
|
+
def _total_size_helper(self) -> Tuple[int, int]:
|
|
981
|
+
total_size = 0
|
|
982
|
+
total_mamba_size = 0
|
|
983
|
+
stack = [self.root_node]
|
|
984
|
+
while stack:
|
|
985
|
+
current_node = stack.pop()
|
|
986
|
+
total_size += len(current_node.value)
|
|
987
|
+
if current_node.mamba_value is not None:
|
|
988
|
+
total_mamba_size += len(current_node.mamba_value)
|
|
989
|
+
for child in current_node.children.values():
|
|
990
|
+
if child.evicted:
|
|
991
|
+
continue
|
|
992
|
+
stack.append(child)
|
|
993
|
+
return total_size, total_mamba_size
|