sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -34,12 +34,37 @@ from sglang.srt.disaggregation.kv_events import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
36
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
|
37
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
38
39
|
|
39
40
|
if TYPE_CHECKING:
|
40
41
|
from sglang.srt.managers.schedule_batch import Req
|
41
42
|
|
42
43
|
|
44
|
+
class RadixKey:
|
45
|
+
|
46
|
+
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
|
47
|
+
# token ids sequence
|
48
|
+
self.token_ids = token_ids
|
49
|
+
# extra key (e.g. lora_id, cache_salt)
|
50
|
+
self.extra_key = extra_key
|
51
|
+
|
52
|
+
def __len__(self) -> int:
|
53
|
+
return len(self.token_ids)
|
54
|
+
|
55
|
+
def __iter__(self) -> Iterator[int]:
|
56
|
+
return iter(self.token_ids)
|
57
|
+
|
58
|
+
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
|
59
|
+
if isinstance(idx, slice):
|
60
|
+
return RadixKey(self.token_ids[idx], self.extra_key)
|
61
|
+
return RadixKey([self.token_ids[idx]], self.extra_key)
|
62
|
+
|
63
|
+
def __repr__(self) -> str:
|
64
|
+
preview = self.token_ids[:10]
|
65
|
+
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
|
66
|
+
|
67
|
+
|
43
68
|
class TreeNode:
|
44
69
|
|
45
70
|
counter = 0
|
@@ -47,14 +72,12 @@ class TreeNode:
|
|
47
72
|
def __init__(self, id: Optional[int] = None):
|
48
73
|
self.children = defaultdict(TreeNode)
|
49
74
|
self.parent: TreeNode = None
|
50
|
-
self.key:
|
75
|
+
self.key: RadixKey = None
|
51
76
|
self.value: Optional[torch.Tensor] = None
|
52
77
|
self.lock_ref = 0
|
53
78
|
self.last_access_time = time.monotonic()
|
54
79
|
|
55
80
|
self.hit_count = 0
|
56
|
-
# indicating the node is loading KV cache from host
|
57
|
-
self.loading = False
|
58
81
|
# indicating the node is locked to protect from eviction
|
59
82
|
# incremented when the node is referenced by a storage operation
|
60
83
|
self.host_ref_counter = 0
|
@@ -95,27 +118,57 @@ class TreeNode:
|
|
95
118
|
return self.last_access_time < other.last_access_time
|
96
119
|
|
97
120
|
|
98
|
-
def
|
121
|
+
def _check_extra_key(key0: RadixKey, key1: RadixKey):
|
122
|
+
if key0.extra_key != key1.extra_key:
|
123
|
+
raise ValueError(
|
124
|
+
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
|
129
|
+
_check_extra_key(key0, key1)
|
99
130
|
i = 0
|
100
|
-
for k0, k1 in zip(key0, key1):
|
131
|
+
for k0, k1 in zip(key0.token_ids, key1.token_ids):
|
101
132
|
if k0 != k1:
|
102
133
|
break
|
103
134
|
i += 1
|
104
135
|
return i
|
105
136
|
|
106
137
|
|
107
|
-
def _key_match_paged(key0:
|
138
|
+
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
|
139
|
+
_check_extra_key(key0, key1)
|
108
140
|
min_len = min(len(key0), len(key1))
|
109
141
|
|
110
142
|
i = 0
|
111
143
|
while i < min_len:
|
112
|
-
if key0[i : i + page_size] != key1[i : i + page_size]:
|
144
|
+
if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
|
113
145
|
break
|
114
146
|
i += page_size
|
115
147
|
|
116
148
|
return i
|
117
149
|
|
118
150
|
|
151
|
+
def get_child_key(key: RadixKey, page_size: int = 1):
|
152
|
+
if page_size == 1:
|
153
|
+
plain_key = key.token_ids[0]
|
154
|
+
else:
|
155
|
+
plain_key = tuple(key.token_ids[:page_size])
|
156
|
+
if key.extra_key is None:
|
157
|
+
return plain_key
|
158
|
+
else:
|
159
|
+
return (key.extra_key, plain_key)
|
160
|
+
|
161
|
+
|
162
|
+
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
163
|
+
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
164
|
+
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
165
|
+
if len(tokens) < 2:
|
166
|
+
return []
|
167
|
+
if isinstance(tokens[0], tuple):
|
168
|
+
return tokens
|
169
|
+
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
170
|
+
|
171
|
+
|
119
172
|
class RadixCache(BasePrefixCache):
|
120
173
|
def __init__(
|
121
174
|
self,
|
@@ -124,6 +177,8 @@ class RadixCache(BasePrefixCache):
|
|
124
177
|
page_size: int,
|
125
178
|
disable: bool = False,
|
126
179
|
enable_kv_cache_events: bool = False,
|
180
|
+
eviction_policy: str = "lru",
|
181
|
+
is_eagle: bool = False,
|
127
182
|
):
|
128
183
|
self.req_to_token_pool = req_to_token_pool
|
129
184
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
@@ -131,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
|
131
186
|
self.disable = disable
|
132
187
|
self.enable_kv_cache_events = enable_kv_cache_events
|
133
188
|
self.kv_event_queue = []
|
189
|
+
self.is_eagle = is_eagle
|
134
190
|
|
135
191
|
if self.token_to_kv_pool_allocator:
|
136
192
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -139,17 +195,31 @@ class RadixCache(BasePrefixCache):
|
|
139
195
|
|
140
196
|
if self.page_size == 1:
|
141
197
|
self.key_match_fn = _key_match_page_size1
|
142
|
-
self.get_child_key_fn =
|
198
|
+
self.get_child_key_fn = get_child_key
|
143
199
|
else:
|
144
200
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
145
|
-
self.get_child_key_fn =
|
201
|
+
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
202
|
+
|
203
|
+
if is_eagle:
|
204
|
+
self.key_convert_fn = _convert_to_bigram_key
|
205
|
+
else:
|
206
|
+
self.key_convert_fn = lambda key: key
|
207
|
+
|
208
|
+
if eviction_policy.lower() == "lru":
|
209
|
+
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
210
|
+
elif eviction_policy.lower() == "lfu":
|
211
|
+
self.eviction_strategy: EvictionStrategy = LFUStrategy()
|
212
|
+
else:
|
213
|
+
raise ValueError(
|
214
|
+
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
|
215
|
+
)
|
146
216
|
self.reset()
|
147
217
|
|
148
218
|
##### Public API #####
|
149
219
|
|
150
220
|
def reset(self):
|
151
221
|
self.root_node = TreeNode()
|
152
|
-
self.root_node.key = []
|
222
|
+
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
|
153
223
|
self.root_node.value = []
|
154
224
|
self.root_node.host_value = []
|
155
225
|
self.root_node.lock_ref = 1
|
@@ -157,18 +227,47 @@ class RadixCache(BasePrefixCache):
|
|
157
227
|
self.protected_size_ = 0
|
158
228
|
self._record_all_cleared_event()
|
159
229
|
|
160
|
-
def match_prefix(self, key:
|
161
|
-
"""Find the
|
230
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
231
|
+
"""Find the longest cached prefix of ``key`` in the radix tree.
|
232
|
+
|
233
|
+
The logical namespace for prefix matching is determined by both the
|
234
|
+
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
|
235
|
+
Entries that share identical leading token ids but have *different*
|
236
|
+
``extra_key`` values are intentionally kept disjoint and never share
|
237
|
+
prefix nodes. This is useful to:
|
238
|
+
|
239
|
+
* Isolate KV cache lines for different LoRA / adapter IDs.
|
240
|
+
* Separate requests that intentionally should not share state (e.g.,
|
241
|
+
different sampling salt, cache version, or retrieval augmentation
|
242
|
+
context) by supplying a distinct ``extra_key``.
|
243
|
+
|
162
244
|
Args:
|
163
|
-
key:
|
245
|
+
key (RadixKey): The lookup key containing a list of token ids and an
|
246
|
+
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
|
247
|
+
length is internally truncated to a multiple of ``page_size``
|
248
|
+
before matching. Passing an empty key returns an empty result
|
249
|
+
with the root as the last node.
|
250
|
+
**kwargs: Reserved for future extensions (ignored currently).
|
251
|
+
|
164
252
|
Returns:
|
165
|
-
|
166
|
-
the
|
167
|
-
|
168
|
-
|
169
|
-
|
253
|
+
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
|
254
|
+
the concatenated KV cache indices corresponding to the longest
|
255
|
+
cached prefix (may be length 0). ``last_device_node`` and
|
256
|
+
``last_host_node`` (currently the same) are the tree node objects
|
257
|
+
representing the terminal node of the matched prefix. This method
|
258
|
+
may mutate internal structure by splitting an existing node if the
|
259
|
+
match ends inside a stored segment.
|
260
|
+
|
261
|
+
Internal updates:
|
262
|
+
* Refreshes access metadata (timestamps) used by the
|
263
|
+
configured eviction strategy.
|
264
|
+
* If the lookup ends inside a stored segment the node is split once
|
265
|
+
to expose a precise boundary; this structural refinement improves
|
266
|
+
subsequent match efficiency and does not duplicate data.
|
170
267
|
"""
|
171
|
-
|
268
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
269
|
+
|
270
|
+
def empty_match_result():
|
172
271
|
return MatchResult(
|
173
272
|
device_indices=torch.empty(
|
174
273
|
(0,),
|
@@ -179,10 +278,16 @@ class RadixCache(BasePrefixCache):
|
|
179
278
|
last_host_node=self.root_node,
|
180
279
|
)
|
181
280
|
|
281
|
+
if self.disable or len(key) == 0:
|
282
|
+
return empty_match_result()
|
283
|
+
|
182
284
|
if self.page_size != 1:
|
183
285
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
184
286
|
key = key[:page_aligned_len]
|
185
287
|
|
288
|
+
if len(key) == 0:
|
289
|
+
return empty_match_result()
|
290
|
+
|
186
291
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
187
292
|
if value:
|
188
293
|
value = torch.cat(value)
|
@@ -194,12 +299,19 @@ class RadixCache(BasePrefixCache):
|
|
194
299
|
last_host_node=last_node,
|
195
300
|
)
|
196
301
|
|
197
|
-
def insert(self, key:
|
302
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
198
303
|
if self.disable:
|
199
304
|
return 0
|
200
305
|
|
306
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
307
|
+
|
201
308
|
if value is None:
|
202
|
-
value =
|
309
|
+
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
310
|
+
|
311
|
+
if self.is_eagle:
|
312
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
313
|
+
value = value[: len(key)]
|
314
|
+
|
203
315
|
return self._insert_helper(self.root_node, key, value)
|
204
316
|
|
205
317
|
def cache_finished_req(self, req: Req):
|
@@ -213,27 +325,42 @@ class RadixCache(BasePrefixCache):
|
|
213
325
|
return
|
214
326
|
|
215
327
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
328
|
+
all_token_len = len(token_ids)
|
329
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
330
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
331
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
216
332
|
kv_indices = self.req_to_token_pool.req_to_token[
|
217
|
-
req.req_pool_idx, :
|
333
|
+
req.req_pool_idx, :all_token_len
|
218
334
|
]
|
219
335
|
|
220
336
|
if self.page_size != 1:
|
221
|
-
page_aligned_len =
|
337
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
222
338
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
223
339
|
dtype=torch.int64, copy=True
|
224
340
|
)
|
225
341
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
226
342
|
else:
|
227
|
-
page_aligned_len =
|
343
|
+
page_aligned_len = actual_kv_len
|
228
344
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
345
|
+
if self.is_eagle:
|
346
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
347
|
+
|
348
|
+
page_aligned_token_len = (
|
349
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
350
|
+
)
|
351
|
+
|
352
|
+
old_prefix_len = len(req.prefix_indices)
|
353
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
354
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
355
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
356
|
+
old_prefix_len -= 1
|
229
357
|
|
230
358
|
# Radix Cache takes one ref in memory pool
|
231
359
|
new_prefix_len = self.insert(
|
232
|
-
token_ids[:
|
233
|
-
|
234
|
-
self.token_to_kv_pool_allocator.free(
|
235
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
360
|
+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
361
|
+
page_aligned_kv_indices,
|
236
362
|
)
|
363
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
237
364
|
|
238
365
|
# Remove req slot release the cache lock
|
239
366
|
self.req_to_token_pool.free(req.req_pool_idx)
|
@@ -245,45 +372,75 @@ class RadixCache(BasePrefixCache):
|
|
245
372
|
return
|
246
373
|
|
247
374
|
token_ids = req.fill_ids
|
375
|
+
all_token_len = len(token_ids)
|
376
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
377
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
378
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
248
379
|
kv_indices = self.req_to_token_pool.req_to_token[
|
249
|
-
req.req_pool_idx, :
|
380
|
+
req.req_pool_idx, :all_token_len
|
250
381
|
]
|
251
382
|
|
252
383
|
if self.page_size != 1:
|
253
|
-
page_aligned_len =
|
384
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
254
385
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
255
386
|
dtype=torch.int64, copy=True
|
256
387
|
)
|
257
388
|
else:
|
258
|
-
page_aligned_len =
|
389
|
+
page_aligned_len = actual_kv_len
|
259
390
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
260
|
-
|
391
|
+
|
392
|
+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
393
|
+
page_aligned_token_len = (
|
394
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
395
|
+
)
|
396
|
+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
397
|
+
|
398
|
+
old_prefix_len = len(req.prefix_indices)
|
399
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
400
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
401
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
402
|
+
old_prefix_len -= 1
|
261
403
|
|
262
404
|
# Radix Cache takes one ref in memory pool
|
263
405
|
new_prefix_len = self.insert(
|
264
|
-
page_aligned_token_ids,
|
265
|
-
|
266
|
-
|
267
|
-
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
406
|
+
RadixKey(page_aligned_token_ids, req.extra_key),
|
407
|
+
page_aligned_kv_indices,
|
408
|
+
chunked=chunked,
|
268
409
|
)
|
410
|
+
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
269
411
|
|
270
412
|
# The prefix indices could be updated, reuse it
|
271
|
-
new_indices, new_last_node, _, _ = self.match_prefix(
|
413
|
+
new_indices, new_last_node, _, _ = self.match_prefix(
|
414
|
+
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
415
|
+
)
|
272
416
|
self.req_to_token_pool.write(
|
273
|
-
(req.req_pool_idx, slice(
|
274
|
-
new_indices[
|
417
|
+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
418
|
+
new_indices[old_prefix_len:],
|
275
419
|
)
|
276
420
|
|
421
|
+
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
422
|
+
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
423
|
+
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
424
|
+
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
425
|
+
req.last_matched_prefix_len = len(new_indices)
|
426
|
+
|
277
427
|
self.dec_lock_ref(req.last_node)
|
278
428
|
self.inc_lock_ref(new_last_node)
|
279
429
|
|
280
430
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
281
431
|
if self.page_size != 1:
|
432
|
+
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
282
433
|
req.prefix_indices = torch.cat(
|
283
434
|
[new_indices, kv_indices[len(new_indices) :]]
|
284
435
|
)
|
285
436
|
else:
|
286
|
-
|
437
|
+
if self.is_eagle:
|
438
|
+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
439
|
+
req.prefix_indices = torch.cat(
|
440
|
+
[new_indices, kv_indices[actual_kv_len:]]
|
441
|
+
)
|
442
|
+
else:
|
443
|
+
req.prefix_indices = new_indices
|
287
444
|
req.last_node = new_last_node
|
288
445
|
|
289
446
|
def pretty_print(self):
|
@@ -298,11 +455,14 @@ class RadixCache(BasePrefixCache):
|
|
298
455
|
return
|
299
456
|
|
300
457
|
leaves = self._collect_leaves()
|
301
|
-
|
458
|
+
eviction_heap = [
|
459
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
460
|
+
]
|
461
|
+
heapq.heapify(eviction_heap)
|
302
462
|
|
303
463
|
num_evicted = 0
|
304
|
-
while num_evicted < num_tokens and len(
|
305
|
-
x = heapq.heappop(
|
464
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
465
|
+
_priority, x = heapq.heappop(eviction_heap)
|
306
466
|
|
307
467
|
if x == self.root_node:
|
308
468
|
break
|
@@ -314,7 +474,8 @@ class RadixCache(BasePrefixCache):
|
|
314
474
|
self._delete_leaf(x)
|
315
475
|
|
316
476
|
if len(x.parent.children) == 0:
|
317
|
-
|
477
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
478
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
318
479
|
|
319
480
|
self._record_remove_event(x)
|
320
481
|
|
@@ -325,9 +486,9 @@ class RadixCache(BasePrefixCache):
|
|
325
486
|
delta = 0
|
326
487
|
while node != self.root_node:
|
327
488
|
if node.lock_ref == 0:
|
328
|
-
self.evictable_size_ -= len(node.
|
329
|
-
self.protected_size_ += len(node.
|
330
|
-
delta -= len(node.
|
489
|
+
self.evictable_size_ -= len(node.key)
|
490
|
+
self.protected_size_ += len(node.key)
|
491
|
+
delta -= len(node.key)
|
331
492
|
node.lock_ref += 1
|
332
493
|
node = node.parent
|
333
494
|
return delta
|
@@ -339,9 +500,9 @@ class RadixCache(BasePrefixCache):
|
|
339
500
|
delta = 0
|
340
501
|
while node != self.root_node:
|
341
502
|
if node.lock_ref == 1:
|
342
|
-
self.evictable_size_ += len(node.
|
343
|
-
self.protected_size_ -= len(node.
|
344
|
-
delta += len(node.
|
503
|
+
self.evictable_size_ += len(node.key)
|
504
|
+
self.protected_size_ -= len(node.key)
|
505
|
+
delta += len(node.key)
|
345
506
|
node.lock_ref -= 1
|
346
507
|
node = node.parent
|
347
508
|
return delta
|
@@ -366,7 +527,7 @@ class RadixCache(BasePrefixCache):
|
|
366
527
|
|
367
528
|
##### Internal Helper Functions #####
|
368
529
|
|
369
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
530
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
370
531
|
node.last_access_time = time.monotonic()
|
371
532
|
|
372
533
|
child_key = self.get_child_key_fn(key)
|
@@ -391,7 +552,7 @@ class RadixCache(BasePrefixCache):
|
|
391
552
|
|
392
553
|
return value, node
|
393
554
|
|
394
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
555
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
395
556
|
# new_node -> child
|
396
557
|
self._record_remove_event(child)
|
397
558
|
new_node = TreeNode()
|
@@ -410,7 +571,7 @@ class RadixCache(BasePrefixCache):
|
|
410
571
|
|
411
572
|
return new_node
|
412
573
|
|
413
|
-
def _insert_helper(self, node: TreeNode, key:
|
574
|
+
def _insert_helper(self, node: TreeNode, key: RadixKey, value):
|
414
575
|
node.last_access_time = time.monotonic()
|
415
576
|
if len(key) == 0:
|
416
577
|
return 0
|
@@ -439,7 +600,7 @@ class RadixCache(BasePrefixCache):
|
|
439
600
|
new_node.key = key
|
440
601
|
new_node.value = value
|
441
602
|
node.children[child_key] = new_node
|
442
|
-
self.evictable_size_ += len(
|
603
|
+
self.evictable_size_ += len(key)
|
443
604
|
self._record_store_event(new_node)
|
444
605
|
return total_prefix_length
|
445
606
|
|
@@ -451,7 +612,7 @@ class RadixCache(BasePrefixCache):
|
|
451
612
|
print(
|
452
613
|
" " * current_indent,
|
453
614
|
len(current_node.key),
|
454
|
-
current_node.key[:10],
|
615
|
+
current_node.key.token_ids[:10],
|
455
616
|
f"r={current_node.lock_ref}",
|
456
617
|
)
|
457
618
|
for key, child in current_node.children.items():
|
@@ -503,11 +664,11 @@ class RadixCache(BasePrefixCache):
|
|
503
664
|
last_page_start = (
|
504
665
|
(len(node.parent.key) - 1) // self.page_size
|
505
666
|
) * self.page_size
|
506
|
-
parent_parent_tokens = node.parent.key[last_page_start:]
|
667
|
+
parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
|
507
668
|
parent_block_hash = hash(tuple(parent_parent_tokens))
|
508
669
|
|
509
670
|
for start in range(0, len(node.key), self.page_size):
|
510
|
-
page_tokens = node.key[start : start + self.page_size]
|
671
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
511
672
|
if not page_tokens:
|
512
673
|
continue
|
513
674
|
|
@@ -530,7 +691,7 @@ class RadixCache(BasePrefixCache):
|
|
530
691
|
# One BlockRemoved per chunk.
|
531
692
|
if self.enable_kv_cache_events:
|
532
693
|
for start in range(0, len(node.key), self.page_size):
|
533
|
-
page_tokens = node.key[start : start + self.page_size]
|
694
|
+
page_tokens = node.key.token_ids[start : start + self.page_size]
|
534
695
|
if not page_tokens:
|
535
696
|
continue
|
536
697
|
block_hash = hash(tuple(page_tokens))
|
@@ -556,19 +717,12 @@ class RadixCache(BasePrefixCache):
|
|
556
717
|
if __name__ == "__main__":
|
557
718
|
tree = RadixCache(None, None, page_size=1, disable=False)
|
558
719
|
|
559
|
-
|
560
|
-
tree.insert(
|
561
|
-
tree.insert(
|
562
|
-
|
563
|
-
|
720
|
+
# Example token id sequences (as lists of ints)
|
721
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
722
|
+
tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
|
723
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
|
724
|
+
tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
|
725
|
+
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
|
564
726
|
tree.pretty_print()
|
565
727
|
|
566
|
-
|
567
|
-
|
568
|
-
# def evict_callback(x):
|
569
|
-
# print("evict", x)
|
570
|
-
# return len(x)
|
571
|
-
|
572
|
-
# tree.evict(5, evict_callback)
|
573
|
-
# tree.evict(10, evict_callback)
|
574
|
-
# tree.pretty_print()
|
728
|
+
print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
|
@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
|
13
13
|
TreeNodeCpp,
|
14
14
|
)
|
15
15
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
|
|
93
94
|
raise NotImplementedError("Host cache is not supported yet")
|
94
95
|
self.tree.reset()
|
95
96
|
|
96
|
-
def match_prefix(self, key:
|
97
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
97
98
|
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
-
self.tree.match_prefix(key)
|
99
|
+
self.tree.match_prefix(key.token_ids)
|
99
100
|
)
|
100
101
|
return MatchResult(
|
101
102
|
device_indices=self._merge_tensor(device_indices_vec),
|
@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
104
105
|
host_hit_length=host_indices_length,
|
105
106
|
)
|
106
107
|
|
107
|
-
def _insert(self, key:
|
108
|
+
def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
|
108
109
|
"""
|
109
110
|
Insert a key-value pair into the radix tree.
|
110
111
|
Args:
|
111
|
-
key (
|
112
|
+
key (RadixKey): The key to insert, represented as a RadixKey.
|
112
113
|
value (torch.Tensor): The value to associate with the key.
|
113
114
|
Returns:
|
114
115
|
int: Number of device indices that were already present in the tree before the insertion.
|
115
116
|
"""
|
116
|
-
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
ongoing_write, length = self.tree.writing_through(key.token_ids, value)
|
117
118
|
if self.cache_controller is None:
|
118
119
|
assert len(ongoing_write) == 0, "Implementation error"
|
119
120
|
return length
|
@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
160
161
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
162
|
# it will automatically align them, but length of them should be equal
|
162
163
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
164
165
|
|
165
166
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
167
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
|
|
191
192
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
193
|
# it will automatically align them, but length of them should be equal
|
193
194
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
-
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
195
196
|
|
196
197
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
198
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
199
|
|
199
200
|
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
201
|
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
-
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
202
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
|
203
|
+
RadixKey(token_ids, req.extra_key).token_ids
|
204
|
+
)
|
202
205
|
new_indices = self._merge_tensor(new_indices_vec)
|
203
206
|
assert new_prefix_len <= len(new_indices)
|
204
207
|
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
|
3
|
+
|
4
|
+
"""Storage backend module for SGLang HiCache."""
|
5
|
+
|
6
|
+
from .backend_factory import StorageBackendFactory
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"StorageBackendFactory",
|
10
|
+
]
|