sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
from typing import TYPE_CHECKING, List, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
11
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
12
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
13
|
+
|
14
|
+
try:
|
15
|
+
from lmcache.integration.sglang.sglang_adapter import (
|
16
|
+
LMCacheLayerwiseConnector,
|
17
|
+
LoadMetadata,
|
18
|
+
StoreMetadata,
|
19
|
+
)
|
20
|
+
except ImportError as e:
|
21
|
+
raise RuntimeError(
|
22
|
+
"LMCache is not installed. Please install it by running `pip install lmcache`"
|
23
|
+
) from e
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.configs.model_config import ModelConfig
|
27
|
+
from sglang.srt.managers.schedule_batch import Req
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class LayerTransferCounter:
|
33
|
+
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
|
34
|
+
|
35
|
+
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
|
36
|
+
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
|
37
|
+
within the provided CUDA stream.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_layers: int,
|
43
|
+
load_stream: torch.cuda.Stream,
|
44
|
+
lmc_connector: LMCacheLayerwiseConnector,
|
45
|
+
printable: bool = False,
|
46
|
+
):
|
47
|
+
self.num_layers = num_layers
|
48
|
+
self.load_stream = load_stream
|
49
|
+
self.lmc_connector = lmc_connector
|
50
|
+
|
51
|
+
def wait_until(self, layer_id: int):
|
52
|
+
# Ensure ordering of the async loads wrt compute stream(s).
|
53
|
+
self.load_stream.synchronize()
|
54
|
+
with self.load_stream:
|
55
|
+
self.lmc_connector.load_kv_layerwise(layer_id)
|
56
|
+
|
57
|
+
|
58
|
+
class LMCRadixCache(RadixCache):
|
59
|
+
"""RadixCache + LMCache IO.
|
60
|
+
|
61
|
+
This subclass adds:
|
62
|
+
- LMCache connector setup (device/host buffers, TP rank/size)
|
63
|
+
- Two CUDA streams for async load/store
|
64
|
+
- Layer-wise transfer executor wiring to the KV cache
|
65
|
+
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
|
66
|
+
- Extended cache_finalization paths to store back into LMCache
|
67
|
+
- Eviction barrier that respects any in-flight host->device stores
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
req_to_token_pool: ReqToTokenPool,
|
73
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
74
|
+
page_size: int,
|
75
|
+
disable: bool = False,
|
76
|
+
enable_kv_cache_events: bool = False,
|
77
|
+
model_config: Optional["ModelConfig"] = None,
|
78
|
+
tp_size: int = 1,
|
79
|
+
rank: int = 0,
|
80
|
+
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
81
|
+
eviction_policy: str = "lru",
|
82
|
+
):
|
83
|
+
super().__init__(
|
84
|
+
req_to_token_pool=req_to_token_pool,
|
85
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
86
|
+
page_size=page_size,
|
87
|
+
disable=disable,
|
88
|
+
enable_kv_cache_events=enable_kv_cache_events,
|
89
|
+
eviction_policy=eviction_policy,
|
90
|
+
)
|
91
|
+
|
92
|
+
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
93
|
+
self.lmcache_connector = LMCacheLayerwiseConnector(
|
94
|
+
sgl_config=model_config,
|
95
|
+
tp_size=tp_size,
|
96
|
+
rank=rank,
|
97
|
+
# NOTE: The original implementation accessed private buffers via
|
98
|
+
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
|
99
|
+
# available; fall back to private fields if needed.
|
100
|
+
k_pool=getattr(
|
101
|
+
kvcache,
|
102
|
+
"k_buffer",
|
103
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
|
104
|
+
),
|
105
|
+
v_pool=getattr(
|
106
|
+
kvcache,
|
107
|
+
"v_buffer",
|
108
|
+
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
|
109
|
+
),
|
110
|
+
tp_group=tp_group,
|
111
|
+
)
|
112
|
+
|
113
|
+
self.load_stream = torch.cuda.Stream()
|
114
|
+
self.store_stream = torch.cuda.Stream()
|
115
|
+
|
116
|
+
self.layer_done_executor = LayerTransferCounter(
|
117
|
+
num_layers=(
|
118
|
+
model_config.num_hidden_layers if model_config is not None else 0
|
119
|
+
),
|
120
|
+
load_stream=self.load_stream,
|
121
|
+
lmc_connector=self.lmcache_connector,
|
122
|
+
)
|
123
|
+
kvcache.register_layer_transfer_counter(self.layer_done_executor)
|
124
|
+
|
125
|
+
self._in_flight_nodes: list[TreeNode] = []
|
126
|
+
self._node_lock = threading.Lock()
|
127
|
+
|
128
|
+
def reset(self): # type: ignore[override]
|
129
|
+
super().reset()
|
130
|
+
if hasattr(self, "_in_flight_nodes"):
|
131
|
+
with self._node_lock:
|
132
|
+
self._in_flight_nodes.clear()
|
133
|
+
|
134
|
+
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
|
135
|
+
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
136
|
+
|
137
|
+
Reuses the base matching logic to obtain (value, last_node). If there
|
138
|
+
remains a *page-aligned* uncached suffix and there is room (or after
|
139
|
+
eviction), we allocate token slots and trigger an async LMCache load
|
140
|
+
into those slots, then materialize a new child node for the retrieved
|
141
|
+
chunk.
|
142
|
+
"""
|
143
|
+
if self.disable or not key:
|
144
|
+
return super().match_prefix(key, **kwargs)
|
145
|
+
|
146
|
+
if self.page_size != 1:
|
147
|
+
aligned_len = len(key) // self.page_size * self.page_size
|
148
|
+
key = key[:aligned_len]
|
149
|
+
|
150
|
+
base_res = super().match_prefix(key, **kwargs)
|
151
|
+
value: torch.Tensor = base_res.device_indices
|
152
|
+
last_node: TreeNode = base_res.last_device_node
|
153
|
+
|
154
|
+
if value.numel() == len(key):
|
155
|
+
return base_res
|
156
|
+
|
157
|
+
uncached_len = len(key) - value.numel()
|
158
|
+
if uncached_len == 0:
|
159
|
+
return base_res
|
160
|
+
|
161
|
+
chunk_size = self.lmcache_connector.chunk_size()
|
162
|
+
prefix_pad = value.numel() % chunk_size
|
163
|
+
|
164
|
+
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
|
165
|
+
self.evict(uncached_len)
|
166
|
+
|
167
|
+
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
|
168
|
+
if token_slots is None:
|
169
|
+
return base_res
|
170
|
+
|
171
|
+
slot_mapping = torch.cat(
|
172
|
+
[
|
173
|
+
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
|
174
|
+
token_slots.detach().clone().to(torch.int64).to(self.device),
|
175
|
+
]
|
176
|
+
)
|
177
|
+
|
178
|
+
with torch.cuda.stream(self.load_stream):
|
179
|
+
num_retrieved = self.lmcache_connector.start_load_kv(
|
180
|
+
LoadMetadata(
|
181
|
+
token_ids=key.token_ids, # full page-aligned key
|
182
|
+
slot_mapping=slot_mapping,
|
183
|
+
offset=value.numel() - prefix_pad, # LMCache offset convention
|
184
|
+
)
|
185
|
+
)
|
186
|
+
logger.debug("num_retrieved_tokens: %s", num_retrieved)
|
187
|
+
|
188
|
+
if num_retrieved > 0:
|
189
|
+
self.token_to_kv_pool_allocator.free(
|
190
|
+
token_slots[(num_retrieved - prefix_pad) :]
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
self.token_to_kv_pool_allocator.free(token_slots)
|
194
|
+
|
195
|
+
if num_retrieved > 0:
|
196
|
+
fetched = num_retrieved - prefix_pad
|
197
|
+
new_node = TreeNode()
|
198
|
+
start = value.numel()
|
199
|
+
end = start + fetched
|
200
|
+
new_node.key = key[start:end]
|
201
|
+
new_node.value = token_slots[:fetched]
|
202
|
+
new_node.parent = last_node
|
203
|
+
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
|
204
|
+
last_node = new_node
|
205
|
+
|
206
|
+
value = torch.cat([value, token_slots[:fetched]])
|
207
|
+
self.evictable_size_ += fetched
|
208
|
+
|
209
|
+
self._record_store_event(new_node.parent)
|
210
|
+
self._record_store_event(new_node)
|
211
|
+
|
212
|
+
return MatchResult(
|
213
|
+
device_indices=value,
|
214
|
+
last_device_node=last_node,
|
215
|
+
last_host_node=last_node,
|
216
|
+
)
|
217
|
+
|
218
|
+
return base_res
|
219
|
+
|
220
|
+
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
221
|
+
"""On request completion, insert device KV into radix and store to LMCache."""
|
222
|
+
|
223
|
+
super().cache_finished_req(req)
|
224
|
+
|
225
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
226
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
227
|
+
req.req_pool_idx, : len(token_ids)
|
228
|
+
]
|
229
|
+
|
230
|
+
_, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
|
231
|
+
assert new_last_node is not None
|
232
|
+
|
233
|
+
self.inc_lock_ref(new_last_node)
|
234
|
+
store_md = StoreMetadata(
|
235
|
+
last_node=new_last_node,
|
236
|
+
token_ids=token_ids,
|
237
|
+
kv_indices=kv_indices,
|
238
|
+
offset=0,
|
239
|
+
)
|
240
|
+
with torch.cuda.stream(self.store_stream):
|
241
|
+
self.lmcache_connector.store_kv(store_md)
|
242
|
+
with self._node_lock:
|
243
|
+
self._in_flight_nodes.append(new_last_node)
|
244
|
+
|
245
|
+
def evict(self, num_tokens: int) -> None: # type: ignore[override]
|
246
|
+
"""Before base eviction, wait for any outstanding stores and release locks."""
|
247
|
+
if self.disable:
|
248
|
+
return
|
249
|
+
|
250
|
+
self.store_stream.synchronize()
|
251
|
+
with self._node_lock:
|
252
|
+
for node in self._in_flight_nodes:
|
253
|
+
self.dec_lock_ref(node)
|
254
|
+
self._in_flight_nodes.clear()
|
255
|
+
|
256
|
+
super().evict(num_tokens)
|
257
|
+
|
258
|
+
def pretty_print(self): # type: ignore[override]
|
259
|
+
super().pretty_print()
|
260
|
+
try:
|
261
|
+
logger.debug(
|
262
|
+
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
|
263
|
+
)
|
264
|
+
except Exception: # pragma: no cover
|
265
|
+
pass
|
266
|
+
|
267
|
+
|
268
|
+
if __name__ == "__main__":
|
269
|
+
cache = LMCRadixCache(
|
270
|
+
req_to_token_pool=None,
|
271
|
+
token_to_kv_pool_allocator=None,
|
272
|
+
page_size=1,
|
273
|
+
disable=False,
|
274
|
+
enable_kv_cache_events=False,
|
275
|
+
model_config=None,
|
276
|
+
tp_size=1,
|
277
|
+
rank=0,
|
278
|
+
tp_group=None,
|
279
|
+
)
|
280
|
+
cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
|
281
|
+
cache.insert(
|
282
|
+
RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
|
283
|
+
)
|
284
|
+
cache.pretty_print()
|
@@ -0,0 +1,121 @@
|
|
1
|
+
try:
|
2
|
+
from lmcache.integration.sglang.sglang_adapter import (
|
3
|
+
LMCacheLayerwiseConnector,
|
4
|
+
LoadMetadata,
|
5
|
+
StoreMetadata,
|
6
|
+
)
|
7
|
+
except ImportError:
|
8
|
+
raise RuntimeError(
|
9
|
+
"LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
|
10
|
+
)
|
11
|
+
|
12
|
+
import os
|
13
|
+
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from sglang.srt.configs.model_config import ModelConfig
|
17
|
+
|
18
|
+
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
19
|
+
os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
|
20
|
+
|
21
|
+
|
22
|
+
def test_load_store_metadata():
|
23
|
+
model_config = ModelConfig(
|
24
|
+
model_path="Qwen/Qwen3-4B",
|
25
|
+
)
|
26
|
+
|
27
|
+
# Generate Dummy KV Cache
|
28
|
+
head_num = model_config.num_key_value_heads
|
29
|
+
head_dim = model_config.head_dim
|
30
|
+
layer_num = model_config.num_hidden_layers
|
31
|
+
buffer_size = 256
|
32
|
+
input_id_len = 16
|
33
|
+
|
34
|
+
k_buffer = [
|
35
|
+
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
36
|
+
for _ in range(layer_num)
|
37
|
+
]
|
38
|
+
v_buffer = [
|
39
|
+
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
40
|
+
for _ in range(layer_num)
|
41
|
+
]
|
42
|
+
|
43
|
+
connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
|
44
|
+
|
45
|
+
fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
|
46
|
+
fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
|
47
|
+
offset = 0
|
48
|
+
|
49
|
+
store_metadata = StoreMetadata(
|
50
|
+
last_node=None,
|
51
|
+
token_ids=fake_token_ids,
|
52
|
+
kv_indices=fake_kv_indices,
|
53
|
+
offset=offset,
|
54
|
+
)
|
55
|
+
|
56
|
+
load_metadata = LoadMetadata(
|
57
|
+
token_ids=fake_token_ids,
|
58
|
+
slot_mapping=fake_kv_indices,
|
59
|
+
offset=offset,
|
60
|
+
)
|
61
|
+
|
62
|
+
current_stream = torch.cuda.current_stream()
|
63
|
+
|
64
|
+
retrieve_token_num = connector.start_load_kv(load_metadata)
|
65
|
+
assert retrieve_token_num == 0
|
66
|
+
|
67
|
+
connector.store_kv(store_metadata)
|
68
|
+
current_stream.synchronize()
|
69
|
+
|
70
|
+
# check retrieve
|
71
|
+
gt_key_buffer = [
|
72
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
73
|
+
for _ in range(layer_num)
|
74
|
+
]
|
75
|
+
gt_value_buffer = [
|
76
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
77
|
+
for _ in range(layer_num)
|
78
|
+
]
|
79
|
+
|
80
|
+
for i in range(layer_num):
|
81
|
+
gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
82
|
+
gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
83
|
+
|
84
|
+
# clear the k_buffer and v_buffer
|
85
|
+
for _ in range(layer_num):
|
86
|
+
k_buffer[i].zero_()
|
87
|
+
v_buffer[i].zero_()
|
88
|
+
|
89
|
+
retrieve_token_num = connector.start_load_kv(load_metadata)
|
90
|
+
assert retrieve_token_num == input_id_len
|
91
|
+
|
92
|
+
for i in range(layer_num):
|
93
|
+
current_stream.synchronize()
|
94
|
+
connector.load_kv_layerwise(i)
|
95
|
+
|
96
|
+
current_stream.synchronize()
|
97
|
+
test_key_buffer = [
|
98
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
99
|
+
for _ in range(layer_num)
|
100
|
+
]
|
101
|
+
test_value_buffer = [
|
102
|
+
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
103
|
+
for _ in range(layer_num)
|
104
|
+
]
|
105
|
+
|
106
|
+
for i in range(layer_num):
|
107
|
+
test_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
108
|
+
test_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
109
|
+
|
110
|
+
for i in range(layer_num):
|
111
|
+
assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
|
112
|
+
assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
|
113
|
+
|
114
|
+
print("================================================")
|
115
|
+
print("TEST_LOAD_STORE_METADATA PASSED!")
|
116
|
+
print("================================================")
|
117
|
+
connector.close()
|
118
|
+
|
119
|
+
|
120
|
+
if __name__ == "__main__":
|
121
|
+
test_load_store_metadata()
|
@@ -7,11 +7,16 @@ from typing import Any, List, Optional
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.mem_cache.hicache_storage import
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import (
|
11
|
+
HiCacheStorage,
|
12
|
+
HiCacheStorageConfig,
|
13
|
+
HiCacheStorageExtraInfo,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
11
16
|
|
12
17
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
13
18
|
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
14
|
-
|
19
|
+
DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "SGLANG_HICACHE_MOONCAKE_CONFIG_PATH"
|
15
20
|
logger = logging.getLogger(__name__)
|
16
21
|
|
17
22
|
|
@@ -28,13 +33,13 @@ class MooncakeStoreConfig:
|
|
28
33
|
@staticmethod
|
29
34
|
def from_file() -> "MooncakeStoreConfig":
|
30
35
|
"""Load the config from a JSON file."""
|
31
|
-
file_path = os.getenv(
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
36
|
+
file_path = os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV)
|
37
|
+
try:
|
38
|
+
with open(file_path) as fin:
|
39
|
+
config = json.load(fin)
|
40
|
+
except Exception as e:
|
41
|
+
raise RuntimeError(f"Failed to load config from {file_path}: {str(e)}")
|
42
|
+
|
38
43
|
return MooncakeStoreConfig(
|
39
44
|
local_hostname=config.get("local_hostname"),
|
40
45
|
metadata_server=config.get("metadata_server"),
|
@@ -72,6 +77,26 @@ class MooncakeStoreConfig:
|
|
72
77
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
73
78
|
)
|
74
79
|
|
80
|
+
@staticmethod
|
81
|
+
def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig":
|
82
|
+
"""Load config from extra_config dictionary."""
|
83
|
+
if "master_server_address" not in extra_config:
|
84
|
+
raise ValueError("master_server_address is required in extra_config")
|
85
|
+
|
86
|
+
return MooncakeStoreConfig(
|
87
|
+
local_hostname=extra_config.get("local_hostname", "localhost"),
|
88
|
+
metadata_server=extra_config.get("metadata_server", "P2PHANDSHAKE"),
|
89
|
+
global_segment_size=extra_config.get(
|
90
|
+
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
91
|
+
),
|
92
|
+
local_buffer_size=extra_config.get(
|
93
|
+
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
94
|
+
),
|
95
|
+
protocol=extra_config.get("protocol", "tcp"),
|
96
|
+
device_name=extra_config.get("device_name", "auto"),
|
97
|
+
master_server_address=extra_config["master_server_address"],
|
98
|
+
)
|
99
|
+
|
75
100
|
def __post_init__(self):
|
76
101
|
if self.device_name == "auto":
|
77
102
|
os.environ["MC_MS_AUTO_DISC"] = "1"
|
@@ -81,6 +106,7 @@ class MooncakeStoreConfig:
|
|
81
106
|
|
82
107
|
|
83
108
|
class MooncakeStore(HiCacheStorage):
|
109
|
+
|
84
110
|
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
85
111
|
try:
|
86
112
|
from mooncake.store import MooncakeDistributedStore
|
@@ -93,14 +119,43 @@ class MooncakeStore(HiCacheStorage):
|
|
93
119
|
|
94
120
|
try:
|
95
121
|
self.store = MooncakeDistributedStore()
|
96
|
-
|
97
|
-
|
122
|
+
|
123
|
+
extra_config = (
|
124
|
+
getattr(storage_config, "extra_config", None)
|
125
|
+
if storage_config
|
126
|
+
else None
|
127
|
+
)
|
128
|
+
# Load configuration with master_server_address prioritized from extra_config if available
|
129
|
+
if (
|
130
|
+
extra_config is not None
|
131
|
+
and extra_config.get("master_server_address") is not None
|
132
|
+
):
|
133
|
+
# Load from extra_config
|
134
|
+
self.config = MooncakeStoreConfig.load_from_extra_config(extra_config)
|
135
|
+
logger.info(
|
136
|
+
"Mooncake Configuration loaded from extra_config successfully."
|
137
|
+
)
|
138
|
+
elif os.getenv(DEFAULT_MOONCAKE_CONFIG_PATH_ENV):
|
139
|
+
# Load from config file
|
140
|
+
self.config = MooncakeStoreConfig.from_file()
|
141
|
+
logger.info("Mooncake Configuration loaded from file successfully.")
|
142
|
+
else:
|
143
|
+
# Load from environment variables
|
144
|
+
self.config = MooncakeStoreConfig.load_from_env()
|
145
|
+
logger.info("Mooncake Configuration loaded from env successfully.")
|
146
|
+
|
147
|
+
tp_scale_factor = 1 if storage_config is None else storage_config.tp_size
|
148
|
+
|
149
|
+
per_tp_global_segment_size = (
|
150
|
+
self.config.global_segment_size // tp_scale_factor
|
151
|
+
)
|
152
|
+
per_tp_local_buffer_size = self.config.local_buffer_size // tp_scale_factor
|
98
153
|
|
99
154
|
ret_code = self.store.setup(
|
100
155
|
self.config.local_hostname,
|
101
156
|
self.config.metadata_server,
|
102
|
-
|
103
|
-
|
157
|
+
per_tp_global_segment_size,
|
158
|
+
per_tp_local_buffer_size,
|
104
159
|
self.config.protocol,
|
105
160
|
self.config.device_name,
|
106
161
|
self.config.master_server_address,
|
@@ -133,7 +188,13 @@ class MooncakeStore(HiCacheStorage):
|
|
133
188
|
assert self.store.is_exist(warmup_key) == 1
|
134
189
|
assert self.store.get(warmup_key) == warmup_value
|
135
190
|
|
136
|
-
def
|
191
|
+
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
|
192
|
+
super().register_mem_pool_host(mem_pool_host)
|
193
|
+
assert self.mem_pool_host.layout in [
|
194
|
+
"page_first",
|
195
|
+
"page_first_direct",
|
196
|
+
], "mooncake store storage backend only support page first or page first direct layout"
|
197
|
+
buffer = self.mem_pool_host.kv_buffer
|
137
198
|
try:
|
138
199
|
buffer_ptr = buffer.data_ptr()
|
139
200
|
buffer_size = buffer.numel() * buffer.element_size()
|
@@ -144,6 +205,97 @@ class MooncakeStore(HiCacheStorage):
|
|
144
205
|
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
145
206
|
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
146
207
|
|
208
|
+
def _get_mha_buffer_meta(self, keys, indices):
|
209
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
210
|
+
key_list = []
|
211
|
+
for key_ in keys:
|
212
|
+
key_list.append(f"{key_}_{self.local_rank}_k")
|
213
|
+
key_list.append(f"{key_}_{self.local_rank}_v")
|
214
|
+
assert len(key_list) == len(ptr_list)
|
215
|
+
return key_list, ptr_list, element_size_list
|
216
|
+
|
217
|
+
def _get_mla_buffer_meta(self, keys, indices):
|
218
|
+
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
|
219
|
+
key_list = []
|
220
|
+
for key_ in keys:
|
221
|
+
key_list.append(f"{key_}_k")
|
222
|
+
assert len(key_list) == len(ptr_list)
|
223
|
+
return key_list, ptr_list, element_size_list
|
224
|
+
|
225
|
+
def _batch_preprocess(self, keys, host_indices):
|
226
|
+
assert len(keys) > 0
|
227
|
+
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
|
228
|
+
if self.is_mla_backend:
|
229
|
+
return self._get_mla_buffer_meta(keys, host_indices)
|
230
|
+
else:
|
231
|
+
return self._get_mha_buffer_meta(keys, host_indices)
|
232
|
+
|
233
|
+
def _batch_postprocess(self, results: List[int], is_set_operate=False):
|
234
|
+
"""
|
235
|
+
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
|
236
|
+
for batch_get_into, results is Vector of integers,
|
237
|
+
where each element is the number of bytes read on success, or a negative value on error
|
238
|
+
for batch_put_from, results is Vector of integers,
|
239
|
+
where each element is 0 on success, or a negative value on error
|
240
|
+
"""
|
241
|
+
if self.is_mla_backend:
|
242
|
+
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
|
243
|
+
else:
|
244
|
+
kv_pairs = zip(results[::2], results[1::2])
|
245
|
+
return [
|
246
|
+
(
|
247
|
+
(k_res == 0 and v_res == 0)
|
248
|
+
if is_set_operate
|
249
|
+
else (k_res > 0 and v_res > 0)
|
250
|
+
)
|
251
|
+
for k_res, v_res in kv_pairs
|
252
|
+
]
|
253
|
+
|
254
|
+
def batch_get_v1(
|
255
|
+
self,
|
256
|
+
keys: List[str],
|
257
|
+
host_indices: torch.Tensor,
|
258
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
259
|
+
) -> List[bool]:
|
260
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
261
|
+
get_results = self._get_batch_zero_copy_impl(
|
262
|
+
key_strs, buffer_ptrs, buffer_sizes
|
263
|
+
)
|
264
|
+
return self._batch_postprocess(get_results, is_set_operate=False)
|
265
|
+
|
266
|
+
def batch_set_v1(
|
267
|
+
self,
|
268
|
+
keys: List[str],
|
269
|
+
host_indices: torch.Tensor,
|
270
|
+
extra_info: Optional[HiCacheStorageExtraInfo] = None,
|
271
|
+
) -> List[bool]:
|
272
|
+
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
|
273
|
+
exist_result = self._batch_exist(key_strs)
|
274
|
+
|
275
|
+
set_keys = []
|
276
|
+
set_buffer_ptrs = []
|
277
|
+
set_buffer_sizes = []
|
278
|
+
set_indices = []
|
279
|
+
set_results = [-1] * len(key_strs)
|
280
|
+
for i in range(len(key_strs)):
|
281
|
+
if exist_result[i] != 1:
|
282
|
+
set_keys.append(key_strs[i])
|
283
|
+
set_buffer_ptrs.append(buffer_ptrs[i])
|
284
|
+
set_buffer_sizes.append(buffer_sizes[i])
|
285
|
+
set_indices.append(i)
|
286
|
+
else:
|
287
|
+
set_results[i] = 0
|
288
|
+
|
289
|
+
# Only set non-existing keys to storage
|
290
|
+
if len(set_keys) > 0:
|
291
|
+
put_results = self._put_batch_zero_copy_impl(
|
292
|
+
set_keys, set_buffer_ptrs, set_buffer_sizes
|
293
|
+
)
|
294
|
+
for i in range(len(set_indices)):
|
295
|
+
set_results[set_indices[i]] = put_results[i]
|
296
|
+
|
297
|
+
return self._batch_postprocess(set_results, is_set_operate=True)
|
298
|
+
|
147
299
|
def set(
|
148
300
|
self,
|
149
301
|
key,
|
@@ -264,9 +416,6 @@ class MooncakeStore(HiCacheStorage):
|
|
264
416
|
return i // key_multiplier
|
265
417
|
return len(query_keys) // key_multiplier
|
266
418
|
|
267
|
-
def delete(self, key) -> None:
|
268
|
-
raise (NotImplementedError)
|
269
|
-
|
270
419
|
def close(self):
|
271
420
|
# MooncakeDistributedStore will automatically call the destructor, so
|
272
421
|
# it is unnecessary to close it manually.
|