sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import concurrent.futures
|
5
4
|
import ctypes
|
6
5
|
import dataclasses
|
7
6
|
import logging
|
8
7
|
import os
|
9
|
-
import queue
|
10
|
-
import socket
|
11
8
|
import struct
|
12
9
|
import threading
|
13
10
|
import time
|
14
11
|
from collections import defaultdict
|
15
|
-
from
|
16
|
-
from typing import Dict, List, Optional, Tuple, Union
|
12
|
+
from typing import Dict, List, Optional, Tuple
|
17
13
|
|
18
14
|
import numpy as np
|
19
15
|
import numpy.typing as npt
|
20
16
|
import requests
|
21
17
|
import zmq
|
22
|
-
|
23
|
-
|
24
|
-
from sglang.srt.disaggregation.
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
KVArgs,
|
30
|
-
KVPoll,
|
18
|
+
|
19
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
20
|
+
from sglang.srt.disaggregation.common.conn import (
|
21
|
+
CommonKVBootstrapServer,
|
22
|
+
CommonKVManager,
|
23
|
+
CommonKVReceiver,
|
24
|
+
CommonKVSender,
|
31
25
|
)
|
32
26
|
from sglang.srt.disaggregation.common.utils import (
|
33
27
|
FastQueue,
|
@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
29
|
)
|
36
30
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
37
31
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
38
|
-
from sglang.srt.distributed import get_pp_group
|
39
|
-
from sglang.srt.layers.dp_attention import (
|
40
|
-
get_attention_dp_rank,
|
41
|
-
get_attention_dp_size,
|
42
|
-
get_attention_tp_rank,
|
43
|
-
get_attention_tp_size,
|
44
|
-
)
|
45
32
|
from sglang.srt.server_args import ServerArgs
|
46
33
|
from sglang.srt.utils import (
|
47
34
|
format_tcp_address,
|
48
35
|
get_bool_env_var,
|
49
|
-
get_free_port,
|
50
36
|
get_int_env_var,
|
51
|
-
get_ip,
|
52
|
-
get_local_ip_auto,
|
53
37
|
is_valid_ipv6_address,
|
54
|
-
maybe_wrap_ipv6_address,
|
55
38
|
)
|
56
39
|
|
57
40
|
logger = logging.getLogger(__name__)
|
@@ -159,7 +142,7 @@ class AuxDataCodec:
|
|
159
142
|
return
|
160
143
|
|
161
144
|
|
162
|
-
class MooncakeKVManager(
|
145
|
+
class MooncakeKVManager(CommonKVManager):
|
163
146
|
AUX_DATA_HEADER = b"AUX_DATA"
|
164
147
|
|
165
148
|
def __init__(
|
@@ -169,42 +152,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
169
152
|
server_args: ServerArgs,
|
170
153
|
is_mla_backend: Optional[bool] = False,
|
171
154
|
):
|
172
|
-
|
173
|
-
self.local_ip = get_local_ip_auto()
|
174
|
-
self.is_mla_backend = is_mla_backend
|
175
|
-
self.disaggregation_mode = disaggregation_mode
|
155
|
+
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
|
176
156
|
self.init_engine()
|
177
|
-
# for p/d multi node infer
|
178
|
-
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
179
|
-
self.dist_init_addr = server_args.dist_init_addr
|
180
|
-
self.attn_tp_size = get_attention_tp_size()
|
181
|
-
self.attn_tp_rank = get_attention_tp_rank()
|
182
|
-
self.attn_dp_size = get_attention_dp_size()
|
183
|
-
self.attn_dp_rank = get_attention_dp_rank()
|
184
|
-
self.system_dp_size = (
|
185
|
-
1 if server_args.enable_dp_attention else server_args.dp_size
|
186
|
-
)
|
187
|
-
self.system_dp_rank = (
|
188
|
-
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
189
|
-
)
|
190
|
-
self.pp_size = server_args.pp_size
|
191
|
-
self.pp_rank = self.kv_args.pp_rank
|
192
|
-
self.request_status: Dict[int, KVPoll] = {}
|
193
|
-
self.rank_port = None
|
194
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
195
|
-
if is_valid_ipv6_address(self.local_ip):
|
196
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
197
|
-
|
198
157
|
self.register_buffer_to_engine()
|
199
158
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
200
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
201
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
202
159
|
self.start_prefill_thread()
|
203
|
-
self._register_to_bootstrap()
|
204
160
|
self.session_failures = defaultdict(int)
|
205
161
|
self.failed_sessions = set()
|
206
162
|
self.session_lock = threading.Lock()
|
207
|
-
self.pp_group = get_pp_group()
|
208
163
|
# Determine the number of threads to use for kv sender
|
209
164
|
cpu_count = os.cpu_count()
|
210
165
|
transfer_thread_pool_size = get_int_env_var(
|
@@ -244,8 +199,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
244
199
|
self.session_pool = defaultdict(requests.Session)
|
245
200
|
self.session_pool_lock = threading.Lock()
|
246
201
|
self.addr_to_rooms_tracker = defaultdict(set)
|
247
|
-
self.connection_lock = threading.Lock()
|
248
|
-
self.required_prefill_response_num_table: Dict[int, int] = {}
|
249
202
|
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
250
203
|
# Heartbeat interval should be at least 2 seconds
|
251
204
|
self.heartbeat_interval = max(
|
@@ -256,20 +209,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
256
209
|
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
257
210
|
)
|
258
211
|
self.start_decode_thread()
|
259
|
-
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
260
|
-
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
261
|
-
self.prefill_dp_size_table: Dict[str, int] = {}
|
262
|
-
self.prefill_pp_size_table: Dict[str, int] = {}
|
263
212
|
# If a timeout happens on the decode side, it means decode instances
|
264
213
|
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
265
214
|
# These timeout requests should be aborted to release the tree cache.
|
266
215
|
self.waiting_timeout = get_int_env_var(
|
267
216
|
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
|
268
217
|
)
|
269
|
-
else:
|
270
|
-
raise ValueError(
|
271
|
-
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
272
|
-
)
|
273
218
|
|
274
219
|
self.failure_records: Dict[int, str] = {}
|
275
220
|
self.failure_lock = threading.Lock()
|
@@ -294,14 +239,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
294
239
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
295
240
|
)
|
296
241
|
|
297
|
-
@cache
|
298
|
-
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
299
|
-
socket = zmq.Context().socket(zmq.PUSH)
|
300
|
-
if is_ipv6:
|
301
|
-
socket.setsockopt(zmq.IPV6, 1)
|
302
|
-
socket.connect(endpoint)
|
303
|
-
return socket
|
304
|
-
|
305
242
|
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
306
243
|
if not transfer_blocks:
|
307
244
|
return 0
|
@@ -327,12 +264,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
327
264
|
layers_params = None
|
328
265
|
|
329
266
|
# pp is not supported on the decode side yet
|
330
|
-
start_layer = self.kv_args.prefill_start_layer
|
331
|
-
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
332
267
|
if self.is_mla_backend:
|
333
|
-
src_kv_ptrs =
|
334
|
-
|
335
|
-
|
268
|
+
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
269
|
+
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
270
|
+
)
|
336
271
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
337
272
|
layers_params = [
|
338
273
|
(
|
@@ -340,18 +275,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
340
275
|
dst_kv_ptrs[layer_id],
|
341
276
|
kv_item_len,
|
342
277
|
)
|
343
|
-
for layer_id in range(
|
278
|
+
for layer_id in range(layers_current_pp_stage)
|
344
279
|
]
|
345
280
|
else:
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
350
|
-
layers_per_pp_stage = len(src_k_ptrs)
|
351
|
-
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
352
|
-
dst_v_ptrs = dst_kv_ptrs[
|
353
|
-
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
354
|
-
]
|
281
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
282
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
283
|
+
)
|
355
284
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
356
285
|
layers_params = [
|
357
286
|
(
|
@@ -359,14 +288,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
359
288
|
dst_k_ptrs[layer_id],
|
360
289
|
kv_item_len,
|
361
290
|
)
|
362
|
-
for layer_id in range(
|
291
|
+
for layer_id in range(layers_current_pp_stage)
|
363
292
|
] + [
|
364
293
|
(
|
365
294
|
src_v_ptrs[layer_id],
|
366
295
|
dst_v_ptrs[layer_id],
|
367
296
|
kv_item_len,
|
368
297
|
)
|
369
|
-
for layer_id in range(
|
298
|
+
for layer_id in range(layers_current_pp_stage)
|
370
299
|
]
|
371
300
|
assert layers_params is not None
|
372
301
|
|
@@ -458,22 +387,15 @@ class MooncakeKVManager(BaseKVManager):
|
|
458
387
|
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
459
388
|
else:
|
460
389
|
# Send KVCache from 1 prefill instance to multiple decode instances
|
461
|
-
src_head_start_offset =
|
390
|
+
src_head_start_offset = (
|
391
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
392
|
+
) % src_heads_per_rank
|
462
393
|
num_heads_to_send = dst_heads_per_rank
|
463
394
|
dst_head_start_offset = 0
|
464
395
|
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
469
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
470
|
-
layers_per_pp_stage = len(src_k_ptrs)
|
471
|
-
start_layer = self.pp_rank * layers_per_pp_stage
|
472
|
-
end_layer = start_layer + layers_per_pp_stage
|
473
|
-
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
474
|
-
dst_v_ptrs = dst_kv_ptrs[
|
475
|
-
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
476
|
-
]
|
396
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
397
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
398
|
+
)
|
477
399
|
|
478
400
|
# Calculate precise byte offset and length for the sub-slice within the token
|
479
401
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
@@ -499,7 +421,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
499
421
|
dst_head_slice_offset,
|
500
422
|
heads_bytes_per_token_to_send,
|
501
423
|
)
|
502
|
-
for layer_id in range(
|
424
|
+
for layer_id in range(layers_current_pp_stage)
|
503
425
|
] + [
|
504
426
|
(
|
505
427
|
src_v_ptrs[layer_id],
|
@@ -510,7 +432,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
510
432
|
dst_head_slice_offset,
|
511
433
|
heads_bytes_per_token_to_send,
|
512
434
|
)
|
513
|
-
for layer_id in range(
|
435
|
+
for layer_id in range(layers_current_pp_stage)
|
514
436
|
]
|
515
437
|
|
516
438
|
def process_layer_tp_aware(layer_params):
|
@@ -651,6 +573,26 @@ class MooncakeKVManager(BaseKVManager):
|
|
651
573
|
]
|
652
574
|
)
|
653
575
|
|
576
|
+
def _handle_aux_data(self, msg: List[bytes]):
|
577
|
+
"""Handle AUX_DATA messages received by the decode thread."""
|
578
|
+
room = int(msg[1].decode("ascii"))
|
579
|
+
buffer_index = int(msg[2].decode("ascii"))
|
580
|
+
aux_index = int(msg[3].decode("ascii"))
|
581
|
+
data_length = struct.unpack(">I", msg[4])[0]
|
582
|
+
data = msg[5]
|
583
|
+
|
584
|
+
if len(data) != data_length:
|
585
|
+
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
586
|
+
return
|
587
|
+
|
588
|
+
AuxDataCodec.deserialize_data_to_buffer(
|
589
|
+
self.kv_args, buffer_index, aux_index, data
|
590
|
+
)
|
591
|
+
|
592
|
+
logger.debug(
|
593
|
+
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
594
|
+
)
|
595
|
+
|
654
596
|
def sync_status_to_decode_endpoint(
|
655
597
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
656
598
|
):
|
@@ -799,11 +741,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
799
741
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
800
742
|
)
|
801
743
|
|
802
|
-
def _bind_server_socket(self):
|
803
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
804
|
-
|
805
744
|
def start_prefill_thread(self):
|
806
|
-
self.rank_port = get_free_port()
|
807
745
|
self._bind_server_socket()
|
808
746
|
|
809
747
|
def bootstrap_thread():
|
@@ -841,28 +779,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
841
779
|
|
842
780
|
threading.Thread(target=bootstrap_thread).start()
|
843
781
|
|
844
|
-
def _handle_aux_data(self, msg: List[bytes]):
|
845
|
-
"""Handle AUX_DATA messages received by the decode thread."""
|
846
|
-
room = int(msg[1].decode("ascii"))
|
847
|
-
buffer_index = int(msg[2].decode("ascii"))
|
848
|
-
aux_index = int(msg[3].decode("ascii"))
|
849
|
-
data_length = struct.unpack(">I", msg[4])[0]
|
850
|
-
data = msg[5]
|
851
|
-
|
852
|
-
if len(data) != data_length:
|
853
|
-
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
854
|
-
return
|
855
|
-
|
856
|
-
AuxDataCodec.deserialize_data_to_buffer(
|
857
|
-
self.kv_args, buffer_index, aux_index, data
|
858
|
-
)
|
859
|
-
|
860
|
-
logger.debug(
|
861
|
-
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
862
|
-
)
|
863
|
-
|
864
782
|
def start_decode_thread(self):
|
865
|
-
self.rank_port = get_free_port()
|
866
783
|
self._bind_server_socket()
|
867
784
|
|
868
785
|
def decode_thread():
|
@@ -1017,49 +934,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
1017
934
|
def get_session_id(self):
|
1018
935
|
return self.engine.get_session_id()
|
1019
936
|
|
1020
|
-
def _register_to_bootstrap(self):
|
1021
|
-
"""Register KVSender to bootstrap server via HTTP POST."""
|
1022
|
-
if self.dist_init_addr:
|
1023
|
-
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
1024
|
-
if self.dist_init_addr.endswith("]"):
|
1025
|
-
host = self.dist_init_addr
|
1026
|
-
else:
|
1027
|
-
host, _ = self.dist_init_addr.rsplit(":", 1)
|
1028
|
-
else:
|
1029
|
-
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
1030
|
-
else:
|
1031
|
-
host = get_ip()
|
1032
|
-
host = maybe_wrap_ipv6_address(host)
|
1033
|
-
|
1034
|
-
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
1035
|
-
url = f"http://{bootstrap_server_url}/route"
|
1036
|
-
payload = {
|
1037
|
-
"role": "Prefill",
|
1038
|
-
"attn_tp_size": self.attn_tp_size,
|
1039
|
-
"attn_tp_rank": self.attn_tp_rank,
|
1040
|
-
"attn_dp_size": self.attn_dp_size,
|
1041
|
-
"attn_dp_rank": self.attn_dp_rank,
|
1042
|
-
"pp_size": self.pp_size,
|
1043
|
-
"pp_rank": self.pp_rank,
|
1044
|
-
"system_dp_size": self.system_dp_size,
|
1045
|
-
"system_dp_rank": self.system_dp_rank,
|
1046
|
-
"rank_ip": self.local_ip,
|
1047
|
-
"rank_port": self.rank_port,
|
1048
|
-
}
|
1049
|
-
|
1050
|
-
try:
|
1051
|
-
response = requests.put(url, json=payload, timeout=5)
|
1052
|
-
if response.status_code == 200:
|
1053
|
-
logger.debug("Prefill successfully registered to bootstrap server.")
|
1054
|
-
else:
|
1055
|
-
logger.error(
|
1056
|
-
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
1057
|
-
)
|
1058
|
-
except Exception as e:
|
1059
|
-
logger.error(
|
1060
|
-
f"Prefill instance failed to register to bootstrap server: {e}"
|
1061
|
-
)
|
1062
|
-
|
1063
937
|
def _handle_node_failure(self, failed_bootstrap_addr):
|
1064
938
|
with self.connection_lock:
|
1065
939
|
keys_to_remove = [
|
@@ -1098,7 +972,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
1098
972
|
)
|
1099
973
|
|
1100
974
|
|
1101
|
-
class MooncakeKVSender(
|
975
|
+
class MooncakeKVSender(CommonKVSender):
|
1102
976
|
|
1103
977
|
def __init__(
|
1104
978
|
self,
|
@@ -1108,19 +982,9 @@ class MooncakeKVSender(BaseKVSender):
|
|
1108
982
|
dest_tp_ranks: List[int],
|
1109
983
|
pp_rank: int,
|
1110
984
|
):
|
1111
|
-
|
1112
|
-
self.bootstrap_room = bootstrap_room
|
1113
|
-
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
1114
|
-
self.aux_index = None
|
1115
|
-
self.bootstrap_server_url = bootstrap_addr
|
985
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
1116
986
|
self.conclude_state = None
|
1117
987
|
self.init_time = time.time()
|
1118
|
-
# inner state
|
1119
|
-
self.curr_idx = 0
|
1120
|
-
|
1121
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
1122
|
-
self.num_kv_indices = num_kv_indices
|
1123
|
-
self.aux_index = aux_index
|
1124
988
|
|
1125
989
|
def send(
|
1126
990
|
self,
|
@@ -1198,7 +1062,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
1198
1062
|
self.conclude_state = KVPoll.Failed
|
1199
1063
|
|
1200
1064
|
|
1201
|
-
class MooncakeKVReceiver(
|
1065
|
+
class MooncakeKVReceiver(CommonKVReceiver):
|
1202
1066
|
_ctx = zmq.Context()
|
1203
1067
|
_socket_cache = {}
|
1204
1068
|
_socket_locks = {}
|
@@ -1209,166 +1073,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1209
1073
|
mgr: MooncakeKVManager,
|
1210
1074
|
bootstrap_addr: str,
|
1211
1075
|
bootstrap_room: Optional[int] = None,
|
1212
|
-
|
1076
|
+
prefill_dp_rank: Optional[int] = None,
|
1213
1077
|
):
|
1214
|
-
self.
|
1215
|
-
self.bootstrap_addr = bootstrap_addr
|
1216
|
-
self.kv_mgr = mgr
|
1217
|
-
self.session_id = self.kv_mgr.get_session_id()
|
1218
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
1078
|
+
self.session_id = mgr.get_session_id()
|
1219
1079
|
self.conclude_state = None
|
1220
1080
|
self.init_time = None
|
1221
|
-
|
1081
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
1222
1082
|
|
1223
|
-
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
1224
|
-
(
|
1225
|
-
self.prefill_attn_tp_size,
|
1226
|
-
self.prefill_dp_size,
|
1227
|
-
self.prefill_pp_size,
|
1228
|
-
) = self._get_prefill_parallel_info_from_server()
|
1229
|
-
if (
|
1230
|
-
self.prefill_attn_tp_size is None
|
1231
|
-
or self.prefill_dp_size is None
|
1232
|
-
or self.prefill_pp_size is None
|
1233
|
-
):
|
1234
|
-
self.kv_mgr.record_failure(
|
1235
|
-
self.bootstrap_room,
|
1236
|
-
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
1237
|
-
)
|
1238
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1239
|
-
return
|
1240
|
-
else:
|
1241
|
-
logger.debug(
|
1242
|
-
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
1243
|
-
)
|
1244
|
-
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
1245
|
-
self.prefill_attn_tp_size
|
1246
|
-
)
|
1247
|
-
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
1248
|
-
self.prefill_dp_size
|
1249
|
-
)
|
1250
|
-
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
1251
|
-
self.prefill_pp_size
|
1252
|
-
)
|
1253
|
-
else:
|
1254
|
-
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
1255
|
-
self.bootstrap_addr
|
1256
|
-
]
|
1257
|
-
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
1258
|
-
self.bootstrap_addr
|
1259
|
-
]
|
1260
|
-
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
1261
|
-
self.bootstrap_addr
|
1262
|
-
]
|
1263
|
-
|
1264
|
-
# Currently, we don't allow prefill instance and decode instance to
|
1265
|
-
# have different TP sizes per DP rank, except for models using MLA.
|
1266
|
-
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
1267
|
-
self.target_tp_rank = (
|
1268
|
-
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1269
|
-
)
|
1270
|
-
self.required_dst_info_num = 1
|
1271
|
-
self.required_prefill_response_num = 1 * (
|
1272
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1273
|
-
)
|
1274
|
-
self.target_tp_ranks = [self.target_tp_rank]
|
1275
|
-
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
1276
|
-
if not self.kv_mgr.is_mla_backend:
|
1277
|
-
logger.warning_once(
|
1278
|
-
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1279
|
-
)
|
1280
|
-
self.target_tp_rank = (
|
1281
|
-
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1282
|
-
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
1283
|
-
self.required_dst_info_num = (
|
1284
|
-
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
1285
|
-
)
|
1286
|
-
self.required_prefill_response_num = 1 * (
|
1287
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1288
|
-
)
|
1289
|
-
self.target_tp_ranks = [self.target_tp_rank]
|
1290
|
-
else:
|
1291
|
-
if not self.kv_mgr.is_mla_backend:
|
1292
|
-
logger.warning_once(
|
1293
|
-
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1294
|
-
)
|
1295
|
-
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
1296
|
-
self.target_tp_ranks = [
|
1297
|
-
rank
|
1298
|
-
for rank in range(
|
1299
|
-
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
1300
|
-
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1301
|
-
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
1302
|
-
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1303
|
-
)
|
1304
|
-
]
|
1305
|
-
|
1306
|
-
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
1307
|
-
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
1308
|
-
# or the KVPoll will never be set correctly
|
1309
|
-
self.target_tp_rank = self.target_tp_ranks[0]
|
1310
|
-
self.required_dst_info_num = 1
|
1311
|
-
if self.kv_mgr.is_mla_backend:
|
1312
|
-
self.required_prefill_response_num = (
|
1313
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1314
|
-
)
|
1315
|
-
else:
|
1316
|
-
self.required_prefill_response_num = (
|
1317
|
-
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1318
|
-
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
1319
|
-
|
1320
|
-
if self.data_parallel_rank is not None:
|
1321
|
-
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
1322
|
-
self.target_dp_group = self.data_parallel_rank
|
1323
|
-
else:
|
1324
|
-
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
1325
|
-
|
1326
|
-
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
1327
|
-
self.required_prefill_response_num
|
1328
|
-
)
|
1329
|
-
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
1330
|
-
bootstrap_key = (
|
1331
|
-
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
1332
|
-
)
|
1333
|
-
|
1334
|
-
if bootstrap_key not in self.kv_mgr.connection_pool:
|
1335
|
-
bootstrap_infos = []
|
1336
|
-
for target_tp_rank in self.target_tp_ranks:
|
1337
|
-
for target_pp_rank in range(self.prefill_pp_size):
|
1338
|
-
bootstrap_info = self._get_bootstrap_info_from_server(
|
1339
|
-
target_tp_rank, self.target_dp_group, target_pp_rank
|
1340
|
-
)
|
1341
|
-
if bootstrap_info is not None:
|
1342
|
-
if self.kv_mgr.is_mla_backend:
|
1343
|
-
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1344
|
-
bootstrap_info["is_dummy"] = not bool(
|
1345
|
-
target_tp_rank == self.target_tp_rank
|
1346
|
-
or self.target_tp_rank is None
|
1347
|
-
)
|
1348
|
-
else:
|
1349
|
-
# For non-MLA: all target_tp_ranks are selected real ranks
|
1350
|
-
bootstrap_info["is_dummy"] = False
|
1351
|
-
logger.debug(
|
1352
|
-
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
1353
|
-
)
|
1354
|
-
bootstrap_infos.append(bootstrap_info)
|
1355
|
-
else:
|
1356
|
-
self.kv_mgr.record_failure(
|
1357
|
-
self.bootstrap_room,
|
1358
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
1359
|
-
)
|
1360
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1361
|
-
return
|
1362
|
-
|
1363
|
-
self.bootstrap_infos = bootstrap_infos
|
1364
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
1365
|
-
|
1366
|
-
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
1367
|
-
self._register_kv_args()
|
1368
|
-
else:
|
1369
|
-
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
1370
|
-
|
1371
|
-
assert len(self.bootstrap_infos) > 0
|
1372
1083
|
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
|
1373
1084
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
|
1374
1085
|
|
@@ -1391,29 +1102,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1391
1102
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
1392
1103
|
return None
|
1393
1104
|
|
1394
|
-
def _get_prefill_parallel_info_from_server(
|
1395
|
-
self,
|
1396
|
-
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
1397
|
-
"""Fetch the prefill parallel info from the bootstrap server."""
|
1398
|
-
try:
|
1399
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
1400
|
-
response = requests.get(url)
|
1401
|
-
if response.status_code == 200:
|
1402
|
-
prefill_parallel_info = response.json()
|
1403
|
-
return (
|
1404
|
-
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
1405
|
-
int(prefill_parallel_info["prefill_dp_size"]),
|
1406
|
-
int(prefill_parallel_info["prefill_pp_size"]),
|
1407
|
-
)
|
1408
|
-
else:
|
1409
|
-
logger.error(
|
1410
|
-
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
1411
|
-
)
|
1412
|
-
return None, None, None
|
1413
|
-
except Exception as e:
|
1414
|
-
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
1415
|
-
return None, None, None
|
1416
|
-
|
1417
1105
|
def _register_kv_args(self):
|
1418
1106
|
for bootstrap_info in self.bootstrap_infos:
|
1419
1107
|
packed_kv_data_ptrs = b"".join(
|
@@ -1445,28 +1133,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1445
1133
|
]
|
1446
1134
|
)
|
1447
1135
|
|
1448
|
-
@classmethod
|
1449
|
-
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
1450
|
-
with cls._global_lock:
|
1451
|
-
if endpoint not in cls._socket_cache:
|
1452
|
-
sock = cls._ctx.socket(zmq.PUSH)
|
1453
|
-
if is_ipv6:
|
1454
|
-
sock.setsockopt(zmq.IPV6, 1)
|
1455
|
-
sock.connect(endpoint)
|
1456
|
-
cls._socket_cache[endpoint] = sock
|
1457
|
-
cls._socket_locks[endpoint] = threading.Lock()
|
1458
|
-
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
1459
|
-
|
1460
|
-
@classmethod
|
1461
|
-
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
1462
|
-
ip_address = bootstrap_info["rank_ip"]
|
1463
|
-
port = bootstrap_info["rank_port"]
|
1464
|
-
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
1465
|
-
sock, lock = cls._connect(
|
1466
|
-
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
1467
|
-
)
|
1468
|
-
return sock, lock
|
1469
|
-
|
1470
1136
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
1471
1137
|
for bootstrap_info in self.bootstrap_infos:
|
1472
1138
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
@@ -1544,153 +1210,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1544
1210
|
self.conclude_state = KVPoll.Failed
|
1545
1211
|
|
1546
1212
|
|
1547
|
-
class MooncakeKVBootstrapServer(
|
1548
|
-
|
1549
|
-
self.port = port
|
1550
|
-
self.app = web.Application()
|
1551
|
-
self.store = dict()
|
1552
|
-
self.lock = asyncio.Lock()
|
1553
|
-
self._setup_routes()
|
1554
|
-
self.pp_size = None
|
1555
|
-
self.attn_tp_size = None
|
1556
|
-
self.dp_size = None
|
1557
|
-
self.prefill_port_table: Dict[
|
1558
|
-
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
1559
|
-
] = {}
|
1560
|
-
|
1561
|
-
# Start bootstrap server
|
1562
|
-
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
1563
|
-
self.run()
|
1564
|
-
|
1565
|
-
def run(self):
|
1566
|
-
self.thread.start()
|
1567
|
-
|
1568
|
-
def _setup_routes(self):
|
1569
|
-
self.app.router.add_route("*", "/route", self._handle_route)
|
1570
|
-
self.app.router.add_get("/health", self._handle_health_check)
|
1571
|
-
|
1572
|
-
async def _handle_health_check(self, request):
|
1573
|
-
return web.Response(text="OK", status=200)
|
1574
|
-
|
1575
|
-
async def _handle_route(self, request: web.Request):
|
1576
|
-
method = request.method
|
1577
|
-
if method == "PUT":
|
1578
|
-
return await self._handle_route_put(request)
|
1579
|
-
elif method == "GET":
|
1580
|
-
return await self._handle_route_get(request)
|
1581
|
-
else:
|
1582
|
-
return web.Response(
|
1583
|
-
text="Method not allowed", status=405, content_type="application/json"
|
1584
|
-
)
|
1585
|
-
|
1586
|
-
async def _handle_route_put(self, request: web.Request):
|
1587
|
-
data = await request.json()
|
1588
|
-
role = data["role"]
|
1589
|
-
attn_tp_size = data["attn_tp_size"]
|
1590
|
-
attn_tp_rank = data["attn_tp_rank"]
|
1591
|
-
attn_dp_size = data["attn_dp_size"]
|
1592
|
-
attn_dp_rank = data["attn_dp_rank"]
|
1593
|
-
pp_size = data["pp_size"]
|
1594
|
-
pp_rank = data["pp_rank"]
|
1595
|
-
system_dp_size = data["system_dp_size"]
|
1596
|
-
system_dp_rank = data["system_dp_rank"]
|
1597
|
-
rank_ip = data["rank_ip"]
|
1598
|
-
rank_port = int(data["rank_port"])
|
1599
|
-
|
1600
|
-
if self.attn_tp_size is None:
|
1601
|
-
self.attn_tp_size = attn_tp_size
|
1602
|
-
|
1603
|
-
if self.dp_size is None:
|
1604
|
-
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
1605
|
-
|
1606
|
-
if self.pp_size is None:
|
1607
|
-
self.pp_size = pp_size
|
1608
|
-
|
1609
|
-
if role == "Prefill":
|
1610
|
-
if system_dp_size == 1:
|
1611
|
-
dp_group = attn_dp_rank
|
1612
|
-
else:
|
1613
|
-
dp_group = system_dp_rank
|
1614
|
-
|
1615
|
-
# Add lock to make sure thread-safe
|
1616
|
-
async with self.lock:
|
1617
|
-
if dp_group not in self.prefill_port_table:
|
1618
|
-
self.prefill_port_table[dp_group] = {}
|
1619
|
-
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
1620
|
-
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
1621
|
-
|
1622
|
-
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
1623
|
-
"rank_ip": rank_ip,
|
1624
|
-
"rank_port": rank_port,
|
1625
|
-
}
|
1626
|
-
logger.debug(
|
1627
|
-
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
1628
|
-
)
|
1629
|
-
|
1630
|
-
return web.Response(text="OK", status=200)
|
1631
|
-
|
1632
|
-
async def _handle_route_get(self, request: web.Request):
|
1633
|
-
engine_rank = request.query.get("engine_rank")
|
1634
|
-
target_dp_group = request.query.get("target_dp_group")
|
1635
|
-
target_pp_rank = request.query.get("target_pp_rank")
|
1636
|
-
if not engine_rank or not target_dp_group or not target_pp_rank:
|
1637
|
-
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
1638
|
-
|
1639
|
-
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
1640
|
-
if (
|
1641
|
-
int(engine_rank) == -1
|
1642
|
-
and int(target_dp_group) == -1
|
1643
|
-
and int(target_pp_rank) == -1
|
1644
|
-
):
|
1645
|
-
prefill_parallel_info = {
|
1646
|
-
"prefill_attn_tp_size": self.attn_tp_size,
|
1647
|
-
"prefill_dp_size": self.dp_size,
|
1648
|
-
"prefill_pp_size": self.pp_size,
|
1649
|
-
}
|
1650
|
-
return web.json_response(prefill_parallel_info, status=200)
|
1651
|
-
|
1652
|
-
# Find corresponding prefill info
|
1653
|
-
async with self.lock:
|
1654
|
-
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
1655
|
-
int(engine_rank)
|
1656
|
-
][int(target_pp_rank)]
|
1657
|
-
|
1658
|
-
if bootstrap_info is not None:
|
1659
|
-
return web.json_response(bootstrap_info, status=200)
|
1660
|
-
else:
|
1661
|
-
return web.Response(text="Bootstrap info not Found", status=404)
|
1662
|
-
|
1663
|
-
def _run_server(self):
|
1664
|
-
try:
|
1665
|
-
# Event Loop
|
1666
|
-
self._loop = asyncio.new_event_loop()
|
1667
|
-
asyncio.set_event_loop(self._loop)
|
1668
|
-
|
1669
|
-
access_log = None
|
1670
|
-
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
|
1671
|
-
access_log = self.app.logger
|
1672
|
-
|
1673
|
-
self._runner = web.AppRunner(self.app, access_log=access_log)
|
1674
|
-
self._loop.run_until_complete(self._runner.setup())
|
1675
|
-
|
1676
|
-
site = web.TCPSite(self._runner, port=self.port)
|
1677
|
-
self._loop.run_until_complete(site.start())
|
1678
|
-
self._loop.run_forever()
|
1679
|
-
except Exception as e:
|
1680
|
-
logger.error(f"Server error: {str(e)}")
|
1681
|
-
finally:
|
1682
|
-
# Cleanup
|
1683
|
-
self._loop.run_until_complete(self._runner.cleanup())
|
1684
|
-
self._loop.close()
|
1685
|
-
|
1686
|
-
def close(self):
|
1687
|
-
"""Shutdown"""
|
1688
|
-
if self._loop is not None and self._loop.is_running():
|
1689
|
-
self._loop.call_soon_threadsafe(self._loop.stop)
|
1690
|
-
logger.info("Stopping server loop...")
|
1691
|
-
|
1692
|
-
if self.thread.is_alive():
|
1693
|
-
self.thread.join(timeout=2)
|
1694
|
-
logger.info("Server thread stopped")
|
1695
|
-
|
1696
|
-
def poll(self) -> KVPoll: ...
|
1213
|
+
class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
|
1214
|
+
pass
|