sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
import heapq
|
2
|
+
import json
|
2
3
|
import logging
|
3
4
|
import threading
|
4
5
|
import time
|
5
|
-
from queue import Queue
|
6
6
|
from typing import List, Optional
|
7
7
|
|
8
8
|
import torch
|
@@ -19,7 +19,8 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
19
19
|
MHATokenToKVPoolHost,
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
22
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
23
|
+
from sglang.srt.metrics.collector import StorageMetricsCollector
|
23
24
|
|
24
25
|
logger = logging.getLogger(__name__)
|
25
26
|
|
@@ -37,17 +38,20 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_write_policy: str,
|
38
39
|
hicache_io_backend: str,
|
39
40
|
hicache_mem_layout: str,
|
41
|
+
enable_metrics: bool,
|
42
|
+
eviction_policy: str = "lru",
|
40
43
|
hicache_storage_backend: Optional[str] = None,
|
41
44
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
45
|
model_name: Optional[str] = None,
|
43
46
|
storage_backend_extra_config: Optional[str] = None,
|
47
|
+
is_eagle: bool = False,
|
44
48
|
):
|
45
49
|
|
46
50
|
if hicache_io_backend == "direct":
|
47
51
|
if hicache_mem_layout == "page_first":
|
48
|
-
hicache_mem_layout = "
|
52
|
+
hicache_mem_layout = "page_first_direct"
|
49
53
|
logger.warning(
|
50
|
-
"Page first layout is not supported with direct IO backend, switching to
|
54
|
+
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
|
51
55
|
)
|
52
56
|
|
53
57
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
@@ -73,9 +77,21 @@ class HiRadixCache(RadixCache):
|
|
73
77
|
self.tp_group = tp_cache_group
|
74
78
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
75
79
|
self.enable_storage = hicache_storage_backend is not None
|
76
|
-
|
77
|
-
|
78
|
-
|
80
|
+
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
81
|
+
|
82
|
+
(
|
83
|
+
extra_config,
|
84
|
+
prefetch_threshold,
|
85
|
+
prefetch_timeout_base,
|
86
|
+
prefetch_timeout_per_ki_token,
|
87
|
+
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
88
|
+
self.prefetch_threshold = prefetch_threshold
|
89
|
+
self.prefetch_timeout_base = prefetch_timeout_base
|
90
|
+
self.prefetch_timeout_per_page = (
|
91
|
+
page_size / 1024 * prefetch_timeout_per_ki_token
|
92
|
+
)
|
93
|
+
# TODO: support more timeout check functions
|
94
|
+
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
79
95
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
80
96
|
|
81
97
|
self.load_cache_event = threading.Event()
|
@@ -90,8 +106,16 @@ class HiRadixCache(RadixCache):
|
|
90
106
|
storage_backend=hicache_storage_backend,
|
91
107
|
prefetch_threshold=self.prefetch_threshold,
|
92
108
|
model_name=model_name,
|
93
|
-
storage_backend_extra_config=
|
109
|
+
storage_backend_extra_config=extra_config,
|
94
110
|
)
|
111
|
+
if self.enable_storage_metrics:
|
112
|
+
# TODO: support pp
|
113
|
+
labels = {
|
114
|
+
"storage_backend": hicache_storage_backend,
|
115
|
+
"tp_rank": self.cache_controller.tp_rank,
|
116
|
+
"dp_rank": self.cache_controller.dp_rank,
|
117
|
+
}
|
118
|
+
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
95
119
|
|
96
120
|
# record the nodes with ongoing write through
|
97
121
|
self.ongoing_write_through = {}
|
@@ -105,8 +129,61 @@ class HiRadixCache(RadixCache):
|
|
105
129
|
1 if hicache_write_policy == "write_through" else 2
|
106
130
|
)
|
107
131
|
self.load_back_threshold = 10
|
132
|
+
|
108
133
|
super().__init__(
|
109
|
-
req_to_token_pool,
|
134
|
+
req_to_token_pool,
|
135
|
+
token_to_kv_pool_allocator,
|
136
|
+
page_size,
|
137
|
+
disable=False,
|
138
|
+
eviction_policy=eviction_policy,
|
139
|
+
is_eagle=is_eagle,
|
140
|
+
)
|
141
|
+
|
142
|
+
def _parse_storage_backend_extra_config(
|
143
|
+
self, storage_backend_extra_config: Optional[str]
|
144
|
+
):
|
145
|
+
"""
|
146
|
+
Parse storage backend extra config JSON and extract specific parameters.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
storage_backend_extra_config: JSON string containing extra configuration
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
153
|
+
"""
|
154
|
+
# Parse extra config JSON if provided
|
155
|
+
extra_config = {}
|
156
|
+
if storage_backend_extra_config:
|
157
|
+
try:
|
158
|
+
extra_config = json.loads(storage_backend_extra_config)
|
159
|
+
except Exception as e:
|
160
|
+
logger.error(f"Invalid backend extra config JSON: {e}")
|
161
|
+
raise e
|
162
|
+
|
163
|
+
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
164
|
+
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
165
|
+
prefetch_timeout_per_ki_token = extra_config.pop(
|
166
|
+
"prefetch_timeout_per_ki_token", 0.25
|
167
|
+
) # seconds per 1024 tokens
|
168
|
+
|
169
|
+
if not isinstance(prefetch_threshold, int):
|
170
|
+
raise ValueError(
|
171
|
+
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
172
|
+
)
|
173
|
+
if not isinstance(prefetch_timeout_base, (int, float)):
|
174
|
+
raise ValueError(
|
175
|
+
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
176
|
+
)
|
177
|
+
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
178
|
+
raise ValueError(
|
179
|
+
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
180
|
+
)
|
181
|
+
|
182
|
+
return (
|
183
|
+
extra_config,
|
184
|
+
prefetch_threshold,
|
185
|
+
float(prefetch_timeout_base),
|
186
|
+
float(prefetch_timeout_per_ki_token),
|
110
187
|
)
|
111
188
|
|
112
189
|
def reset(self):
|
@@ -122,11 +199,24 @@ class HiRadixCache(RadixCache):
|
|
122
199
|
height += 1
|
123
200
|
return height
|
124
201
|
|
125
|
-
def clear_storage_backend(self):
|
202
|
+
def clear_storage_backend(self) -> bool:
|
126
203
|
if self.enable_storage:
|
127
|
-
|
128
|
-
|
129
|
-
|
204
|
+
try:
|
205
|
+
# Check if the storage backend has a clear method (for nixl backends)
|
206
|
+
if hasattr(self.cache_controller.storage_backend, "clear"):
|
207
|
+
self.cache_controller.storage_backend.clear()
|
208
|
+
logger.info(
|
209
|
+
"Hierarchical cache storage backend cleared successfully!"
|
210
|
+
)
|
211
|
+
return True
|
212
|
+
else:
|
213
|
+
logger.warning(
|
214
|
+
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
215
|
+
)
|
216
|
+
return False
|
217
|
+
except Exception as e:
|
218
|
+
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
219
|
+
return False
|
130
220
|
else:
|
131
221
|
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
222
|
return False
|
@@ -176,53 +266,72 @@ class HiRadixCache(RadixCache):
|
|
176
266
|
if write_back:
|
177
267
|
# blocking till all write back complete
|
178
268
|
while len(self.ongoing_write_through) > 0:
|
179
|
-
|
180
|
-
|
269
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
270
|
+
finish_event.synchronize()
|
271
|
+
for ack_id in ack_list:
|
272
|
+
del self.ongoing_write_through[ack_id]
|
273
|
+
self.cache_controller.ack_write_queue.clear()
|
274
|
+
assert len(self.ongoing_write_through) == 0
|
181
275
|
return
|
182
|
-
|
183
|
-
|
184
|
-
)
|
276
|
+
|
277
|
+
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
278
|
+
if len(self.ongoing_write_through) == 0:
|
279
|
+
return
|
280
|
+
|
281
|
+
finish_count = 0
|
282
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
283
|
+
if not finish_event.query():
|
284
|
+
break
|
285
|
+
finish_count += 1
|
286
|
+
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
185
287
|
if self.tp_world_size > 1:
|
186
|
-
#
|
288
|
+
# synchronize TP workers to make the same update to radix cache
|
187
289
|
torch.distributed.all_reduce(
|
188
290
|
queue_size,
|
189
291
|
op=torch.distributed.ReduceOp.MIN,
|
190
292
|
group=self.tp_group,
|
191
293
|
)
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
196
|
-
|
197
|
-
|
198
|
-
self.
|
294
|
+
|
295
|
+
finish_count = int(queue_size.item())
|
296
|
+
while finish_count > 0:
|
297
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
298
|
+
finish_event.synchronize()
|
299
|
+
for ack_id in ack_list:
|
300
|
+
backuped_node = self.ongoing_write_through.pop(ack_id)
|
301
|
+
self.dec_lock_ref(backuped_node)
|
302
|
+
if self.enable_storage:
|
303
|
+
self.write_backup_storage(backuped_node)
|
304
|
+
finish_count -= 1
|
199
305
|
|
200
306
|
def loading_check(self):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
self.dec_lock_ref(end_node)
|
206
|
-
while end_node != start_node:
|
207
|
-
assert end_node.loading
|
208
|
-
end_node.loading = False
|
209
|
-
end_node = end_node.parent
|
210
|
-
# clear the reference
|
211
|
-
del self.ongoing_load_back[ack_id]
|
212
|
-
except Exception:
|
307
|
+
finish_count = 0
|
308
|
+
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
309
|
+
if not finish_event.query():
|
310
|
+
# the KV cache loading is still ongoing
|
213
311
|
break
|
312
|
+
finish_count += 1
|
313
|
+
# no need to sync across TP workers as batch forwarding is synced
|
314
|
+
for ack_id in ack_list:
|
315
|
+
end_node = self.ongoing_load_back.pop(ack_id)
|
316
|
+
self.dec_lock_ref(end_node)
|
317
|
+
|
318
|
+
# ACK until all events are processed
|
319
|
+
del self.cache_controller.ack_load_queue[:finish_count]
|
214
320
|
|
215
321
|
def evictable_size(self):
|
216
322
|
return self.evictable_size_
|
217
323
|
|
218
324
|
def evict(self, num_tokens: int):
|
219
325
|
leaves = self._collect_leaves_device()
|
220
|
-
|
326
|
+
eviction_heap = [
|
327
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
328
|
+
]
|
329
|
+
heapq.heapify(eviction_heap)
|
221
330
|
|
222
331
|
num_evicted = 0
|
223
332
|
write_back_nodes = []
|
224
|
-
while num_evicted < num_tokens and len(
|
225
|
-
x = heapq.heappop(
|
333
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
334
|
+
_priority, x = heapq.heappop(eviction_heap)
|
226
335
|
|
227
336
|
if x.lock_ref > 0:
|
228
337
|
continue
|
@@ -244,7 +353,8 @@ class HiRadixCache(RadixCache):
|
|
244
353
|
break
|
245
354
|
else:
|
246
355
|
# all children are evicted or no children
|
247
|
-
|
356
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
357
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
248
358
|
|
249
359
|
if self.cache_controller.write_policy == "write_back":
|
250
360
|
self.writing_check(write_back=True)
|
@@ -254,7 +364,7 @@ class HiRadixCache(RadixCache):
|
|
254
364
|
|
255
365
|
def _evict_backuped(self, node: TreeNode):
|
256
366
|
# evict a node already written to host
|
257
|
-
num_evicted = self.cache_controller.evict_device(node.value
|
367
|
+
num_evicted = self.cache_controller.evict_device(node.value)
|
258
368
|
assert num_evicted > 0
|
259
369
|
self.evictable_size_ -= num_evicted
|
260
370
|
node.value = None
|
@@ -269,11 +379,14 @@ class HiRadixCache(RadixCache):
|
|
269
379
|
|
270
380
|
def evict_host(self, num_tokens: int):
|
271
381
|
leaves = self._collect_leaves()
|
272
|
-
|
382
|
+
eviction_heap = [
|
383
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
384
|
+
]
|
385
|
+
heapq.heapify(eviction_heap)
|
273
386
|
|
274
387
|
num_evicted = 0
|
275
|
-
while num_evicted < num_tokens and len(
|
276
|
-
x = heapq.heappop(
|
388
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
389
|
+
_priority, x = heapq.heappop(eviction_heap)
|
277
390
|
if x == self.root_node:
|
278
391
|
break
|
279
392
|
# only evict the host value of evicted nodes
|
@@ -292,7 +405,8 @@ class HiRadixCache(RadixCache):
|
|
292
405
|
del x.parent.children[k]
|
293
406
|
|
294
407
|
if len(x.parent.children) == 0 and x.parent.evicted:
|
295
|
-
|
408
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
409
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
296
410
|
|
297
411
|
def load_back(
|
298
412
|
self, node: TreeNode, mem_quota: Optional[int] = None
|
@@ -335,12 +449,11 @@ class HiRadixCache(RadixCache):
|
|
335
449
|
# no sufficient GPU memory to load back KV caches
|
336
450
|
return None
|
337
451
|
|
338
|
-
self.ongoing_load_back[last_hit_node.id] =
|
452
|
+
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
339
453
|
offset = 0
|
340
454
|
for node in nodes_to_load:
|
341
455
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
342
456
|
offset += len(node.host_value)
|
343
|
-
node.loading = True
|
344
457
|
self.evictable_size_ += len(device_indices)
|
345
458
|
self.inc_lock_ref(last_hit_node)
|
346
459
|
|
@@ -369,16 +482,22 @@ class HiRadixCache(RadixCache):
|
|
369
482
|
last_node,
|
370
483
|
)
|
371
484
|
|
372
|
-
def ready_to_load_host_cache(self):
|
373
|
-
|
374
|
-
|
375
|
-
|
485
|
+
def ready_to_load_host_cache(self) -> int:
|
486
|
+
"""
|
487
|
+
Notify the cache controller to start the KV cache loading.
|
488
|
+
Return the consumer index for the schedule batch manager to track.
|
489
|
+
"""
|
490
|
+
return self.cache_controller.start_loading()
|
376
491
|
|
377
492
|
def check_hicache_events(self):
|
378
493
|
self.writing_check()
|
379
494
|
self.loading_check()
|
380
495
|
if self.enable_storage:
|
381
496
|
self.drain_storage_control_queues()
|
497
|
+
if self.enable_storage_metrics:
|
498
|
+
self.metrics_collector.log_storage_metrics(
|
499
|
+
self.cache_controller.storage_backend.get_stats()
|
500
|
+
)
|
382
501
|
|
383
502
|
def drain_storage_control_queues(self):
|
384
503
|
"""
|
@@ -414,10 +533,13 @@ class HiRadixCache(RadixCache):
|
|
414
533
|
|
415
534
|
# process backup acks
|
416
535
|
for _ in range(n_backup):
|
417
|
-
|
536
|
+
operation = cc.ack_backup_queue.get()
|
537
|
+
ack_id = operation.id
|
418
538
|
entry = self.ongoing_backup.pop(ack_id, None)
|
419
539
|
if entry is not None:
|
420
540
|
entry.release_host()
|
541
|
+
if self.enable_storage_metrics:
|
542
|
+
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
421
543
|
|
422
544
|
# release host memory
|
423
545
|
host_indices_list = []
|
@@ -427,6 +549,15 @@ class HiRadixCache(RadixCache):
|
|
427
549
|
host_indices = torch.cat(host_indices_list, dim=0)
|
428
550
|
cc.mem_pool_host.free(host_indices)
|
429
551
|
|
552
|
+
# Timeout is linearly increasing with the number of pages
|
553
|
+
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
554
|
+
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
555
|
+
return (
|
556
|
+
time.monotonic() - operation.start_time
|
557
|
+
> self.prefetch_timeout_base
|
558
|
+
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
559
|
+
)
|
560
|
+
|
430
561
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
431
562
|
can_terminate = True
|
432
563
|
|
@@ -443,22 +574,27 @@ class HiRadixCache(RadixCache):
|
|
443
574
|
if self.prefetch_stop_policy == "wait_complete":
|
444
575
|
can_terminate = completed
|
445
576
|
elif self.prefetch_stop_policy == "timeout":
|
446
|
-
can_terminate = completed or (
|
447
|
-
time.monotonic() - operation.start_time > self.prefetch_timeout
|
448
|
-
)
|
577
|
+
can_terminate = completed or self.is_prefetch_timeout(operation)
|
449
578
|
else:
|
450
579
|
# unknown prefetch stop policy, just return True
|
451
580
|
return True
|
452
581
|
|
582
|
+
operation_terminated = operation.is_terminated()
|
453
583
|
if self.tp_world_size > 1:
|
454
|
-
|
584
|
+
states = torch.tensor(
|
585
|
+
[1 - int(can_terminate), int(operation_terminated)],
|
586
|
+
dtype=torch.int,
|
587
|
+
)
|
455
588
|
torch.distributed.all_reduce(
|
456
|
-
|
457
|
-
op=torch.distributed.ReduceOp.
|
589
|
+
states,
|
590
|
+
op=torch.distributed.ReduceOp.MAX,
|
458
591
|
group=self.tp_group,
|
459
592
|
)
|
460
|
-
can_terminate =
|
461
|
-
|
593
|
+
can_terminate = states[0].item() == 0
|
594
|
+
operation_terminated = states[1].item() == 1
|
595
|
+
# the operation should be terminated if it is already terminated on any TP worker
|
596
|
+
# or it meets the termination condition on all TP workers
|
597
|
+
can_terminate = can_terminate or operation_terminated
|
462
598
|
return can_terminate
|
463
599
|
|
464
600
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
@@ -485,7 +621,7 @@ class HiRadixCache(RadixCache):
|
|
485
621
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
486
622
|
|
487
623
|
min_completed_tokens = completed_tokens
|
488
|
-
if self.tp_world_size > 1
|
624
|
+
if self.tp_world_size > 1:
|
489
625
|
# synchrnoize TP workers to make the same update to hiradix cache
|
490
626
|
completed_tokens_tensor = torch.tensor(
|
491
627
|
min_completed_tokens, dtype=torch.int
|
@@ -500,12 +636,12 @@ class HiRadixCache(RadixCache):
|
|
500
636
|
written_indices = host_indices[:min_completed_tokens]
|
501
637
|
matched_length = self._insert_helper_host(
|
502
638
|
last_host_node,
|
503
|
-
|
639
|
+
RadixKey(
|
640
|
+
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
641
|
+
),
|
504
642
|
written_indices,
|
505
643
|
hash_value[: min_completed_tokens // self.page_size],
|
506
644
|
)
|
507
|
-
if len(written_indices):
|
508
|
-
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
509
645
|
|
510
646
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
511
647
|
self.cache_controller.append_host_mem_release(
|
@@ -515,10 +651,16 @@ class HiRadixCache(RadixCache):
|
|
515
651
|
del self.ongoing_prefetch[req_id]
|
516
652
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
517
653
|
|
654
|
+
if self.enable_storage_metrics:
|
655
|
+
self.metrics_collector.log_prefetched_tokens(
|
656
|
+
min_completed_tokens - matched_length
|
657
|
+
)
|
658
|
+
|
518
659
|
return True
|
519
660
|
|
520
|
-
def match_prefix(self, key:
|
661
|
+
def match_prefix(self, key: RadixKey, **kwargs):
|
521
662
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
663
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
522
664
|
if self.disable or len(key) == 0:
|
523
665
|
return MatchResult(
|
524
666
|
device_indices=empty_value,
|
@@ -591,7 +733,9 @@ class HiRadixCache(RadixCache):
|
|
591
733
|
)
|
592
734
|
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
593
735
|
|
594
|
-
def _insert_helper_host(
|
736
|
+
def _insert_helper_host(
|
737
|
+
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
738
|
+
):
|
595
739
|
node.last_access_time = time.monotonic()
|
596
740
|
if len(key) == 0:
|
597
741
|
return 0
|
@@ -625,7 +769,7 @@ class HiRadixCache(RadixCache):
|
|
625
769
|
node.children[child_key] = new_node
|
626
770
|
return matched_length
|
627
771
|
|
628
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
772
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
629
773
|
node.last_access_time = time.monotonic()
|
630
774
|
child_key = self.get_child_key_fn(key)
|
631
775
|
value = []
|
@@ -651,14 +795,13 @@ class HiRadixCache(RadixCache):
|
|
651
795
|
|
652
796
|
return value, node
|
653
797
|
|
654
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
798
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
655
799
|
# child node split into new_node -> child
|
656
800
|
new_node = TreeNode()
|
657
801
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
658
802
|
new_node.parent = child.parent
|
659
803
|
new_node.lock_ref = child.lock_ref
|
660
804
|
new_node.key = child.key[:split_len]
|
661
|
-
new_node.loading = child.loading
|
662
805
|
new_node.hit_count = child.hit_count
|
663
806
|
|
664
807
|
# split value and host value if exists
|
@@ -679,10 +822,16 @@ class HiRadixCache(RadixCache):
|
|
679
822
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
680
823
|
return new_node
|
681
824
|
|
682
|
-
def insert(self, key:
|
825
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
826
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
827
|
+
|
683
828
|
if len(key) == 0:
|
684
829
|
return 0
|
685
830
|
|
831
|
+
if self.is_eagle and value is not None:
|
832
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
833
|
+
value = value[: len(key)]
|
834
|
+
|
686
835
|
node = self.root_node
|
687
836
|
child_key = self.get_child_key_fn(key)
|
688
837
|
total_prefix_length = 0
|
@@ -697,7 +846,6 @@ class HiRadixCache(RadixCache):
|
|
697
846
|
# change the reference if the node is evicted
|
698
847
|
# this often happens in the case of KV cache recomputation
|
699
848
|
node.value = value[:prefix_len]
|
700
|
-
self.token_to_kv_pool_host.update_synced(node.host_value)
|
701
849
|
self.evictable_size_ += len(node.value)
|
702
850
|
else:
|
703
851
|
self._inc_hit_count(node, chunked)
|
@@ -707,7 +855,6 @@ class HiRadixCache(RadixCache):
|
|
707
855
|
new_node = self._split_node(node.key, node, prefix_len)
|
708
856
|
if new_node.evicted:
|
709
857
|
new_node.value = value[:prefix_len]
|
710
|
-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
711
858
|
self.evictable_size_ += len(new_node.value)
|
712
859
|
else:
|
713
860
|
self._inc_hit_count(new_node, chunked)
|
@@ -737,7 +884,7 @@ class HiRadixCache(RadixCache):
|
|
737
884
|
for idx in range(0, len(key), self.page_size):
|
738
885
|
new_node.hash_value.append(
|
739
886
|
self.cache_controller.get_hash_str(
|
740
|
-
key[idx : idx + self.page_size],
|
887
|
+
key.token_ids[idx : idx + self.page_size],
|
741
888
|
prior_hash=last_hash,
|
742
889
|
)
|
743
890
|
)
|