sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
|
|
22
22
|
KVPoll,
|
23
23
|
)
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
25
|
+
from sglang.srt.distributed import get_pp_group
|
26
|
+
from sglang.srt.layers.dp_attention import (
|
27
|
+
get_attention_dp_rank,
|
28
|
+
get_attention_dp_size,
|
29
|
+
get_attention_tp_rank,
|
30
|
+
get_attention_tp_size,
|
31
|
+
)
|
25
32
|
from sglang.srt.server_args import ServerArgs
|
26
33
|
from sglang.srt.utils import (
|
27
34
|
format_tcp_address,
|
28
35
|
get_free_port,
|
29
|
-
|
30
|
-
get_local_ip_by_remote,
|
36
|
+
get_local_ip_auto,
|
31
37
|
is_valid_ipv6_address,
|
32
38
|
maybe_wrap_ipv6_address,
|
33
39
|
)
|
@@ -47,31 +53,52 @@ class CommonKVManager(BaseKVManager):
|
|
47
53
|
self.is_mla_backend = is_mla_backend
|
48
54
|
self.disaggregation_mode = disaggregation_mode
|
49
55
|
# for p/d multi node infer
|
56
|
+
self.bootstrap_host = server_args.host
|
50
57
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
51
58
|
self.dist_init_addr = server_args.dist_init_addr
|
52
|
-
self.
|
53
|
-
self.
|
54
|
-
self.
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
59
|
+
self.attn_tp_size = get_attention_tp_size()
|
60
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
61
|
+
self.attn_dp_size = get_attention_dp_size()
|
62
|
+
self.attn_dp_rank = get_attention_dp_rank()
|
63
|
+
self.system_dp_size = (
|
64
|
+
1 if server_args.enable_dp_attention else server_args.dp_size
|
65
|
+
)
|
66
|
+
self.system_dp_rank = (
|
67
|
+
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
68
|
+
)
|
69
|
+
self.pp_size = server_args.pp_size
|
70
|
+
self.pp_rank = self.kv_args.pp_rank
|
60
71
|
self.rank_port = get_free_port()
|
72
|
+
self.local_ip = get_local_ip_auto()
|
73
|
+
self.server_socket = zmq.Context().socket(zmq.PULL)
|
74
|
+
if is_valid_ipv6_address(self.local_ip):
|
75
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
76
|
+
self.request_status: Dict[int, KVPoll] = {}
|
77
|
+
|
61
78
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
62
79
|
self._register_to_bootstrap()
|
80
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
81
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
82
|
+
self.pp_group = get_pp_group()
|
63
83
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
64
84
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
65
|
-
self.
|
85
|
+
self.connection_lock = threading.Lock()
|
86
|
+
self.required_prefill_response_num_table: Dict[int, int] = {}
|
87
|
+
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
66
88
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
89
|
+
self.prefill_pp_size_table: Dict[str, int] = {}
|
67
90
|
else:
|
68
91
|
raise ValueError(
|
69
92
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
70
93
|
)
|
71
94
|
|
95
|
+
def _bind_server_socket(self):
|
96
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
97
|
+
|
72
98
|
def _register_to_bootstrap(self):
|
73
99
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
74
100
|
if self.dist_init_addr:
|
101
|
+
# Multi-node case: bootstrap server's host is dist_init_addr
|
75
102
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
103
|
if self.dist_init_addr.endswith("]"):
|
77
104
|
host = self.dist_init_addr
|
@@ -80,30 +107,38 @@ class CommonKVManager(BaseKVManager):
|
|
80
107
|
else:
|
81
108
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
82
109
|
else:
|
83
|
-
host
|
110
|
+
# Single-node case: bootstrap server's host is the same as http server's host
|
111
|
+
host = self.bootstrap_host
|
84
112
|
host = maybe_wrap_ipv6_address(host)
|
85
113
|
|
86
114
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
87
115
|
url = f"http://{bootstrap_server_url}/route"
|
88
116
|
payload = {
|
89
117
|
"role": "Prefill",
|
90
|
-
"
|
91
|
-
"
|
92
|
-
"
|
118
|
+
"attn_tp_size": self.attn_tp_size,
|
119
|
+
"attn_tp_rank": self.attn_tp_rank,
|
120
|
+
"attn_dp_size": self.attn_dp_size,
|
121
|
+
"attn_dp_rank": self.attn_dp_rank,
|
122
|
+
"pp_size": self.pp_size,
|
123
|
+
"pp_rank": self.pp_rank,
|
124
|
+
"system_dp_size": self.system_dp_size,
|
125
|
+
"system_dp_rank": self.system_dp_rank,
|
126
|
+
"rank_ip": self.local_ip,
|
93
127
|
"rank_port": self.rank_port,
|
94
|
-
"engine_rank": self.kv_args.engine_rank,
|
95
128
|
}
|
96
129
|
|
97
130
|
try:
|
98
|
-
response = requests.put(url, json=payload)
|
131
|
+
response = requests.put(url, json=payload, timeout=5)
|
99
132
|
if response.status_code == 200:
|
100
133
|
logger.debug("Prefill successfully registered to bootstrap server.")
|
101
134
|
else:
|
102
135
|
logger.error(
|
103
|
-
f"Prefill
|
136
|
+
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
104
137
|
)
|
105
138
|
except Exception as e:
|
106
|
-
logger.error(
|
139
|
+
logger.error(
|
140
|
+
f"Prefill instance failed to register to bootstrap server: {e}"
|
141
|
+
)
|
107
142
|
|
108
143
|
@cache
|
109
144
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
@@ -113,6 +148,68 @@ class CommonKVManager(BaseKVManager):
|
|
113
148
|
socket.connect(endpoint)
|
114
149
|
return socket
|
115
150
|
|
151
|
+
def get_mha_kv_ptrs_with_pp(
|
152
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
153
|
+
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
154
|
+
# pp is not supported on the decode side yet
|
155
|
+
start_layer = self.kv_args.prefill_start_layer
|
156
|
+
num_kv_layers = len(src_kv_ptrs) // 2
|
157
|
+
end_layer = start_layer + num_kv_layers
|
158
|
+
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
159
|
+
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
160
|
+
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
161
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
162
|
+
dst_v_ptrs = dst_kv_ptrs[
|
163
|
+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
164
|
+
]
|
165
|
+
layers_current_pp_stage = len(src_k_ptrs)
|
166
|
+
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
167
|
+
|
168
|
+
def get_mla_kv_ptrs_with_pp(
|
169
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
170
|
+
) -> Tuple[List[int], List[int], int]:
|
171
|
+
# pp is not supported on the decode side yet
|
172
|
+
start_layer = self.kv_args.prefill_start_layer
|
173
|
+
end_layer = start_layer + len(src_kv_ptrs)
|
174
|
+
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
175
|
+
layers_current_pp_stage = len(src_kv_ptrs)
|
176
|
+
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
177
|
+
|
178
|
+
|
179
|
+
class CommonKVSender(BaseKVSender):
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
mgr: BaseKVManager,
|
184
|
+
bootstrap_addr: str,
|
185
|
+
bootstrap_room: int,
|
186
|
+
dest_tp_ranks: List[int],
|
187
|
+
pp_rank: int,
|
188
|
+
):
|
189
|
+
self.kv_mgr = mgr
|
190
|
+
self.bootstrap_room = bootstrap_room
|
191
|
+
self.aux_index = None
|
192
|
+
self.bootstrap_server_url = bootstrap_addr
|
193
|
+
# inner state
|
194
|
+
self.curr_idx = 0
|
195
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
196
|
+
|
197
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
198
|
+
self.num_kv_indices = num_kv_indices
|
199
|
+
self.aux_index = aux_index
|
200
|
+
|
201
|
+
def send(
|
202
|
+
self,
|
203
|
+
kv_indices: npt.NDArray[np.int32],
|
204
|
+
):
|
205
|
+
pass
|
206
|
+
|
207
|
+
def poll(self) -> KVPoll:
|
208
|
+
pass
|
209
|
+
|
210
|
+
def failure_exception(self):
|
211
|
+
raise Exception("Fake KVReceiver Exception")
|
212
|
+
|
116
213
|
|
117
214
|
class CommonKVReceiver(BaseKVReceiver):
|
118
215
|
_ctx = zmq.Context()
|
@@ -125,70 +222,93 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
125
222
|
mgr: BaseKVManager,
|
126
223
|
bootstrap_addr: str,
|
127
224
|
bootstrap_room: Optional[int] = None,
|
128
|
-
|
225
|
+
prefill_dp_rank: Optional[int] = None,
|
129
226
|
):
|
130
227
|
self.bootstrap_room = bootstrap_room
|
131
228
|
self.bootstrap_addr = bootstrap_addr
|
132
229
|
self.kv_mgr = mgr
|
133
|
-
self.
|
230
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
134
231
|
|
135
232
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
136
|
-
|
137
|
-
self.
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
233
|
+
(
|
234
|
+
self.prefill_attn_tp_size,
|
235
|
+
self.prefill_dp_size,
|
236
|
+
self.prefill_pp_size,
|
237
|
+
) = self._get_prefill_parallel_info_from_server()
|
238
|
+
if (
|
239
|
+
self.prefill_attn_tp_size is None
|
240
|
+
or self.prefill_dp_size is None
|
241
|
+
or self.prefill_pp_size is None
|
242
|
+
):
|
243
|
+
self.kv_mgr.record_failure(
|
244
|
+
self.bootstrap_room,
|
245
|
+
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
142
246
|
)
|
247
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
248
|
+
return
|
143
249
|
else:
|
144
|
-
|
145
|
-
self.
|
250
|
+
logger.debug(
|
251
|
+
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
252
|
+
)
|
253
|
+
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
254
|
+
self.prefill_attn_tp_size
|
146
255
|
)
|
147
256
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
148
257
|
self.prefill_dp_size
|
149
258
|
)
|
259
|
+
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
260
|
+
self.prefill_pp_size
|
261
|
+
)
|
150
262
|
else:
|
151
|
-
self.
|
263
|
+
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
152
264
|
self.bootstrap_addr
|
153
265
|
]
|
154
266
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
155
267
|
self.bootstrap_addr
|
156
268
|
]
|
269
|
+
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
270
|
+
self.bootstrap_addr
|
271
|
+
]
|
157
272
|
|
158
273
|
# Currently, we don't allow prefill instance and decode instance to
|
159
274
|
# have different TP sizes per DP rank, except for models using MLA.
|
160
|
-
|
161
|
-
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
162
|
-
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
275
|
+
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
163
276
|
self.target_tp_rank = (
|
164
|
-
self.kv_mgr.kv_args.engine_rank %
|
277
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
165
278
|
)
|
166
279
|
self.required_dst_info_num = 1
|
280
|
+
self.required_prefill_response_num = 1 * (
|
281
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
282
|
+
)
|
167
283
|
self.target_tp_ranks = [self.target_tp_rank]
|
168
|
-
elif
|
169
|
-
|
170
|
-
|
171
|
-
|
284
|
+
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
285
|
+
if not self.kv_mgr.is_mla_backend:
|
286
|
+
logger.warning_once(
|
287
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
288
|
+
)
|
172
289
|
self.target_tp_rank = (
|
173
|
-
self.kv_mgr.kv_args.engine_rank %
|
174
|
-
) // (
|
290
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
291
|
+
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
175
292
|
self.required_dst_info_num = (
|
176
|
-
|
293
|
+
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
294
|
+
)
|
295
|
+
self.required_prefill_response_num = 1 * (
|
296
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
177
297
|
)
|
178
298
|
self.target_tp_ranks = [self.target_tp_rank]
|
179
299
|
else:
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
300
|
+
if not self.kv_mgr.is_mla_backend:
|
301
|
+
logger.warning_once(
|
302
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
303
|
+
)
|
184
304
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
185
305
|
self.target_tp_ranks = [
|
186
306
|
rank
|
187
307
|
for rank in range(
|
188
|
-
(self.kv_mgr.kv_args.engine_rank %
|
189
|
-
* (
|
190
|
-
(self.kv_mgr.kv_args.engine_rank %
|
191
|
-
* (
|
308
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
309
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
310
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
311
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
192
312
|
)
|
193
313
|
]
|
194
314
|
|
@@ -197,13 +317,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
197
317
|
# or the KVPoll will never be set correctly
|
198
318
|
self.target_tp_rank = self.target_tp_ranks[0]
|
199
319
|
self.required_dst_info_num = 1
|
320
|
+
if self.kv_mgr.is_mla_backend:
|
321
|
+
self.required_prefill_response_num = (
|
322
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
self.required_prefill_response_num = (
|
326
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
327
|
+
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
200
328
|
|
201
|
-
if
|
202
|
-
logger.debug(f"Targeting DP rank: {
|
203
|
-
self.
|
329
|
+
if prefill_dp_rank is not None:
|
330
|
+
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
331
|
+
self.prefill_dp_rank = prefill_dp_rank
|
204
332
|
else:
|
205
|
-
self.
|
333
|
+
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
334
|
+
|
335
|
+
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
336
|
+
self.target_dp_group = self.prefill_dp_rank
|
206
337
|
|
338
|
+
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
339
|
+
self.required_prefill_response_num
|
340
|
+
)
|
207
341
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
208
342
|
bootstrap_key = (
|
209
343
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
@@ -212,41 +346,49 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
212
346
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
213
347
|
bootstrap_infos = []
|
214
348
|
for target_tp_rank in self.target_tp_ranks:
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
)
|
219
|
-
if bootstrap_info is not None:
|
220
|
-
# NOTE: only support MLA for now: select one prefill rank as real rank
|
221
|
-
bootstrap_info["is_dummy"] = not bool(
|
222
|
-
target_tp_rank == self.target_tp_rank
|
223
|
-
or self.target_tp_rank is None
|
224
|
-
)
|
225
|
-
bootstrap_infos.append(bootstrap_info)
|
226
|
-
else:
|
227
|
-
logger.error(
|
228
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
349
|
+
for target_pp_rank in range(self.prefill_pp_size):
|
350
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
351
|
+
target_tp_rank, self.target_dp_group, target_pp_rank
|
229
352
|
)
|
353
|
+
if bootstrap_info is not None:
|
354
|
+
if self.kv_mgr.is_mla_backend:
|
355
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
356
|
+
bootstrap_info["is_dummy"] = not bool(
|
357
|
+
target_tp_rank == self.target_tp_rank
|
358
|
+
or self.target_tp_rank is None
|
359
|
+
)
|
360
|
+
else:
|
361
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
362
|
+
bootstrap_info["is_dummy"] = False
|
363
|
+
logger.debug(
|
364
|
+
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
365
|
+
)
|
366
|
+
bootstrap_infos.append(bootstrap_info)
|
367
|
+
else:
|
368
|
+
self.kv_mgr.record_failure(
|
369
|
+
self.bootstrap_room,
|
370
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
371
|
+
)
|
372
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
373
|
+
return
|
374
|
+
|
230
375
|
self.bootstrap_infos = bootstrap_infos
|
376
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
231
377
|
|
232
|
-
|
233
|
-
|
234
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
235
|
-
)
|
236
|
-
else:
|
237
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
238
|
-
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
239
|
-
self._register_kv_args()
|
378
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
379
|
+
self._register_kv_args()
|
240
380
|
else:
|
241
381
|
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
242
382
|
|
243
383
|
assert len(self.bootstrap_infos) > 0
|
244
384
|
|
245
|
-
def _get_bootstrap_info_from_server(
|
385
|
+
def _get_bootstrap_info_from_server(
|
386
|
+
self, engine_rank, target_dp_group, target_pp_rank
|
387
|
+
):
|
246
388
|
"""Fetch the bootstrap info from the bootstrap server."""
|
247
389
|
try:
|
248
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
249
|
-
response = requests.get(url)
|
390
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
|
391
|
+
response = requests.get(url, timeout=5)
|
250
392
|
if response.status_code == 200:
|
251
393
|
bootstrap_info = response.json()
|
252
394
|
return bootstrap_info
|
@@ -259,24 +401,28 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
259
401
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
260
402
|
return None
|
261
403
|
|
262
|
-
def
|
404
|
+
def _get_prefill_parallel_info_from_server(
|
405
|
+
self,
|
406
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
263
407
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
264
408
|
try:
|
265
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
409
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
266
410
|
response = requests.get(url)
|
267
411
|
if response.status_code == 200:
|
268
412
|
prefill_parallel_info = response.json()
|
269
|
-
return
|
270
|
-
prefill_parallel_info["
|
413
|
+
return (
|
414
|
+
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
415
|
+
int(prefill_parallel_info["prefill_dp_size"]),
|
416
|
+
int(prefill_parallel_info["prefill_pp_size"]),
|
271
417
|
)
|
272
418
|
else:
|
273
419
|
logger.error(
|
274
420
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
275
421
|
)
|
276
|
-
return None
|
422
|
+
return None, None, None
|
277
423
|
except Exception as e:
|
278
424
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
279
|
-
return None
|
425
|
+
return None, None, None
|
280
426
|
|
281
427
|
@classmethod
|
282
428
|
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
@@ -308,16 +454,19 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
308
454
|
|
309
455
|
|
310
456
|
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
311
|
-
def __init__(self, port: int):
|
457
|
+
def __init__(self, host: str, port: int):
|
458
|
+
self.host = host
|
312
459
|
self.port = port
|
313
460
|
self.app = web.Application()
|
314
461
|
self.store = dict()
|
315
462
|
self.lock = asyncio.Lock()
|
316
463
|
self._setup_routes()
|
317
|
-
self.
|
464
|
+
self.pp_size = None
|
465
|
+
self.attn_tp_size = None
|
318
466
|
self.dp_size = None
|
319
|
-
self.
|
320
|
-
|
467
|
+
self.prefill_port_table: Dict[
|
468
|
+
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
469
|
+
] = {}
|
321
470
|
|
322
471
|
# Start bootstrap server
|
323
472
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -328,6 +477,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
328
477
|
|
329
478
|
def _setup_routes(self):
|
330
479
|
self.app.router.add_route("*", "/route", self._handle_route)
|
480
|
+
self.app.router.add_get("/health", self._handle_health_check)
|
481
|
+
|
482
|
+
async def _handle_health_check(self, request):
|
483
|
+
return web.Response(text="OK", status=200)
|
331
484
|
|
332
485
|
async def _handle_route(self, request: web.Request):
|
333
486
|
method = request.method
|
@@ -343,37 +496,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
343
496
|
async def _handle_route_put(self, request: web.Request):
|
344
497
|
data = await request.json()
|
345
498
|
role = data["role"]
|
346
|
-
|
347
|
-
|
499
|
+
attn_tp_size = data["attn_tp_size"]
|
500
|
+
attn_tp_rank = data["attn_tp_rank"]
|
501
|
+
attn_dp_size = data["attn_dp_size"]
|
502
|
+
attn_dp_rank = data["attn_dp_rank"]
|
503
|
+
pp_size = data["pp_size"]
|
504
|
+
pp_rank = data["pp_rank"]
|
505
|
+
system_dp_size = data["system_dp_size"]
|
506
|
+
system_dp_rank = data["system_dp_rank"]
|
348
507
|
rank_ip = data["rank_ip"]
|
349
508
|
rank_port = int(data["rank_port"])
|
350
|
-
engine_rank = int(data["engine_rank"])
|
351
509
|
|
352
|
-
if self.
|
353
|
-
self.
|
510
|
+
if self.attn_tp_size is None:
|
511
|
+
self.attn_tp_size = attn_tp_size
|
354
512
|
|
355
513
|
if self.dp_size is None:
|
356
|
-
self.dp_size =
|
514
|
+
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
357
515
|
|
358
|
-
|
359
|
-
|
360
|
-
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
516
|
+
if self.pp_size is None:
|
517
|
+
self.pp_size = pp_size
|
361
518
|
|
362
|
-
# Add lock to make sure thread-safe
|
363
519
|
if role == "Prefill":
|
364
|
-
|
365
|
-
|
520
|
+
if system_dp_size == 1:
|
521
|
+
dp_group = attn_dp_rank
|
522
|
+
else:
|
523
|
+
dp_group = system_dp_rank
|
366
524
|
|
525
|
+
# Add lock to make sure thread-safe
|
367
526
|
async with self.lock:
|
368
527
|
if dp_group not in self.prefill_port_table:
|
369
528
|
self.prefill_port_table[dp_group] = {}
|
529
|
+
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
530
|
+
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
370
531
|
|
371
|
-
self.prefill_port_table[dp_group][
|
532
|
+
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
372
533
|
"rank_ip": rank_ip,
|
373
534
|
"rank_port": rank_port,
|
374
535
|
}
|
375
536
|
logger.debug(
|
376
|
-
f"Register
|
537
|
+
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
377
538
|
)
|
378
539
|
|
379
540
|
return web.Response(text="OK", status=200)
|
@@ -381,14 +542,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
381
542
|
async def _handle_route_get(self, request: web.Request):
|
382
543
|
engine_rank = request.query.get("engine_rank")
|
383
544
|
target_dp_group = request.query.get("target_dp_group")
|
384
|
-
|
545
|
+
target_pp_rank = request.query.get("target_pp_rank")
|
546
|
+
if not engine_rank or not target_dp_group or not target_pp_rank:
|
385
547
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
386
548
|
|
387
549
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
388
|
-
if
|
550
|
+
if (
|
551
|
+
int(engine_rank) == -1
|
552
|
+
and int(target_dp_group) == -1
|
553
|
+
and int(target_pp_rank) == -1
|
554
|
+
):
|
389
555
|
prefill_parallel_info = {
|
390
|
-
"
|
556
|
+
"prefill_attn_tp_size": self.attn_tp_size,
|
391
557
|
"prefill_dp_size": self.dp_size,
|
558
|
+
"prefill_pp_size": self.pp_size,
|
392
559
|
}
|
393
560
|
return web.json_response(prefill_parallel_info, status=200)
|
394
561
|
|
@@ -396,7 +563,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
396
563
|
async with self.lock:
|
397
564
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
398
565
|
int(engine_rank)
|
399
|
-
]
|
566
|
+
][int(target_pp_rank)]
|
400
567
|
|
401
568
|
if bootstrap_info is not None:
|
402
569
|
return web.json_response(bootstrap_info, status=200)
|
@@ -409,10 +576,14 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
409
576
|
self._loop = asyncio.new_event_loop()
|
410
577
|
asyncio.set_event_loop(self._loop)
|
411
578
|
|
412
|
-
|
579
|
+
access_log = None
|
580
|
+
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
|
581
|
+
access_log = self.app.logger
|
582
|
+
|
583
|
+
self._runner = web.AppRunner(self.app, access_log=access_log)
|
413
584
|
self._loop.run_until_complete(self._runner.setup())
|
414
585
|
|
415
|
-
site = web.TCPSite(self._runner, port=self.port)
|
586
|
+
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
416
587
|
self._loop.run_until_complete(site.start())
|
417
588
|
self._loop.run_forever()
|
418
589
|
except Exception as e:
|