sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
36
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
|
37
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
38
39
|
|
39
40
|
if TYPE_CHECKING:
|
40
41
|
from sglang.srt.managers.schedule_batch import Req
|
41
42
|
|
42
43
|
|
44
|
+
class RadixKey:
|
45
|
+
|
46
|
+
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
|
47
|
+
# token ids sequence
|
48
|
+
self.token_ids = token_ids
|
49
|
+
# extra key (e.g. lora_id, cache_salt)
|
50
|
+
self.extra_key = extra_key
|
51
|
+
|
52
|
+
def __len__(self) -> int:
|
53
|
+
return len(self.token_ids)
|
54
|
+
|
55
|
+
def __iter__(self) -> Iterator[int]:
|
56
|
+
return iter(self.token_ids)
|
57
|
+
|
58
|
+
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
|
59
|
+
if isinstance(idx, slice):
|
60
|
+
return RadixKey(self.token_ids[idx], self.extra_key)
|
61
|
+
return RadixKey([self.token_ids[idx]], self.extra_key)
|
62
|
+
|
63
|
+
def __repr__(self) -> str:
|
64
|
+
preview = self.token_ids[:10]
|
65
|
+
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
|
66
|
+
|
67
|
+
|
43
68
|
class TreeNode:
|
44
69
|
|
45
70
|
counter = 0
|
@@ -47,14 +72,12 @@ class TreeNode:
|
|
47
72
|
def __init__(self, id: Optional[int] = None):
|
48
73
|
self.children = defaultdict(TreeNode)
|
49
74
|
self.parent: TreeNode = None
|
50
|
-
self.key:
|
75
|
+
self.key: RadixKey = None
|
51
76
|
self.value: Optional[torch.Tensor] = None
|
52
77
|
self.lock_ref = 0
|
53
78
|
self.last_access_time = time.monotonic()
|
54
79
|
|
55
80
|
self.hit_count = 0
|
56
|
-
# indicating the node is loading KV cache from host
|
57
|
-
self.loading = False
|
58
81
|
# indicating the node is locked to protect from eviction
|
59
82
|
# incremented when the node is referenced by a storage operation
|
60
83
|
self.host_ref_counter = 0
|
@@ -95,27 +118,57 @@ class TreeNode:
|
|
95
118
|
return self.last_access_time < other.last_access_time
|
96
119
|
|
97
120
|
|
98
|
-
def
|
121
|
+
def _check_extra_key(key0: RadixKey, key1: RadixKey):
|
122
|
+
if key0.extra_key != key1.extra_key:
|
123
|
+
raise ValueError(
|
124
|
+
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
|
129
|
+
_check_extra_key(key0, key1)
|
99
130
|
i = 0
|
100
|
-
for k0, k1 in zip(key0, key1):
|
131
|
+
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
101
132
|
if k0 != k1:
|
102
133
|
break
|
103
134
|
i += 1
|
104
135
|
return i
|
105
136
|
|
106
137
|
|
107
|
-
def _key_match_paged(key0:
|
138
|
+
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
|
139
|
+
_check_extra_key(key0, key1)
|
108
140
|
min_len = min(len(key0), len(key1))
|
109
141
|
|
110
142
|
i = 0
|
111
143
|
while i < min_len:
|
112
|
-
if key0[i : i + page_size] != key1[i : i + page_size]:
|
144
|
+
if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
|
113
145
|
break
|
114
146
|
i += page_size
|
115
147
|
|
116
148
|
return i
|
117
149
|
|
118
150
|
|
151
|
+
def get_child_key(key: RadixKey, page_size: int = 1):
|
152
|
+
if page_size == 1:
|
153
|
+
plain_key = key.token_ids[0]
|
154
|
+
else:
|
155
|
+
plain_key = tuple(key.token_ids[:page_size])
|
156
|
+
if key.extra_key is None:
|
157
|
+
return plain_key
|
158
|
+
else:
|
159
|
+
return (key.extra_key, plain_key)
|
160
|
+
|
161
|
+
|
162
|
+
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
163
|
+
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
164
|
+
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
165
|
+
if len(tokens) < 2:
|
166
|
+
return []
|
167
|
+
if isinstance(tokens[0], tuple):
|
168
|
+
return tokens
|
169
|
+
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
170
|
+
|
171
|
+
|
119
172
|
class RadixCache(BasePrefixCache):
|
120
173
|
def __init__(
|
121
174
|
self,
|
@@ -124,6 +177,8 @@ class RadixCache(BasePrefixCache):
|
|
124
177
|
page_size: int,
|
125
178
|
disable: bool = False,
|
126
179
|
enable_kv_cache_events: bool = False,
|
180
|
+
eviction_policy: str = "lru",
|
181
|
+
is_eagle: bool = False,
|
127
182
|
):
|
128
183
|
self.req_to_token_pool = req_to_token_pool
|
129
184
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
@@ -131,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
|
131
186
|
self.disable = disable
|
132
187
|
self.enable_kv_cache_events = enable_kv_cache_events
|
133
188
|
self.kv_event_queue = []
|
189
|
+
self.is_eagle = is_eagle
|
134
190
|
|
135
191
|
if self.token_to_kv_pool_allocator:
|
136
192
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -139,17 +195,31 @@ class RadixCache(BasePrefixCache):
|
|
139
195
|
|
140
196
|
if self.page_size == 1:
|
141
197
|
self.key_match_fn = _key_match_page_size1
|
142
|
-
self.get_child_key_fn =
|
198
|
+
self.get_child_key_fn = get_child_key
|
143
199
|
else:
|
144
200
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
145
|
-
self.get_child_key_fn =
|
201
|
+
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
202
|
+
|
203
|
+
if is_eagle:
|
204
|
+
self.key_convert_fn = _convert_to_bigram_key
|
205
|
+
else:
|
206
|
+
self.key_convert_fn = lambda key: key
|
207
|
+
|
208
|
+
if eviction_policy.lower() == "lru":
|
209
|
+
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
210
|
+
elif eviction_policy.lower() == "lfu":
|
211
|
+
self.eviction_strategy: EvictionStrategy = LFUStrategy()
|
212
|
+
else:
|
213
|
+
raise ValueError(
|
214
|
+
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
|
215
|
+
)
|
146
216
|
self.reset()
|
147
217
|
|
148
218
|
##### Public API #####
|
149
219
|
|
150
220
|
def reset(self):
|
151
221
|
self.root_node = TreeNode()
|
152
|
-
self.root_node.key = []
|
222
|
+
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
|
153
223
|
self.root_node.value = []
|
154
224
|
self.root_node.host_value = []
|
155
225
|
self.root_node.lock_ref = 1
|
@@ -157,18 +227,47 @@ class RadixCache(BasePrefixCache):
|
|
157
227
|
self.protected_size_ = 0
|
158
228
|
self._record_all_cleared_event()
|
159
229
|
|
160
|
-
def match_prefix(self, key:
|
161
|
-
"""Find the
|
230
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
231
|
+
"""Find the longest cached prefix of ``key`` in the radix tree.
|
232
|
+
|
233
|
+
The logical namespace for prefix matching is determined by both the
|
234
|
+
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
|
235
|
+
Entries that share identical leading token ids but have *different*
|
236
|
+
``extra_key`` values are intentionally kept disjoint and never share
|
237
|
+
prefix nodes. This is useful to:
|
238
|
+
|
239
|
+
* Isolate KV cache lines for different LoRA / adapter IDs.
|
240
|
+
* Separate requests that intentionally should not share state (e.g.,
|
241
|
+
different sampling salt, cache version, or retrieval augmentation
|
242
|
+
context) by supplying a distinct ``extra_key``.
|
243
|
+
|
162
244
|
Args:
|
163
|
-
key:
|
245
|
+
key (RadixKey): The lookup key containing a list of token ids and an
|
246
|
+
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
|
247
|
+
length is internally truncated to a multiple of ``page_size``
|
248
|
+
before matching. Passing an empty key returns an empty result
|
249
|
+
with the root as the last node.
|
250
|
+
**kwargs: Reserved for future extensions (ignored currently).
|
251
|
+
|
164
252
|
Returns:
|
165
|
-
|
166
|
-
the
|
167
|
-
|
168
|
-
|
169
|
-
|
253
|
+
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
|
254
|
+
the concatenated KV cache indices corresponding to the longest
|
255
|
+
cached prefix (may be length 0). ``last_device_node`` and
|
256
|
+
``last_host_node`` (currently the same) are the tree node objects
|
257
|
+
representing the terminal node of the matched prefix. This method
|
258
|
+
may mutate internal structure by splitting an existing node if the
|
259
|
+
match ends inside a stored segment.
|
260
|
+
|
261
|
+
Internal updates:
|
262
|
+
* Refreshes access metadata (timestamps) used by the
|
263
|
+
configured eviction strategy.
|
264
|
+
* If the lookup ends inside a stored segment the node is split once
|
265
|
+
to expose a precise boundary; this structural refinement improves
|
266
|
+
subsequent match efficiency and does not duplicate data.
|
170
267
|
"""
|
171
|
-
|
268
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
269
|
+
|
270
|
+
def empty_match_result():
|
172
271
|
return MatchResult(
|
173
272
|
device_indices=torch.empty(
|
174
273
|
(0,),
|
@@ -179,10 +278,16 @@ class RadixCache(BasePrefixCache):
|
|
179
278
|
last_host_node=self.root_node,
|
180
279
|
)
|
181
280
|
|
281
|
+
if self.disable or len(key) == 0:
|
282
|
+
return empty_match_result()
|
283
|
+
|
182
284
|
if self.page_size != 1:
|
183
285
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
184
286
|
key = key[:page_aligned_len]
|
185
287
|
|
288
|
+
if len(key) == 0:
|
289
|
+
return empty_match_result()
|
290
|
+
|
186
291
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
187
292
|
if value:
|
188
293
|
value = torch.cat(value)
|
@@ -194,12 +299,19 @@ class RadixCache(BasePrefixCache):
|
|
194
299
|
last_host_node=last_node,
|
195
300
|
)
|
196
301
|
|
197
|
-
def insert(self, key:
|
302
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
198
303
|
if self.disable:
|
199
304
|
return 0
|
200
305
|
|
306
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
307
|
+
|
201
308
|
if value is None:
|
202
|
-
value =
|
309
|
+
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
310
|
+
|
311
|
+
if self.is_eagle:
|
312
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
313
|
+
value = value[: len(key)]
|
314
|
+
|
203
315
|
return self._insert_helper(self.root_node, key, value)
|
204
316
|
|
205
317
|
def cache_finished_req(self, req: Req):
|
@@ -213,27 +325,39 @@ class RadixCache(BasePrefixCache):
|
|
213
325
|
return
|
214
326
|
|
215
327
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
328
|
+
all_token_len = len(token_ids)
|
329
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
216
330
|
kv_indices = self.req_to_token_pool.req_to_token[
|
217
|
-
req.req_pool_idx, :
|
331
|
+
req.req_pool_idx, :all_token_len
|
218
332
|
]
|
219
333
|
|
220
334
|
if self.page_size != 1:
|
221
|
-
page_aligned_len =
|
335
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
222
336
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
223
337
|
dtype=torch.int64, copy=True
|
224
338
|
)
|
225
339
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
226
340
|
else:
|
227
|
-
page_aligned_len =
|
341
|
+
page_aligned_len = actual_kv_len
|
228
342
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
343
|
+
if self.is_eagle:
|
344
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
345
|
+
|
346
|
+
page_aligned_token_len = (
|
347
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
348
|
+
)
|
349
|
+
|
350
|
+
old_prefix_len = len(req.prefix_indices)
|
351
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
352
|
+
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
353
|
+
old_prefix_len -= 1
|
229
354
|
|
230
355
|
# Radix Cache takes one ref in memory pool
|
231
356
|
new_prefix_len = self.insert(
|
232
|
-
token_ids[:
|
233
|
-
|
234
|
-
self.token_to_kv_pool_allocator.free(
|
235
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
357
|
+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
358
|
+
page_aligned_kv_indices,
|
236
359
|
)
|
360
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
237
361
|
|
238
362
|
# Remove req slot release the cache lock
|
239
363
|
self.req_to_token_pool.free(req.req_pool_idx)
|
@@ -245,45 +369,73 @@ class RadixCache(BasePrefixCache):
|
|
245
369
|
return
|
246
370
|
|
247
371
|
token_ids = req.fill_ids
|
372
|
+
all_token_len = len(token_ids)
|
373
|
+
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
374
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
248
375
|
kv_indices = self.req_to_token_pool.req_to_token[
|
249
|
-
req.req_pool_idx, :
|
376
|
+
req.req_pool_idx, :all_token_len
|
250
377
|
]
|
251
378
|
|
252
379
|
if self.page_size != 1:
|
253
|
-
page_aligned_len =
|
380
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
254
381
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
255
382
|
dtype=torch.int64, copy=True
|
256
383
|
)
|
257
384
|
else:
|
258
|
-
page_aligned_len =
|
385
|
+
page_aligned_len = actual_kv_len
|
259
386
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
260
|
-
|
387
|
+
|
388
|
+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
389
|
+
page_aligned_token_len = (
|
390
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
391
|
+
)
|
392
|
+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
393
|
+
|
394
|
+
old_prefix_len = len(req.prefix_indices)
|
395
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
396
|
+
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
397
|
+
old_prefix_len -= 1
|
261
398
|
|
262
399
|
# Radix Cache takes one ref in memory pool
|
263
400
|
new_prefix_len = self.insert(
|
264
|
-
page_aligned_token_ids,
|
265
|
-
|
266
|
-
|
267
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
401
|
+
RadixKey(page_aligned_token_ids, req.extra_key),
|
402
|
+
page_aligned_kv_indices,
|
403
|
+
chunked=chunked,
|
268
404
|
)
|
405
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
269
406
|
|
270
407
|
# The prefix indices could be updated, reuse it
|
271
|
-
new_indices, new_last_node, _, _ = self.match_prefix(
|
408
|
+
new_indices, new_last_node, _, _ = self.match_prefix(
|
409
|
+
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
410
|
+
)
|
272
411
|
self.req_to_token_pool.write(
|
273
|
-
(req.req_pool_idx, slice(
|
274
|
-
new_indices[
|
412
|
+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
413
|
+
new_indices[old_prefix_len:],
|
275
414
|
)
|
276
415
|
|
416
|
+
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
417
|
+
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
418
|
+
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
419
|
+
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
420
|
+
req.last_matched_prefix_len = len(new_indices)
|
421
|
+
|
277
422
|
self.dec_lock_ref(req.last_node)
|
278
423
|
self.inc_lock_ref(new_last_node)
|
279
424
|
|
280
425
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
281
426
|
if self.page_size != 1:
|
427
|
+
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
282
428
|
req.prefix_indices = torch.cat(
|
283
429
|
[new_indices, kv_indices[len(new_indices) :]]
|
284
430
|
)
|
285
431
|
else:
|
286
|
-
|
432
|
+
if self.is_eagle:
|
433
|
+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
434
|
+
req.prefix_indices = torch.cat(
|
435
|
+
[new_indices, kv_indices[actual_kv_len:]]
|
436
|
+
)
|
437
|
+
else:
|
438
|
+
req.prefix_indices = new_indices
|
287
439
|
req.last_node = new_last_node
|
288
440
|
|
289
441
|
def pretty_print(self):
|
@@ -298,11 +450,14 @@ class RadixCache(BasePrefixCache):
|
|
298
450
|
return
|
299
451
|
|
300
452
|
leaves = self._collect_leaves()
|
301
|
-
|
453
|
+
eviction_heap = [
|
454
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
455
|
+
]
|
456
|
+
heapq.heapify(eviction_heap)
|
302
457
|
|
303
458
|
num_evicted = 0
|
304
|
-
while num_evicted < num_tokens and len(
|
305
|
-
x = heapq.heappop(
|
459
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
460
|
+
_priority, x = heapq.heappop(eviction_heap)
|
306
461
|
|
307
462
|
if x == self.root_node:
|
308
463
|
break
|
@@ -314,7 +469,8 @@ class RadixCache(BasePrefixCache):
|
|
314
469
|
self._delete_leaf(x)
|
315
470
|
|
316
471
|
if len(x.parent.children) == 0:
|
317
|
-
|
472
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
473
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
318
474
|
|
319
475
|
self._record_remove_event(x)
|
320
476
|
|
@@ -325,9 +481,9 @@ class RadixCache(BasePrefixCache):
|
|
325
481
|
delta = 0
|
326
482
|
while node != self.root_node:
|
327
483
|
if node.lock_ref == 0:
|
328
|
-
self.evictable_size_ -= len(node.
|
329
|
-
self.protected_size_ += len(node.
|
330
|
-
delta -= len(node.
|
484
|
+
self.evictable_size_ -= len(node.key)
|
485
|
+
self.protected_size_ += len(node.key)
|
486
|
+
delta -= len(node.key)
|
331
487
|
node.lock_ref += 1
|
332
488
|
node = node.parent
|
333
489
|
return delta
|
@@ -339,9 +495,9 @@ class RadixCache(BasePrefixCache):
|
|
339
495
|
delta = 0
|
340
496
|
while node != self.root_node:
|
341
497
|
if node.lock_ref == 1:
|
342
|
-
self.evictable_size_ += len(node.
|
343
|
-
self.protected_size_ -= len(node.
|
344
|
-
delta += len(node.
|
498
|
+
self.evictable_size_ += len(node.key)
|
499
|
+
self.protected_size_ -= len(node.key)
|
500
|
+
delta += len(node.key)
|
345
501
|
node.lock_ref -= 1
|
346
502
|
node = node.parent
|
347
503
|
return delta
|
@@ -366,7 +522,7 @@ class RadixCache(BasePrefixCache):
|
|
366
522
|
|
367
523
|
##### Internal Helper Functions #####
|
368
524
|
|
369
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
525
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
370
526
|
node.last_access_time = time.monotonic()
|
371
527
|
|
372
528
|
child_key = self.get_child_key_fn(key)
|
@@ -391,7 +547,7 @@ class RadixCache(BasePrefixCache):
|
|
391
547
|
|
392
548
|
return value, node
|
393
549
|
|
394
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
550
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
395
551
|
# new_node -> child
|
396
552
|
self._record_remove_event(child)
|
397
553
|
new_node = TreeNode()
|
@@ -410,7 +566,7 @@ class RadixCache(BasePrefixCache):
|
|
410
566
|
|
411
567
|
return new_node
|
412
568
|
|
413
|
-
def _insert_helper(self, node: TreeNode, key:
|
569
|
+
def _insert_helper(self, node: TreeNode, key: RadixKey, value):
|
414
570
|
node.last_access_time = time.monotonic()
|
415
571
|
if len(key) == 0:
|
416
572
|
return 0
|
@@ -439,7 +595,7 @@ class RadixCache(BasePrefixCache):
|
|
439
595
|
new_node.key = key
|
440
596
|
new_node.value = value
|
441
597
|
node.children[child_key] = new_node
|
442
|
-
self.evictable_size_ += len(
|
598
|
+
self.evictable_size_ += len(key)
|
443
599
|
self._record_store_event(new_node)
|
444
600
|
return total_prefix_length
|
445
601
|
|
@@ -451,7 +607,7 @@ class RadixCache(BasePrefixCache):
|
|
451
607
|
print(
|
452
608
|
" " * current_indent,
|
453
609
|
len(current_node.key),
|
454
|
-
current_node.key[:10],
|
610
|
+
current_node.key.token_ids[:10],
|
455
611
|
f"r={current_node.lock_ref}",
|
456
612
|
)
|
457
613
|
for key, child in current_node.children.items():
|
@@ -503,11 +659,11 @@ class RadixCache(BasePrefixCache):
|
|
503
659
|
last_page_start = (
|
504
660
|
(len(node.parent.key) - 1) // self.page_size
|
505
661
|
) * self.page_size
|
506
|
-
parent_parent_tokens = node.parent.key[last_page_start:]
|
662
|
+
parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
|
507
663
|
parent_block_hash = hash(tuple(parent_parent_tokens))
|
508
664
|
|
509
665
|
for start in range(0, len(node.key), self.page_size):
|
510
|
-
page_tokens = node.key[start : start + self.page_size]
|
666
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
511
667
|
if not page_tokens:
|
512
668
|
continue
|
513
669
|
|
@@ -530,7 +686,7 @@ class RadixCache(BasePrefixCache):
|
|
530
686
|
# One BlockRemoved per chunk.
|
531
687
|
if self.enable_kv_cache_events:
|
532
688
|
for start in range(0, len(node.key), self.page_size):
|
533
|
-
page_tokens = node.key[start : start + self.page_size]
|
689
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
534
690
|
if not page_tokens:
|
535
691
|
continue
|
536
692
|
block_hash = hash(tuple(page_tokens))
|
@@ -556,19 +712,12 @@ class RadixCache(BasePrefixCache):
|
|
556
712
|
if __name__ == "__main__":
|
557
713
|
tree = RadixCache(None, None, page_size=1, disable=False)
|
558
714
|
|
559
|
-
|
560
|
-
tree.insert(
|
561
|
-
tree.insert(
|
562
|
-
|
563
|
-
|
715
|
+
# Example token id sequences (as lists of ints)
|
716
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
717
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
718
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
|
719
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
|
720
|
+
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
|
564
721
|
tree.pretty_print()
|
565
722
|
|
566
|
-
|
567
|
-
|
568
|
-
# def evict_callback(x):
|
569
|
-
# print("evict", x)
|
570
|
-
# return len(x)
|
571
|
-
|
572
|
-
# tree.evict(5, evict_callback)
|
573
|
-
# tree.evict(10, evict_callback)
|
574
|
-
# tree.pretty_print()
|
723
|
+
print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
|
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
|
13
13
|
TreeNodeCpp,
|
14
14
|
)
|
15
15
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
|
|
93
94
|
raise NotImplementedError("Host cache is not supported yet")
|
94
95
|
self.tree.reset()
|
95
96
|
|
96
|
-
def match_prefix(self, key:
|
97
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
97
98
|
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
-
self.tree.match_prefix(key)
|
99
|
+
self.tree.match_prefix(key.token_ids)
|
99
100
|
)
|
100
101
|
return MatchResult(
|
101
102
|
device_indices=self._merge_tensor(device_indices_vec),
|
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
104
105
|
host_hit_length=host_indices_length,
|
105
106
|
)
|
106
107
|
|
107
|
-
def _insert(self, key:
|
108
|
+
def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
|
108
109
|
"""
|
109
110
|
Insert a key-value pair into the radix tree.
|
110
111
|
Args:
|
111
|
-
key (
|
112
|
+
key (RadixKey): The key to insert, represented as a RadixKey.
|
112
113
|
value (torch.Tensor): The value to associate with the key.
|
113
114
|
Returns:
|
114
115
|
int: Number of device indices that were already present in the tree before the insertion.
|
115
116
|
"""
|
116
|
-
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
ongoing_write, length = self.tree.writing_through(key.token_ids, value)
|
117
118
|
if self.cache_controller is None:
|
118
119
|
assert len(ongoing_write) == 0, "Implementation error"
|
119
120
|
return length
|
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
160
161
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
162
|
# it will automatically align them, but length of them should be equal
|
162
163
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
164
165
|
|
165
166
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
167
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
191
192
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
193
|
# it will automatically align them, but length of them should be equal
|
193
194
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
195
196
|
|
196
197
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
198
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
199
|
|
199
200
|
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
201
|
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
-
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
202
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
203
|
+
RadixKey(token_ids, req.extra_key).token_ids
|
204
|
+
)
|
202
205
|
new_indices = self._merge_tensor(new_indices_vec)
|
203
206
|
assert new_prefix_len <= len(new_indices)
|
204
207
|
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
|
3
|
+
|
4
|
+
"""Storage backend module for SGLang HiCache."""
|
5
|
+
|
6
|
+
from .backend_factory import StorageBackendFactory
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"StorageBackendFactory",
|
10
|
+
]
|