sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
|
|
1
|
-
import hashlib
|
2
1
|
import json
|
3
2
|
import logging
|
4
3
|
import os
|
@@ -6,15 +5,18 @@ import uuid
|
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Any, List, Optional
|
8
7
|
|
9
|
-
import numpy as np
|
10
8
|
import torch
|
11
9
|
|
12
|
-
from sglang.srt.
|
13
|
-
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import (
|
11
|
+
HiCacheStorage,
|
12
|
+
HiCacheStorageConfig,
|
13
|
+
HiCacheStorageExtraInfo,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
14
16
|
|
15
17
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
16
18
|
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
17
|
-
|
19
|
+
DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
|
18
20
|
logger = logging.getLogger(__name__)
|
19
21
|
|
20
22
|
|
@@ -31,13 +33,13 @@ class MooncakeStoreConfig:
|
|
31
33
|
@staticmethod
|
32
34
|
def from_file() -> "MooncakeStoreConfig":
|
33
35
|
"""Load the config from a JSON file."""
|
34
|
-
file_path = os.getenv(
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
|
37
|
+
try:
|
38
|
+
with open(file_path) as fin:
|
39
|
+
config = json.load(fin)
|
40
|
+
except Exception as e:
|
41
|
+
raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
|
42
|
+
|
41
43
|
return MooncakeStoreConfig(
|
42
44
|
local_hostname=config.get("local_hostname"),
|
43
45
|
metadata_server=config.get("metadata_server"),
|
@@ -75,6 +77,26 @@ class MooncakeStoreConfig:
|
|
75
77
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
76
78
|
)
|
77
79
|
|
80
|
+
@staticmethod
|
81
|
+
def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
|
82
|
+
"""Load config from extra_config dictionary."""
|
83
|
+
if "master_server_address" not in extra_config:
|
84
|
+
raise ValueError("master_server_address is required in extra_config")
|
85
|
+
|
86
|
+
return MooncakeStoreConfig(
|
87
|
+
local_hostname=extra_config.get("local_hostname", "localhost"),
|
88
|
+
metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
|
89
|
+
global_segment_size=extra_config.get(
|
90
|
+
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
91
|
+
),
|
92
|
+
local_buffer_size=extra_config.get(
|
93
|
+
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
94
|
+
),
|
95
|
+
protocol=extra_config.get("protocol", "tcp"),
|
96
|
+
device_name=extra_config.get("device_name", "auto"),
|
97
|
+
master_server_address=extra_config["master_server_address"],
|
98
|
+
)
|
99
|
+
|
78
100
|
def __post_init__(self):
|
79
101
|
if self.device_name == "auto":
|
80
102
|
os.environ["MC_MS_AUTO_DISC"] = "1"
|
@@ -84,6 +106,7 @@ class MooncakeStoreConfig:
|
|
84
106
|
|
85
107
|
|
86
108
|
class MooncakeStore(HiCacheStorage):
|
109
|
+
|
87
110
|
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
88
111
|
try:
|
89
112
|
from mooncake.store import MooncakeDistributedStore
|
@@ -96,14 +119,43 @@ class MooncakeStore(HiCacheStorage):
|
|
96
119
|
|
97
120
|
try:
|
98
121
|
self.store = MooncakeDistributedStore()
|
99
|
-
|
100
|
-
|
122
|
+
|
123
|
+
extra_config = (
|
124
|
+
getattr(storage_config, "extra_config", None)
|
125
|
+
if storage_config
|
126
|
+
else None
|
127
|
+
)
|
128
|
+
# Load configuration with master_server_address prioritized from extra_config if available
|
129
|
+
if (
|
130
|
+
extra_config is not None
|
131
|
+
and extra_config.get("master_server_address") is not None
|
132
|
+
):
|
133
|
+
# Load from extra_config
|
134
|
+
self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
|
135
|
+
logger.info(
|
136
|
+
"Mooncake Configuration loaded from extra_config successfully."
|
137
|
+
)
|
138
|
+
elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
|
139
|
+
# Load from config file
|
140
|
+
self.config = MooncakeStoreConfig.from_file()
|
141
|
+
logger.info("Mooncake Configuration loaded from file successfully.")
|
142
|
+
else:
|
143
|
+
# Load from environment variables
|
144
|
+
self.config = MooncakeStoreConfig.load_from_env()
|
145
|
+
logger.info("Mooncake Configuration loaded from env successfully.")
|
146
|
+
|
147
|
+
tp_scale_factor = 1 if storage_config is None else storage_config.tp_size
|
148
|
+
|
149
|
+
per_tp_global_segment_size = (
|
150
|
+
self.config.global_segment_size // tp_scale_factor
|
151
|
+
)
|
152
|
+
per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
|
101
153
|
|
102
154
|
ret_code = self.store.setup(
|
103
155
|
self.config.local_hostname,
|
104
156
|
self.config.metadata_server,
|
105
|
-
|
106
|
-
|
157
|
+
per_tp_global_segment_size,
|
158
|
+
per_tp_local_buffer_size,
|
107
159
|
self.config.protocol,
|
108
160
|
self.config.device_name,
|
109
161
|
self.config.master_server_address,
|
@@ -136,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
|
|
136
188
|
assert self.store.is_exist(warmup_key) == 1
|
137
189
|
assert self.store.get(warmup_key) == warmup_value
|
138
190
|
|
139
|
-
def
|
191
|
+
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
192
|
+
super().register_mem_pool_host(mem_pool_host)
|
193
|
+
assert self.mem_pool_host.layout in [
|
194
|
+
"page_first",
|
195
|
+
"page_first_direct",
|
196
|
+
], "mooncake store storage backend only support page first or page first direct layout"
|
197
|
+
buffer = self.mem_pool_host.kv_buffer
|
140
198
|
try:
|
141
199
|
buffer_ptr = buffer.data_ptr()
|
142
200
|
buffer_size = buffer.numel() * buffer.element_size()
|
@@ -147,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
|
|
147
205
|
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
148
206
|
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
149
207
|
|
208
|
+
def _get_mha_buffer_meta(self, keys, indices):
|
209
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
210
|
+
key_list = []
|
211
|
+
for key_ in keys:
|
212
|
+
key_list.append(f"{key_}_{self.local_rank}_k")
|
213
|
+
key_list.append(f"{key_}_{self.local_rank}_v")
|
214
|
+
assert len(key_list) == len(ptr_list)
|
215
|
+
return key_list, ptr_list, element_size_list
|
216
|
+
|
217
|
+
def _get_mla_buffer_meta(self, keys, indices):
|
218
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
219
|
+
key_list = []
|
220
|
+
for key_ in keys:
|
221
|
+
key_list.append(f"{key_}_k")
|
222
|
+
assert len(key_list) == len(ptr_list)
|
223
|
+
return key_list, ptr_list, element_size_list
|
224
|
+
|
225
|
+
def _batch_preprocess(self, keys, host_indices):
|
226
|
+
assert len(keys) > 0
|
227
|
+
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
|
228
|
+
if self.is_mla_backend:
|
229
|
+
return self._get_mla_buffer_meta(keys, host_indices)
|
230
|
+
else:
|
231
|
+
return self._get_mha_buffer_meta(keys, host_indices)
|
232
|
+
|
233
|
+
def _batch_postprocess(self, results: List[int], is_set_operate=False):
|
234
|
+
"""
|
235
|
+
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
|
236
|
+
for batch_get_into, results is Vector of integers,
|
237
|
+
where each element is the number of bytes read on success, or a negative value on error
|
238
|
+
for batch_put_from, results is Vector of integers,
|
239
|
+
where each element is 0 on success, or a negative value on error
|
240
|
+
"""
|
241
|
+
if self.is_mla_backend:
|
242
|
+
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
|
243
|
+
else:
|
244
|
+
kv_pairs = zip(results[::2], results[1::2])
|
245
|
+
return [
|
246
|
+
(
|
247
|
+
(k_res == 0 and v_res == 0)
|
248
|
+
if is_set_operate
|
249
|
+
else (k_res > 0 and v_res > 0)
|
250
|
+
)
|
251
|
+
for k_res, v_res in kv_pairs
|
252
|
+
]
|
253
|
+
|
254
|
+
def batch_get_v1(
|
255
|
+
self,
|
256
|
+
keys: List[str],
|
257
|
+
host_indices: torch.Tensor,
|
258
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
259
|
+
) -> List[bool]:
|
260
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
261
|
+
get_results = self._get_batch_zero_copy_impl(
|
262
|
+
key_strs, buffer_ptrs, buffer_sizes
|
263
|
+
)
|
264
|
+
return self._batch_postprocess(get_results, is_set_operate=False)
|
265
|
+
|
266
|
+
def batch_set_v1(
|
267
|
+
self,
|
268
|
+
keys: List[str],
|
269
|
+
host_indices: torch.Tensor,
|
270
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
271
|
+
) -> List[bool]:
|
272
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
273
|
+
exist_result = self._batch_exist(key_strs)
|
274
|
+
|
275
|
+
set_keys = []
|
276
|
+
set_buffer_ptrs = []
|
277
|
+
set_buffer_sizes = []
|
278
|
+
set_indices = []
|
279
|
+
set_results = [-1] * len(key_strs)
|
280
|
+
for i in range(len(key_strs)):
|
281
|
+
if exist_result[i] != 1:
|
282
|
+
set_keys.append(key_strs[i])
|
283
|
+
set_buffer_ptrs.append(buffer_ptrs[i])
|
284
|
+
set_buffer_sizes.append(buffer_sizes[i])
|
285
|
+
set_indices.append(i)
|
286
|
+
else:
|
287
|
+
set_results[i] = 0
|
288
|
+
|
289
|
+
# Only set non-existing keys to storage
|
290
|
+
if len(set_keys) > 0:
|
291
|
+
put_results = self._put_batch_zero_copy_impl(
|
292
|
+
set_keys, set_buffer_ptrs, set_buffer_sizes
|
293
|
+
)
|
294
|
+
for i in range(len(set_indices)):
|
295
|
+
set_results[set_indices[i]] = put_results[i]
|
296
|
+
|
297
|
+
return self._batch_postprocess(set_results, is_set_operate=True)
|
298
|
+
|
150
299
|
def set(
|
151
300
|
self,
|
152
301
|
key,
|
@@ -154,21 +303,36 @@ class MooncakeStore(HiCacheStorage):
|
|
154
303
|
target_location: Optional[List[int]] = None,
|
155
304
|
target_sizes: Optional[List[int]] = None,
|
156
305
|
) -> bool:
|
157
|
-
|
306
|
+
# Only support zero copy set for now
|
307
|
+
assert target_location is not None and target_sizes is not None
|
308
|
+
exist_result = self._batch_exist([key])
|
309
|
+
if exist_result[0] == 1:
|
310
|
+
return True
|
311
|
+
put_result = self._put_batch_zero_copy_impl(
|
312
|
+
[key], [target_location], [target_sizes]
|
313
|
+
)
|
314
|
+
return put_result[0] == 0
|
158
315
|
|
159
316
|
def batch_set(
|
160
317
|
self,
|
161
318
|
keys: List[str],
|
162
319
|
values: Optional[List[torch.Tensor]] = None,
|
163
|
-
|
320
|
+
target_locations: Optional[List[int]] = None,
|
164
321
|
target_sizes: Optional[List[int]] = None,
|
165
322
|
) -> bool:
|
166
|
-
|
323
|
+
# Only support zero copy set for now
|
324
|
+
assert target_locations is not None and target_sizes is not None
|
325
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
326
|
+
|
167
327
|
if len(keys) == 0:
|
168
328
|
return False
|
169
329
|
|
170
330
|
for i in range(len(keys)):
|
171
|
-
if
|
331
|
+
if (
|
332
|
+
keys[i] is None
|
333
|
+
or target_locations[i] is None
|
334
|
+
or target_sizes[i] is None
|
335
|
+
):
|
172
336
|
return False
|
173
337
|
|
174
338
|
exist_result = self._batch_exist(keys)
|
@@ -179,7 +343,7 @@ class MooncakeStore(HiCacheStorage):
|
|
179
343
|
for i in range(len(keys)):
|
180
344
|
if exist_result[i] != 1:
|
181
345
|
set_keys.append(keys[i])
|
182
|
-
set_target_locations.append(
|
346
|
+
set_target_locations.append(target_locations[i])
|
183
347
|
set_target_sizes.append(target_sizes[i])
|
184
348
|
set_indices.append(i)
|
185
349
|
# Only set non-existing keys to storage
|
@@ -204,18 +368,24 @@ class MooncakeStore(HiCacheStorage):
|
|
204
368
|
target_location: Optional[Any] = None,
|
205
369
|
target_sizes: Optional[Any] = None,
|
206
370
|
) -> bool:
|
207
|
-
|
371
|
+
assert target_location is not None and target_sizes is not None
|
372
|
+
get_result = self._get_batch_zero_copy_impl(
|
373
|
+
[key], [target_location], [target_sizes]
|
374
|
+
)
|
375
|
+
return get_result[0] >= 0
|
208
376
|
|
209
377
|
def batch_get(
|
210
378
|
self,
|
211
379
|
keys: List[str],
|
212
|
-
|
380
|
+
target_locations: Optional[Any] = None,
|
213
381
|
target_sizes: Optional[Any] = None,
|
214
382
|
) -> int:
|
215
|
-
assert len(keys) == len(
|
383
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
216
384
|
if len(keys) == 0:
|
217
385
|
return 0
|
218
|
-
get_result = self._get_batch_zero_copy_impl(
|
386
|
+
get_result = self._get_batch_zero_copy_impl(
|
387
|
+
keys, target_locations, target_sizes
|
388
|
+
)
|
219
389
|
if self.is_mla_backend:
|
220
390
|
key_multiplier = 1
|
221
391
|
else:
|
@@ -226,7 +396,8 @@ class MooncakeStore(HiCacheStorage):
|
|
226
396
|
return len(keys) // key_multiplier
|
227
397
|
|
228
398
|
def exists(self, key) -> bool:
|
229
|
-
|
399
|
+
exist_result = self._batch_exist([key])
|
400
|
+
return exist_result[0] == 1
|
230
401
|
|
231
402
|
def batch_exists(self, keys) -> int:
|
232
403
|
if self.is_mla_backend:
|
@@ -245,9 +416,6 @@ class MooncakeStore(HiCacheStorage):
|
|
245
416
|
return i // key_multiplier
|
246
417
|
return len(query_keys) // key_multiplier
|
247
418
|
|
248
|
-
def delete(self, key) -> None:
|
249
|
-
raise (NotImplementedError)
|
250
|
-
|
251
419
|
def close(self):
|
252
420
|
# MooncakeDistributedStore will automatically call the destructor, so
|
253
421
|
# it is unnecessary to close it manually.
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import logging
|
2
|
+
import uuid
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from mooncake_store import MooncakeStore
|
6
|
+
|
7
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
8
|
+
|
9
|
+
logging.basicConfig(
|
10
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
11
|
+
)
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
|
16
|
+
keys = []
|
17
|
+
for _ in range(kv_num):
|
18
|
+
key = "test_" + str(uuid.uuid4())
|
19
|
+
keys.append(key)
|
20
|
+
set_keys = []
|
21
|
+
for key in keys:
|
22
|
+
if config.is_mla_model:
|
23
|
+
set_keys.append(key + "_k")
|
24
|
+
else:
|
25
|
+
set_keys.append(key + f"_{config.tp_rank}_k")
|
26
|
+
set_keys.append(key + f"_{config.tp_rank}_v")
|
27
|
+
get_keys = set_keys
|
28
|
+
exist_keys = keys
|
29
|
+
return set_keys, get_keys, exist_keys
|
30
|
+
|
31
|
+
|
32
|
+
def test_single_operation():
|
33
|
+
"""Test the set API with a single key-value pair."""
|
34
|
+
print("=" * 100)
|
35
|
+
print("Testing single operation")
|
36
|
+
|
37
|
+
buffer_size = 1024 * 1024 * 16 # 16MB
|
38
|
+
value_elements = 1024
|
39
|
+
store = MooncakeStore()
|
40
|
+
buffer = torch.randn(buffer_size, dtype=torch.float32)
|
41
|
+
store.register_buffer(buffer)
|
42
|
+
value_size = value_elements * buffer.element_size()
|
43
|
+
|
44
|
+
key = str(uuid.uuid4())
|
45
|
+
set_slice = buffer[:value_elements]
|
46
|
+
get_slice = buffer[value_elements : 2 * value_elements]
|
47
|
+
set_location = set_slice.data_ptr()
|
48
|
+
get_location = get_slice.data_ptr()
|
49
|
+
|
50
|
+
# Test set operation
|
51
|
+
result = store.set(key, target_location=set_location, target_sizes=value_size)
|
52
|
+
assert result is True, f"❌set operation failed for key: {key}"
|
53
|
+
|
54
|
+
# Test exists operation
|
55
|
+
assert store.exists(key), f"❌key {key} should exist after set operation"
|
56
|
+
|
57
|
+
# Test get operation
|
58
|
+
result = store.get(key, target_location=get_location, target_sizes=value_size)
|
59
|
+
assert result is True, f"❌get operation failed for key: {key}"
|
60
|
+
|
61
|
+
# Compare the data using proper tensor indices
|
62
|
+
assert torch.allclose(
|
63
|
+
set_slice, get_slice, atol=1e-6
|
64
|
+
), f"❌get operation failed for key: {key}"
|
65
|
+
|
66
|
+
logger.info(f"✅ Single operation passed")
|
67
|
+
|
68
|
+
|
69
|
+
def test_batch_operation(config: HiCacheStorageConfig):
|
70
|
+
"""Test the batch set/get APIs with multiple key-value pairs."""
|
71
|
+
print("=" * 100)
|
72
|
+
print(f"Testing batch operation with config: {config}")
|
73
|
+
|
74
|
+
buffer_size = 1024 * 1024 * 16 # 16MB
|
75
|
+
value_elements = 256
|
76
|
+
kv_num = 13
|
77
|
+
store = MooncakeStore(config)
|
78
|
+
buffer = torch.randn(buffer_size, dtype=torch.float32)
|
79
|
+
store.register_buffer(buffer)
|
80
|
+
value_size = value_elements * buffer.element_size()
|
81
|
+
|
82
|
+
set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
|
83
|
+
set_slices = [
|
84
|
+
buffer[i * value_elements : (i + 1) * value_elements]
|
85
|
+
for i in range(len(set_keys))
|
86
|
+
]
|
87
|
+
set_locations = [set_slice.data_ptr() for set_slice in set_slices]
|
88
|
+
target_sizes = [value_size for _ in range(len(set_keys))]
|
89
|
+
|
90
|
+
# Test batch set operation
|
91
|
+
result = store.batch_set(
|
92
|
+
set_keys, target_locations=set_locations, target_sizes=target_sizes
|
93
|
+
)
|
94
|
+
assert result is True, f"❌batch set operation failed"
|
95
|
+
|
96
|
+
# Test batch exists operation
|
97
|
+
assert store.batch_exists(
|
98
|
+
exist_keys
|
99
|
+
), f"❌keys should exist after batch set operation"
|
100
|
+
|
101
|
+
# Test batch get operation
|
102
|
+
get_slices = [
|
103
|
+
buffer[
|
104
|
+
(len(set_keys) + i)
|
105
|
+
* value_elements : (len(set_keys) + i + 1)
|
106
|
+
* value_elements
|
107
|
+
]
|
108
|
+
for i in range(len(get_keys))
|
109
|
+
]
|
110
|
+
get_locations = [get_slice.data_ptr() for get_slice in get_slices]
|
111
|
+
result = store.batch_get(
|
112
|
+
get_keys, target_locations=get_locations, target_sizes=target_sizes
|
113
|
+
)
|
114
|
+
assert result == kv_num, f"❌batch get operation failed"
|
115
|
+
for i in range(len(get_keys)):
|
116
|
+
assert torch.allclose(
|
117
|
+
set_slices[i], get_slices[i], atol=1e-6
|
118
|
+
), f"❌batch get operation failed for key: {get_keys[i]}"
|
119
|
+
|
120
|
+
logger.info(f"✅ Batch operation passed")
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == "__main__":
|
124
|
+
test_single_operation()
|
125
|
+
test_batch_operation(
|
126
|
+
HiCacheStorageConfig(
|
127
|
+
is_mla_model=False,
|
128
|
+
tp_rank=0,
|
129
|
+
tp_size=1,
|
130
|
+
model_name=None,
|
131
|
+
is_page_first_layout=True,
|
132
|
+
)
|
133
|
+
)
|
134
|
+
test_batch_operation(
|
135
|
+
HiCacheStorageConfig(
|
136
|
+
is_mla_model=True,
|
137
|
+
tp_rank=0,
|
138
|
+
tp_size=1,
|
139
|
+
model_name=None,
|
140
|
+
is_page_first_layout=True,
|
141
|
+
)
|
142
|
+
)
|
143
|
+
test_batch_operation(
|
144
|
+
HiCacheStorageConfig(
|
145
|
+
is_mla_model=False,
|
146
|
+
tp_rank=1,
|
147
|
+
tp_size=4,
|
148
|
+
model_name=None,
|
149
|
+
is_page_first_layout=True,
|
150
|
+
)
|
151
|
+
)
|
152
|
+
test_batch_operation(
|
153
|
+
HiCacheStorageConfig(
|
154
|
+
is_mla_model=True,
|
155
|
+
tp_rank=3,
|
156
|
+
tp_size=8,
|
157
|
+
model_name=None,
|
158
|
+
is_page_first_layout=True,
|
159
|
+
)
|
160
|
+
)
|
161
|
+
logger.info(f"✅ All tests passed")
|
@@ -30,6 +30,12 @@ import torch
|
|
30
30
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
31
31
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
32
32
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
33
|
+
from sglang.srt.mem_cache.radix_cache import (
|
34
|
+
RadixKey,
|
35
|
+
_key_match_page_size1,
|
36
|
+
_key_match_paged,
|
37
|
+
get_child_key,
|
38
|
+
)
|
33
39
|
|
34
40
|
if TYPE_CHECKING:
|
35
41
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -47,7 +53,7 @@ class TreeNode:
|
|
47
53
|
def __init__(self, id: Optional[int] = None):
|
48
54
|
self.children = defaultdict(TreeNode)
|
49
55
|
self.parent: TreeNode = None
|
50
|
-
self.key:
|
56
|
+
self.key: RadixKey = None
|
51
57
|
self.value: Optional[torch.Tensor] = None
|
52
58
|
# swa_tombstone is used to indicate the kv indices have been freed for swa layers
|
53
59
|
self.swa_tombstone = False
|
@@ -60,8 +66,6 @@ class TreeNode:
|
|
60
66
|
self.last_access_time = time.monotonic()
|
61
67
|
|
62
68
|
self.hit_count = 0
|
63
|
-
# indicating the node is loading KV cache from host
|
64
|
-
self.loading = False
|
65
69
|
# store the host indices of KV cache
|
66
70
|
self.host_value = None
|
67
71
|
|
@@ -89,27 +93,6 @@ class TreeNode:
|
|
89
93
|
return self.last_access_time < other.last_access_time
|
90
94
|
|
91
95
|
|
92
|
-
def _key_match_page_size1(key0: List, key1: List):
|
93
|
-
i = 0
|
94
|
-
for k0, k1 in zip(key0, key1):
|
95
|
-
if k0 != k1:
|
96
|
-
break
|
97
|
-
i += 1
|
98
|
-
return i
|
99
|
-
|
100
|
-
|
101
|
-
def _key_match_paged(key0: List, key1: List, page_size: int):
|
102
|
-
min_len = min(len(key0), len(key1))
|
103
|
-
|
104
|
-
i = 0
|
105
|
-
while i < min_len:
|
106
|
-
if key0[i : i + page_size] != key1[i : i + page_size]:
|
107
|
-
break
|
108
|
-
i += page_size
|
109
|
-
|
110
|
-
return i
|
111
|
-
|
112
|
-
|
113
96
|
def gen_swa_uuid() -> int:
|
114
97
|
TreeNode.swa_uuid_counter += 1
|
115
98
|
return TreeNode.swa_uuid_counter
|
@@ -358,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
|
|
358
341
|
|
359
342
|
if self.page_size == 1:
|
360
343
|
self.key_match_fn = _key_match_page_size1
|
361
|
-
self.get_child_key_fn =
|
344
|
+
self.get_child_key_fn = get_child_key
|
362
345
|
else:
|
363
346
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
364
|
-
self.get_child_key_fn =
|
347
|
+
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
365
348
|
|
366
349
|
self.sliding_window_size = sliding_window_size
|
367
350
|
self.reset()
|
@@ -382,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
|
|
382
365
|
self.full_lru_list = LRUList(swa=False)
|
383
366
|
self.swa_lru_list = LRUList(swa=True)
|
384
367
|
|
385
|
-
def match_prefix(self, key:
|
368
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
|
386
369
|
"""Find the matching prefix from the radix tree.
|
387
370
|
Args:
|
388
|
-
key: A
|
371
|
+
key: A RadixKey contains token IDs to find a matching prefix.
|
389
372
|
Returns:
|
390
373
|
A tuple of a tensor of matching prefix token IDs and
|
391
374
|
the last node that contains the prefix values. Note that
|
@@ -419,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
|
|
419
402
|
last_host_node=last_node,
|
420
403
|
)
|
421
404
|
|
422
|
-
def insert(self, key:
|
405
|
+
def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
|
423
406
|
if self.disable:
|
424
407
|
return 0
|
425
408
|
|
426
409
|
if value is None:
|
427
|
-
value = [x for x in key]
|
410
|
+
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
428
411
|
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
429
412
|
|
430
413
|
def cache_finished_req(self, req: Req) -> None:
|
@@ -455,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
|
|
455
438
|
# insert the token_ids and kv_indices into the radix tree
|
456
439
|
# Note: the insert function already frees the overlapped kv_indices
|
457
440
|
new_prefix_len = self.insert(
|
458
|
-
token_ids[:page_aligned_len],
|
441
|
+
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
459
442
|
page_aligned_kv_indices,
|
460
443
|
len(req.prefix_indices),
|
461
444
|
)
|
@@ -491,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
|
|
491
474
|
# Radix Cache takes one ref in memory pool
|
492
475
|
# Note: the insert function already frees the overlapped kv_indices
|
493
476
|
new_prefix_len = self.insert(
|
494
|
-
page_aligned_token_ids,
|
477
|
+
RadixKey(page_aligned_token_ids, req.extra_key),
|
478
|
+
page_aligned_kv_indices,
|
479
|
+
len(req.prefix_indices),
|
495
480
|
)
|
496
481
|
|
497
482
|
# The prefix indices could be updated, reuse it
|
498
|
-
new_indices, new_last_node, _, _ = self.match_prefix(
|
483
|
+
new_indices, new_last_node, _, _ = self.match_prefix(
|
484
|
+
RadixKey(page_aligned_token_ids, req.extra_key)
|
485
|
+
)
|
499
486
|
assert len(req.prefix_indices) <= len(
|
500
487
|
new_indices
|
501
488
|
), f"{req.prefix_indices=}, {new_indices=}"
|
@@ -734,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
|
|
734
721
|
|
735
722
|
##### Internal Helper Functions #####
|
736
723
|
|
737
|
-
def _match_prefix_helper(
|
724
|
+
def _match_prefix_helper(
|
725
|
+
self, key: RadixKey
|
726
|
+
) -> Tuple[List[torch.Tensor], TreeNode]:
|
738
727
|
"""
|
739
728
|
SWA prefix matching helper. It factors in the sliding window size such that
|
740
729
|
the matched node is guaranteed to either 1. connected to root without swa tombstone,
|
@@ -798,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
|
|
798
787
|
|
799
788
|
return value[:best_value_len], best_last_node
|
800
789
|
|
801
|
-
def _split_node(self, key:
|
790
|
+
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
|
802
791
|
# new_node -> child
|
803
792
|
new_node = TreeNode()
|
804
793
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
@@ -833,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
|
|
833
822
|
return new_node
|
834
823
|
|
835
824
|
def _insert_helper(
|
836
|
-
self, node: TreeNode, key:
|
825
|
+
self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
|
837
826
|
) -> int:
|
838
827
|
# Update the last access time from root to leaf, so that
|
839
828
|
# swa will tombstone the node closer to root first
|