sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
import heapq
|
2
|
+
import json
|
2
3
|
import logging
|
3
4
|
import threading
|
4
5
|
import time
|
5
|
-
from queue import Queue
|
6
6
|
from typing import List, Optional
|
7
7
|
|
8
8
|
import torch
|
@@ -19,7 +19,8 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
19
19
|
MHATokenToKVPoolHost,
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
22
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
23
|
+
from sglang.srt.metrics.collector import StorageMetricsCollector
|
23
24
|
|
24
25
|
logger = logging.getLogger(__name__)
|
25
26
|
|
@@ -37,17 +38,20 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_write_policy: str,
|
38
39
|
hicache_io_backend: str,
|
39
40
|
hicache_mem_layout: str,
|
41
|
+
enable_metrics: bool,
|
42
|
+
eviction_policy: str = "lru",
|
40
43
|
hicache_storage_backend: Optional[str] = None,
|
41
44
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
45
|
model_name: Optional[str] = None,
|
43
46
|
storage_backend_extra_config: Optional[str] = None,
|
47
|
+
is_eagle: bool = False,
|
44
48
|
):
|
45
49
|
|
46
50
|
if hicache_io_backend == "direct":
|
47
51
|
if hicache_mem_layout == "page_first":
|
48
|
-
hicache_mem_layout = "
|
52
|
+
hicache_mem_layout = "page_first_direct"
|
49
53
|
logger.warning(
|
50
|
-
"Page first layout is not supported with direct IO backend, switching to
|
54
|
+
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
|
51
55
|
)
|
52
56
|
|
53
57
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
@@ -73,9 +77,21 @@ class HiRadixCache(RadixCache):
|
|
73
77
|
self.tp_group = tp_cache_group
|
74
78
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
75
79
|
self.enable_storage = hicache_storage_backend is not None
|
76
|
-
|
77
|
-
|
78
|
-
|
80
|
+
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
81
|
+
|
82
|
+
(
|
83
|
+
extra_config,
|
84
|
+
prefetch_threshold,
|
85
|
+
prefetch_timeout_base,
|
86
|
+
prefetch_timeout_per_ki_token,
|
87
|
+
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
|
88
|
+
self.prefetch_threshold = prefetch_threshold
|
89
|
+
self.prefetch_timeout_base = prefetch_timeout_base
|
90
|
+
self.prefetch_timeout_per_page = (
|
91
|
+
page_size / 1024 * prefetch_timeout_per_ki_token
|
92
|
+
)
|
93
|
+
# TODO: support more timeout check functions
|
94
|
+
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
|
79
95
|
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
80
96
|
|
81
97
|
self.load_cache_event = threading.Event()
|
@@ -90,8 +106,16 @@ class HiRadixCache(RadixCache):
|
|
90
106
|
storage_backend=hicache_storage_backend,
|
91
107
|
prefetch_threshold=self.prefetch_threshold,
|
92
108
|
model_name=model_name,
|
93
|
-
storage_backend_extra_config=
|
109
|
+
storage_backend_extra_config=extra_config,
|
94
110
|
)
|
111
|
+
if self.enable_storage_metrics:
|
112
|
+
# TODO: support pp
|
113
|
+
labels = {
|
114
|
+
"storage_backend": hicache_storage_backend,
|
115
|
+
"tp_rank": self.cache_controller.tp_rank,
|
116
|
+
"dp_rank": self.cache_controller.dp_rank,
|
117
|
+
}
|
118
|
+
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
95
119
|
|
96
120
|
# record the nodes with ongoing write through
|
97
121
|
self.ongoing_write_through = {}
|
@@ -105,8 +129,61 @@ class HiRadixCache(RadixCache):
|
|
105
129
|
1 if hicache_write_policy == "write_through" else 2
|
106
130
|
)
|
107
131
|
self.load_back_threshold = 10
|
132
|
+
|
108
133
|
super().__init__(
|
109
|
-
req_to_token_pool,
|
134
|
+
req_to_token_pool,
|
135
|
+
token_to_kv_pool_allocator,
|
136
|
+
page_size,
|
137
|
+
disable=False,
|
138
|
+
eviction_policy=eviction_policy,
|
139
|
+
is_eagle=is_eagle,
|
140
|
+
)
|
141
|
+
|
142
|
+
def _parse_storage_backend_extra_config(
|
143
|
+
self, storage_backend_extra_config: Optional[str]
|
144
|
+
):
|
145
|
+
"""
|
146
|
+
Parse storage backend extra config JSON and extract specific parameters.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
storage_backend_extra_config: JSON string containing extra configuration
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
|
153
|
+
"""
|
154
|
+
# Parse extra config JSON if provided
|
155
|
+
extra_config = {}
|
156
|
+
if storage_backend_extra_config:
|
157
|
+
try:
|
158
|
+
extra_config = json.loads(storage_backend_extra_config)
|
159
|
+
except Exception as e:
|
160
|
+
logger.error(f"Invalid backend extra config JSON: {e}")
|
161
|
+
raise e
|
162
|
+
|
163
|
+
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
|
164
|
+
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
|
165
|
+
prefetch_timeout_per_ki_token = extra_config.pop(
|
166
|
+
"prefetch_timeout_per_ki_token", 0.25
|
167
|
+
) # seconds per 1024 tokens
|
168
|
+
|
169
|
+
if not isinstance(prefetch_threshold, int):
|
170
|
+
raise ValueError(
|
171
|
+
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
|
172
|
+
)
|
173
|
+
if not isinstance(prefetch_timeout_base, (int, float)):
|
174
|
+
raise ValueError(
|
175
|
+
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
|
176
|
+
)
|
177
|
+
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
|
178
|
+
raise ValueError(
|
179
|
+
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
|
180
|
+
)
|
181
|
+
|
182
|
+
return (
|
183
|
+
extra_config,
|
184
|
+
prefetch_threshold,
|
185
|
+
float(prefetch_timeout_base),
|
186
|
+
float(prefetch_timeout_per_ki_token),
|
110
187
|
)
|
111
188
|
|
112
189
|
def reset(self):
|
@@ -122,11 +199,24 @@ class HiRadixCache(RadixCache):
|
|
122
199
|
height += 1
|
123
200
|
return height
|
124
201
|
|
125
|
-
def clear_storage_backend(self):
|
202
|
+
def clear_storage_backend(self) -> bool:
|
126
203
|
if self.enable_storage:
|
127
|
-
|
128
|
-
|
129
|
-
|
204
|
+
try:
|
205
|
+
# Check if the storage backend has a clear method (for nixl backends)
|
206
|
+
if hasattr(self.cache_controller.storage_backend, "clear"):
|
207
|
+
self.cache_controller.storage_backend.clear()
|
208
|
+
logger.info(
|
209
|
+
"Hierarchical cache storage backend cleared successfully!"
|
210
|
+
)
|
211
|
+
return True
|
212
|
+
else:
|
213
|
+
logger.warning(
|
214
|
+
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
215
|
+
)
|
216
|
+
return False
|
217
|
+
except Exception as e:
|
218
|
+
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
219
|
+
return False
|
130
220
|
else:
|
131
221
|
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
222
|
return False
|
@@ -176,53 +266,72 @@ class HiRadixCache(RadixCache):
|
|
176
266
|
if write_back:
|
177
267
|
# blocking till all write back complete
|
178
268
|
while len(self.ongoing_write_through) > 0:
|
179
|
-
|
180
|
-
|
269
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
270
|
+
finish_event.synchronize()
|
271
|
+
for ack_id in ack_list:
|
272
|
+
del self.ongoing_write_through[ack_id]
|
273
|
+
self.cache_controller.ack_write_queue.clear()
|
274
|
+
assert len(self.ongoing_write_through) == 0
|
181
275
|
return
|
182
|
-
|
183
|
-
|
184
|
-
)
|
276
|
+
|
277
|
+
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
278
|
+
if len(self.ongoing_write_through) == 0:
|
279
|
+
return
|
280
|
+
|
281
|
+
finish_count = 0
|
282
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
283
|
+
if not finish_event.query():
|
284
|
+
break
|
285
|
+
finish_count += 1
|
286
|
+
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
185
287
|
if self.tp_world_size > 1:
|
186
|
-
#
|
288
|
+
# synchronize TP workers to make the same update to radix cache
|
187
289
|
torch.distributed.all_reduce(
|
188
290
|
queue_size,
|
189
291
|
op=torch.distributed.ReduceOp.MIN,
|
190
292
|
group=self.tp_group,
|
191
293
|
)
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
196
|
-
|
197
|
-
|
198
|
-
self.
|
294
|
+
|
295
|
+
finish_count = int(queue_size.item())
|
296
|
+
while finish_count > 0:
|
297
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
298
|
+
finish_event.synchronize()
|
299
|
+
for ack_id in ack_list:
|
300
|
+
backuped_node = self.ongoing_write_through.pop(ack_id)
|
301
|
+
self.dec_lock_ref(backuped_node)
|
302
|
+
if self.enable_storage:
|
303
|
+
self.write_backup_storage(backuped_node)
|
304
|
+
finish_count -= 1
|
199
305
|
|
200
306
|
def loading_check(self):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
self.dec_lock_ref(end_node)
|
206
|
-
while end_node != start_node:
|
207
|
-
assert end_node.loading
|
208
|
-
end_node.loading = False
|
209
|
-
end_node = end_node.parent
|
210
|
-
# clear the reference
|
211
|
-
del self.ongoing_load_back[ack_id]
|
212
|
-
except Exception:
|
307
|
+
finish_count = 0
|
308
|
+
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
309
|
+
if not finish_event.query():
|
310
|
+
# the KV cache loading is still ongoing
|
213
311
|
break
|
312
|
+
finish_count += 1
|
313
|
+
# no need to sync across TP workers as batch forwarding is synced
|
314
|
+
for ack_id in ack_list:
|
315
|
+
end_node = self.ongoing_load_back.pop(ack_id)
|
316
|
+
self.dec_lock_ref(end_node)
|
317
|
+
|
318
|
+
# ACK until all events are processed
|
319
|
+
del self.cache_controller.ack_load_queue[:finish_count]
|
214
320
|
|
215
321
|
def evictable_size(self):
|
216
322
|
return self.evictable_size_
|
217
323
|
|
218
324
|
def evict(self, num_tokens: int):
|
219
325
|
leaves = self._collect_leaves_device()
|
220
|
-
|
326
|
+
eviction_heap = [
|
327
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
328
|
+
]
|
329
|
+
heapq.heapify(eviction_heap)
|
221
330
|
|
222
331
|
num_evicted = 0
|
223
332
|
write_back_nodes = []
|
224
|
-
while num_evicted < num_tokens and len(
|
225
|
-
x = heapq.heappop(
|
333
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
334
|
+
_priority, x = heapq.heappop(eviction_heap)
|
226
335
|
|
227
336
|
if x.lock_ref > 0:
|
228
337
|
continue
|
@@ -244,7 +353,8 @@ class HiRadixCache(RadixCache):
|
|
244
353
|
break
|
245
354
|
else:
|
246
355
|
# all children are evicted or no children
|
247
|
-
|
356
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
357
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
248
358
|
|
249
359
|
if self.cache_controller.write_policy == "write_back":
|
250
360
|
self.writing_check(write_back=True)
|
@@ -254,7 +364,7 @@ class HiRadixCache(RadixCache):
|
|
254
364
|
|
255
365
|
def _evict_backuped(self, node: TreeNode):
|
256
366
|
# evict a node already written to host
|
257
|
-
num_evicted = self.cache_controller.evict_device(node.value
|
367
|
+
num_evicted = self.cache_controller.evict_device(node.value)
|
258
368
|
assert num_evicted > 0
|
259
369
|
self.evictable_size_ -= num_evicted
|
260
370
|
node.value = None
|
@@ -269,11 +379,14 @@ class HiRadixCache(RadixCache):
|
|
269
379
|
|
270
380
|
def evict_host(self, num_tokens: int):
|
271
381
|
leaves = self._collect_leaves()
|
272
|
-
|
382
|
+
eviction_heap = [
|
383
|
+
(self.eviction_strategy.get_priority(node), node) for node in leaves
|
384
|
+
]
|
385
|
+
heapq.heapify(eviction_heap)
|
273
386
|
|
274
387
|
num_evicted = 0
|
275
|
-
while num_evicted < num_tokens and len(
|
276
|
-
x = heapq.heappop(
|
388
|
+
while num_evicted < num_tokens and len(eviction_heap):
|
389
|
+
_priority, x = heapq.heappop(eviction_heap)
|
277
390
|
if x == self.root_node:
|
278
391
|
break
|
279
392
|
# only evict the host value of evicted nodes
|
@@ -292,7 +405,8 @@ class HiRadixCache(RadixCache):
|
|
292
405
|
del x.parent.children[k]
|
293
406
|
|
294
407
|
if len(x.parent.children) == 0 and x.parent.evicted:
|
295
|
-
|
408
|
+
new_priority = self.eviction_strategy.get_priority(x.parent)
|
409
|
+
heapq.heappush(eviction_heap, (new_priority, x.parent))
|
296
410
|
|
297
411
|
def load_back(
|
298
412
|
self, node: TreeNode, mem_quota: Optional[int] = None
|
@@ -335,12 +449,11 @@ class HiRadixCache(RadixCache):
|
|
335
449
|
# no sufficient GPU memory to load back KV caches
|
336
450
|
return None
|
337
451
|
|
338
|
-
self.ongoing_load_back[last_hit_node.id] =
|
452
|
+
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
339
453
|
offset = 0
|
340
454
|
for node in nodes_to_load:
|
341
455
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
342
456
|
offset += len(node.host_value)
|
343
|
-
node.loading = True
|
344
457
|
self.evictable_size_ += len(device_indices)
|
345
458
|
self.inc_lock_ref(last_hit_node)
|
346
459
|
|
@@ -369,16 +482,22 @@ class HiRadixCache(RadixCache):
|
|
369
482
|
last_node,
|
370
483
|
)
|
371
484
|
|
372
|
-
def ready_to_load_host_cache(self):
|
373
|
-
|
374
|
-
|
375
|
-
|
485
|
+
def ready_to_load_host_cache(self) -> int:
|
486
|
+
"""
|
487
|
+
Notify the cache controller to start the KV cache loading.
|
488
|
+
Return the consumer index for the schedule batch manager to track.
|
489
|
+
"""
|
490
|
+
return self.cache_controller.start_loading()
|
376
491
|
|
377
492
|
def check_hicache_events(self):
|
378
493
|
self.writing_check()
|
379
494
|
self.loading_check()
|
380
495
|
if self.enable_storage:
|
381
496
|
self.drain_storage_control_queues()
|
497
|
+
if self.enable_storage_metrics:
|
498
|
+
self.metrics_collector.log_storage_metrics(
|
499
|
+
self.cache_controller.storage_backend.get_stats()
|
500
|
+
)
|
382
501
|
|
383
502
|
def drain_storage_control_queues(self):
|
384
503
|
"""
|
@@ -414,10 +533,13 @@ class HiRadixCache(RadixCache):
|
|
414
533
|
|
415
534
|
# process backup acks
|
416
535
|
for _ in range(n_backup):
|
417
|
-
|
536
|
+
operation = cc.ack_backup_queue.get()
|
537
|
+
ack_id = operation.id
|
418
538
|
entry = self.ongoing_backup.pop(ack_id, None)
|
419
539
|
if entry is not None:
|
420
540
|
entry.release_host()
|
541
|
+
if self.enable_storage_metrics:
|
542
|
+
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
421
543
|
|
422
544
|
# release host memory
|
423
545
|
host_indices_list = []
|
@@ -427,6 +549,15 @@ class HiRadixCache(RadixCache):
|
|
427
549
|
host_indices = torch.cat(host_indices_list, dim=0)
|
428
550
|
cc.mem_pool_host.free(host_indices)
|
429
551
|
|
552
|
+
# Timeout is linearly increasing with the number of pages
|
553
|
+
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
|
554
|
+
# If hash_value has not been computed in timeout_base seconds, terminate it.
|
555
|
+
return (
|
556
|
+
time.monotonic() - operation.start_time
|
557
|
+
> self.prefetch_timeout_base
|
558
|
+
+ len(operation.hash_value) * self.prefetch_timeout_per_page
|
559
|
+
)
|
560
|
+
|
430
561
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
431
562
|
can_terminate = True
|
432
563
|
|
@@ -443,22 +574,27 @@ class HiRadixCache(RadixCache):
|
|
443
574
|
if self.prefetch_stop_policy == "wait_complete":
|
444
575
|
can_terminate = completed
|
445
576
|
elif self.prefetch_stop_policy == "timeout":
|
446
|
-
can_terminate = completed or (
|
447
|
-
time.monotonic() - operation.start_time > self.prefetch_timeout
|
448
|
-
)
|
577
|
+
can_terminate = completed or self.is_prefetch_timeout(operation)
|
449
578
|
else:
|
450
579
|
# unknown prefetch stop policy, just return True
|
451
580
|
return True
|
452
581
|
|
582
|
+
operation_terminated = operation.is_terminated()
|
453
583
|
if self.tp_world_size > 1:
|
454
|
-
|
584
|
+
states = torch.tensor(
|
585
|
+
[1 - int(can_terminate), int(operation_terminated)],
|
586
|
+
dtype=torch.int,
|
587
|
+
)
|
455
588
|
torch.distributed.all_reduce(
|
456
|
-
|
457
|
-
op=torch.distributed.ReduceOp.
|
589
|
+
states,
|
590
|
+
op=torch.distributed.ReduceOp.MAX,
|
458
591
|
group=self.tp_group,
|
459
592
|
)
|
460
|
-
can_terminate =
|
461
|
-
|
593
|
+
can_terminate = states[0].item() == 0
|
594
|
+
operation_terminated = states[1].item() == 1
|
595
|
+
# the operation should be terminated if it is already terminated on any TP worker
|
596
|
+
# or it meets the termination condition on all TP workers
|
597
|
+
can_terminate = can_terminate or operation_terminated
|
462
598
|
return can_terminate
|
463
599
|
|
464
600
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
@@ -468,9 +604,9 @@ class HiRadixCache(RadixCache):
|
|
468
604
|
|
469
605
|
# todo: more policies for prefetch progress such as timeout
|
470
606
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
471
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
607
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
472
608
|
req_id
|
473
|
-
|
609
|
+
]
|
474
610
|
|
475
611
|
if operation.host_indices is None:
|
476
612
|
# prefetch has not been issued due to insufficient host memory
|
@@ -485,7 +621,7 @@ class HiRadixCache(RadixCache):
|
|
485
621
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
486
622
|
|
487
623
|
min_completed_tokens = completed_tokens
|
488
|
-
if self.tp_world_size > 1
|
624
|
+
if self.tp_world_size > 1:
|
489
625
|
# synchrnoize TP workers to make the same update to hiradix cache
|
490
626
|
completed_tokens_tensor = torch.tensor(
|
491
627
|
min_completed_tokens, dtype=torch.int
|
@@ -500,24 +636,31 @@ class HiRadixCache(RadixCache):
|
|
500
636
|
written_indices = host_indices[:min_completed_tokens]
|
501
637
|
matched_length = self._insert_helper_host(
|
502
638
|
last_host_node,
|
503
|
-
|
639
|
+
RadixKey(
|
640
|
+
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
|
641
|
+
),
|
504
642
|
written_indices,
|
505
643
|
hash_value[: min_completed_tokens // self.page_size],
|
506
644
|
)
|
507
|
-
if len(written_indices):
|
508
|
-
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
509
645
|
|
510
646
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
511
647
|
self.cache_controller.append_host_mem_release(
|
512
648
|
host_indices[min_completed_tokens:completed_tokens]
|
513
649
|
)
|
514
650
|
last_host_node.release_host()
|
651
|
+
del self.ongoing_prefetch[req_id]
|
515
652
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
516
653
|
|
654
|
+
if self.enable_storage_metrics:
|
655
|
+
self.metrics_collector.log_prefetched_tokens(
|
656
|
+
min_completed_tokens - matched_length
|
657
|
+
)
|
658
|
+
|
517
659
|
return True
|
518
660
|
|
519
|
-
def match_prefix(self, key:
|
661
|
+
def match_prefix(self, key: RadixKey, **kwargs):
|
520
662
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
663
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
521
664
|
if self.disable or len(key) == 0:
|
522
665
|
return MatchResult(
|
523
666
|
device_indices=empty_value,
|
@@ -590,7 +733,9 @@ class HiRadixCache(RadixCache):
|
|
590
733
|
)
|
591
734
|
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
592
735
|
|
593
|
-
def _insert_helper_host(
|
736
|
+
def _insert_helper_host(
|
737
|
+
self, node: TreeNode, key: RadixKey, host_value, hash_value
|
738
|
+
):
|
594
739
|
node.last_access_time = time.monotonic()
|
595
740
|
if len(key) == 0:
|
596
741
|
return 0
|
@@ -624,7 +769,7 @@ class HiRadixCache(RadixCache):
|
|
624
769
|
node.children[child_key] = new_node
|
625
770
|
return matched_length
|
626
771
|
|
627
|
-
def _match_prefix_helper(self, node: TreeNode, key:
|
772
|
+
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
|
628
773
|
node.last_access_time = time.monotonic()
|
629
774
|
child_key = self.get_child_key_fn(key)
|
630
775
|
value = []
|
@@ -650,14 +795,13 @@ class HiRadixCache(RadixCache):
|
|
650
795
|
|
651
796
|
return value, node
|
652
797
|
|
653
|
-
def _split_node(self, key, child: TreeNode, split_len: int):
|
798
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
|
654
799
|
# child node split into new_node -> child
|
655
800
|
new_node = TreeNode()
|
656
801
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
657
802
|
new_node.parent = child.parent
|
658
803
|
new_node.lock_ref = child.lock_ref
|
659
804
|
new_node.key = child.key[:split_len]
|
660
|
-
new_node.loading = child.loading
|
661
805
|
new_node.hit_count = child.hit_count
|
662
806
|
|
663
807
|
# split value and host value if exists
|
@@ -678,10 +822,16 @@ class HiRadixCache(RadixCache):
|
|
678
822
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
679
823
|
return new_node
|
680
824
|
|
681
|
-
def insert(self, key:
|
825
|
+
def insert(self, key: RadixKey, value=None, chunked=False):
|
826
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
827
|
+
|
682
828
|
if len(key) == 0:
|
683
829
|
return 0
|
684
830
|
|
831
|
+
if self.is_eagle and value is not None:
|
832
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
833
|
+
value = value[: len(key)]
|
834
|
+
|
685
835
|
node = self.root_node
|
686
836
|
child_key = self.get_child_key_fn(key)
|
687
837
|
total_prefix_length = 0
|
@@ -696,7 +846,6 @@ class HiRadixCache(RadixCache):
|
|
696
846
|
# change the reference if the node is evicted
|
697
847
|
# this often happens in the case of KV cache recomputation
|
698
848
|
node.value = value[:prefix_len]
|
699
|
-
self.token_to_kv_pool_host.update_synced(node.host_value)
|
700
849
|
self.evictable_size_ += len(node.value)
|
701
850
|
else:
|
702
851
|
self._inc_hit_count(node, chunked)
|
@@ -706,7 +855,6 @@ class HiRadixCache(RadixCache):
|
|
706
855
|
new_node = self._split_node(node.key, node, prefix_len)
|
707
856
|
if new_node.evicted:
|
708
857
|
new_node.value = value[:prefix_len]
|
709
|
-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
710
858
|
self.evictable_size_ += len(new_node.value)
|
711
859
|
else:
|
712
860
|
self._inc_hit_count(new_node, chunked)
|
@@ -736,7 +884,7 @@ class HiRadixCache(RadixCache):
|
|
736
884
|
for idx in range(0, len(key), self.page_size):
|
737
885
|
new_node.hash_value.append(
|
738
886
|
self.cache_controller.get_hash_str(
|
739
|
-
key[idx : idx + self.page_size],
|
887
|
+
key.token_ids[idx : idx + self.page_size],
|
740
888
|
prior_hash=last_hash,
|
741
889
|
)
|
742
890
|
)
|
@@ -775,9 +923,7 @@ class HiRadixCache(RadixCache):
|
|
775
923
|
if rid not in self.ongoing_prefetch:
|
776
924
|
return
|
777
925
|
|
778
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
779
|
-
rid
|
780
|
-
)
|
926
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
781
927
|
if operation.host_indices is None:
|
782
928
|
return
|
783
929
|
|
@@ -785,5 +931,6 @@ class HiRadixCache(RadixCache):
|
|
785
931
|
if self.tp_world_size > 1:
|
786
932
|
torch.distributed.barrier(group=self.tp_group)
|
787
933
|
last_host_node.release_host()
|
934
|
+
del self.ongoing_prefetch[rid]
|
788
935
|
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
789
936
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|