sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
|
|
|
22
22
|
KVPoll,
|
|
23
23
|
)
|
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
25
|
+
from sglang.srt.distributed import get_pp_group
|
|
26
|
+
from sglang.srt.layers.dp_attention import (
|
|
27
|
+
get_attention_dp_rank,
|
|
28
|
+
get_attention_dp_size,
|
|
29
|
+
get_attention_tp_rank,
|
|
30
|
+
get_attention_tp_size,
|
|
31
|
+
)
|
|
25
32
|
from sglang.srt.server_args import ServerArgs
|
|
26
33
|
from sglang.srt.utils import (
|
|
27
34
|
format_tcp_address,
|
|
28
35
|
get_free_port,
|
|
29
|
-
|
|
30
|
-
get_local_ip_by_remote,
|
|
36
|
+
get_local_ip_auto,
|
|
31
37
|
is_valid_ipv6_address,
|
|
32
38
|
maybe_wrap_ipv6_address,
|
|
33
39
|
)
|
|
@@ -50,30 +56,49 @@ class CommonKVManager(BaseKVManager):
|
|
|
50
56
|
self.bootstrap_host = server_args.host
|
|
51
57
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
|
52
58
|
self.dist_init_addr = server_args.dist_init_addr
|
|
53
|
-
self.
|
|
54
|
-
self.
|
|
55
|
-
self.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
self.attn_tp_size = get_attention_tp_size()
|
|
60
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
|
61
|
+
self.attn_dp_size = get_attention_dp_size()
|
|
62
|
+
self.attn_dp_rank = get_attention_dp_rank()
|
|
63
|
+
self.system_dp_size = (
|
|
64
|
+
1 if server_args.enable_dp_attention else server_args.dp_size
|
|
65
|
+
)
|
|
66
|
+
self.system_dp_rank = (
|
|
67
|
+
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
|
68
|
+
)
|
|
69
|
+
self.pp_size = server_args.pp_size
|
|
70
|
+
self.pp_rank = self.kv_args.pp_rank
|
|
61
71
|
self.rank_port = get_free_port()
|
|
72
|
+
self.local_ip = get_local_ip_auto()
|
|
73
|
+
self.server_socket = zmq.Context().socket(zmq.PULL)
|
|
74
|
+
if is_valid_ipv6_address(self.local_ip):
|
|
75
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
|
76
|
+
self.request_status: Dict[int, KVPoll] = {}
|
|
77
|
+
|
|
62
78
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
63
79
|
self._register_to_bootstrap()
|
|
80
|
+
self.transfer_infos = {}
|
|
81
|
+
self.decode_kv_args_table = {}
|
|
82
|
+
self.pp_group = get_pp_group()
|
|
64
83
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
65
84
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
|
66
|
-
self.
|
|
85
|
+
self.connection_lock = threading.Lock()
|
|
86
|
+
self.required_prefill_response_num_table: Dict[int, int] = {}
|
|
87
|
+
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
|
67
88
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
|
89
|
+
self.prefill_pp_size_table: Dict[str, int] = {}
|
|
68
90
|
else:
|
|
69
91
|
raise ValueError(
|
|
70
92
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
|
71
93
|
)
|
|
72
94
|
|
|
95
|
+
def _bind_server_socket(self):
|
|
96
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
|
97
|
+
|
|
73
98
|
def _register_to_bootstrap(self):
|
|
74
99
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
|
75
100
|
if self.dist_init_addr:
|
|
76
|
-
#
|
|
101
|
+
# Multi-node case: bootstrap server's host is dist_init_addr
|
|
77
102
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
|
78
103
|
if self.dist_init_addr.endswith("]"):
|
|
79
104
|
host = self.dist_init_addr
|
|
@@ -82,7 +107,7 @@ class CommonKVManager(BaseKVManager):
|
|
|
82
107
|
else:
|
|
83
108
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
|
84
109
|
else:
|
|
85
|
-
#
|
|
110
|
+
# Single-node case: bootstrap server's host is the same as http server's host
|
|
86
111
|
host = self.bootstrap_host
|
|
87
112
|
host = maybe_wrap_ipv6_address(host)
|
|
88
113
|
|
|
@@ -90,23 +115,30 @@ class CommonKVManager(BaseKVManager):
|
|
|
90
115
|
url = f"http://{bootstrap_server_url}/route"
|
|
91
116
|
payload = {
|
|
92
117
|
"role": "Prefill",
|
|
93
|
-
"
|
|
94
|
-
"
|
|
95
|
-
"
|
|
118
|
+
"attn_tp_size": self.attn_tp_size,
|
|
119
|
+
"attn_tp_rank": self.attn_tp_rank,
|
|
120
|
+
"attn_dp_size": self.attn_dp_size,
|
|
121
|
+
"attn_dp_rank": self.attn_dp_rank,
|
|
122
|
+
"pp_size": self.pp_size,
|
|
123
|
+
"pp_rank": self.pp_rank,
|
|
124
|
+
"system_dp_size": self.system_dp_size,
|
|
125
|
+
"system_dp_rank": self.system_dp_rank,
|
|
126
|
+
"rank_ip": self.local_ip,
|
|
96
127
|
"rank_port": self.rank_port,
|
|
97
|
-
"engine_rank": self.kv_args.engine_rank,
|
|
98
128
|
}
|
|
99
129
|
|
|
100
130
|
try:
|
|
101
|
-
response = requests.put(url, json=payload)
|
|
131
|
+
response = requests.put(url, json=payload, timeout=5)
|
|
102
132
|
if response.status_code == 200:
|
|
103
133
|
logger.debug("Prefill successfully registered to bootstrap server.")
|
|
104
134
|
else:
|
|
105
135
|
logger.error(
|
|
106
|
-
f"Prefill
|
|
136
|
+
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
|
107
137
|
)
|
|
108
138
|
except Exception as e:
|
|
109
|
-
logger.error(
|
|
139
|
+
logger.error(
|
|
140
|
+
f"Prefill instance failed to register to bootstrap server: {e}"
|
|
141
|
+
)
|
|
110
142
|
|
|
111
143
|
@cache
|
|
112
144
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
|
@@ -116,6 +148,69 @@ class CommonKVManager(BaseKVManager):
|
|
|
116
148
|
socket.connect(endpoint)
|
|
117
149
|
return socket
|
|
118
150
|
|
|
151
|
+
def get_mha_kv_ptrs_with_pp(
|
|
152
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
|
153
|
+
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
|
154
|
+
# pp is not supported on the decode side yet
|
|
155
|
+
start_layer = self.kv_args.prefill_start_layer
|
|
156
|
+
num_kv_layers = len(src_kv_ptrs) // 2
|
|
157
|
+
end_layer = start_layer + num_kv_layers
|
|
158
|
+
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
|
159
|
+
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
|
160
|
+
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
|
161
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
162
|
+
dst_v_ptrs = dst_kv_ptrs[
|
|
163
|
+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
|
164
|
+
]
|
|
165
|
+
layers_current_pp_stage = len(src_k_ptrs)
|
|
166
|
+
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
|
167
|
+
|
|
168
|
+
def get_mla_kv_ptrs_with_pp(
|
|
169
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
|
170
|
+
) -> Tuple[List[int], List[int], int]:
|
|
171
|
+
# pp is not supported on the decode side yet
|
|
172
|
+
start_layer = self.kv_args.prefill_start_layer
|
|
173
|
+
end_layer = start_layer + len(src_kv_ptrs)
|
|
174
|
+
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
175
|
+
layers_current_pp_stage = len(src_kv_ptrs)
|
|
176
|
+
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class CommonKVSender(BaseKVSender):
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
mgr: BaseKVManager,
|
|
184
|
+
bootstrap_addr: str,
|
|
185
|
+
bootstrap_room: int,
|
|
186
|
+
dest_tp_ranks: List[int],
|
|
187
|
+
pp_rank: int,
|
|
188
|
+
):
|
|
189
|
+
self.kv_mgr = mgr
|
|
190
|
+
self.bootstrap_room = bootstrap_room
|
|
191
|
+
self.aux_index = None
|
|
192
|
+
self.bootstrap_server_url = bootstrap_addr
|
|
193
|
+
# inner state
|
|
194
|
+
self.curr_idx = 0
|
|
195
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
|
196
|
+
|
|
197
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
|
198
|
+
self.num_kv_indices = num_kv_indices
|
|
199
|
+
self.aux_index = aux_index
|
|
200
|
+
|
|
201
|
+
def send(
|
|
202
|
+
self,
|
|
203
|
+
kv_indices: npt.NDArray[np.int32],
|
|
204
|
+
state_indices: Optional[List[int]] = None,
|
|
205
|
+
):
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def poll(self) -> KVPoll:
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
def failure_exception(self):
|
|
212
|
+
raise Exception("Fake KVReceiver Exception")
|
|
213
|
+
|
|
119
214
|
|
|
120
215
|
class CommonKVReceiver(BaseKVReceiver):
|
|
121
216
|
_ctx = zmq.Context()
|
|
@@ -133,61 +228,89 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
|
133
228
|
self.bootstrap_room = bootstrap_room
|
|
134
229
|
self.bootstrap_addr = bootstrap_addr
|
|
135
230
|
self.kv_mgr = mgr
|
|
231
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
|
136
232
|
|
|
137
233
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
|
138
|
-
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
234
|
+
(
|
|
235
|
+
self.prefill_attn_tp_size,
|
|
236
|
+
self.prefill_dp_size,
|
|
237
|
+
self.prefill_pp_size,
|
|
238
|
+
) = self._get_prefill_parallel_info_from_server()
|
|
239
|
+
if (
|
|
240
|
+
self.prefill_attn_tp_size is None
|
|
241
|
+
or self.prefill_dp_size is None
|
|
242
|
+
or self.prefill_pp_size is None
|
|
243
|
+
):
|
|
244
|
+
self.kv_mgr.record_failure(
|
|
245
|
+
self.bootstrap_room,
|
|
246
|
+
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
|
144
247
|
)
|
|
248
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
|
249
|
+
self.bootstrap_infos = None
|
|
250
|
+
return
|
|
145
251
|
else:
|
|
146
|
-
|
|
147
|
-
self.
|
|
252
|
+
logger.debug(
|
|
253
|
+
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
|
254
|
+
)
|
|
255
|
+
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
|
256
|
+
self.prefill_attn_tp_size
|
|
148
257
|
)
|
|
149
258
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
|
150
259
|
self.prefill_dp_size
|
|
151
260
|
)
|
|
261
|
+
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
|
262
|
+
self.prefill_pp_size
|
|
263
|
+
)
|
|
152
264
|
else:
|
|
153
|
-
self.
|
|
265
|
+
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
|
154
266
|
self.bootstrap_addr
|
|
155
267
|
]
|
|
156
268
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
|
157
269
|
self.bootstrap_addr
|
|
158
270
|
]
|
|
271
|
+
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
|
272
|
+
self.bootstrap_addr
|
|
273
|
+
]
|
|
159
274
|
|
|
160
275
|
# Currently, we don't allow prefill instance and decode instance to
|
|
161
276
|
# have different TP sizes per DP rank, except for models using MLA.
|
|
162
|
-
|
|
163
|
-
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
|
164
|
-
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
|
277
|
+
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
|
165
278
|
self.target_tp_rank = (
|
|
166
|
-
self.kv_mgr.kv_args.engine_rank %
|
|
279
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
|
167
280
|
)
|
|
168
281
|
self.required_dst_info_num = 1
|
|
282
|
+
self.required_prefill_response_num = 1 * (
|
|
283
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
|
284
|
+
)
|
|
169
285
|
self.target_tp_ranks = [self.target_tp_rank]
|
|
170
|
-
elif
|
|
286
|
+
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
|
287
|
+
if not self.kv_mgr.is_mla_backend:
|
|
288
|
+
logger.warning_once(
|
|
289
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
|
290
|
+
)
|
|
171
291
|
self.target_tp_rank = (
|
|
172
|
-
self.kv_mgr.kv_args.engine_rank %
|
|
173
|
-
) // (
|
|
292
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
|
293
|
+
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
|
174
294
|
self.required_dst_info_num = (
|
|
175
|
-
|
|
295
|
+
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
|
296
|
+
)
|
|
297
|
+
self.required_prefill_response_num = 1 * (
|
|
298
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
|
176
299
|
)
|
|
177
300
|
self.target_tp_ranks = [self.target_tp_rank]
|
|
178
301
|
else:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
302
|
+
if not self.kv_mgr.is_mla_backend:
|
|
303
|
+
logger.warning_once(
|
|
304
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
|
305
|
+
)
|
|
183
306
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
|
184
307
|
self.target_tp_ranks = [
|
|
185
308
|
rank
|
|
186
309
|
for rank in range(
|
|
187
|
-
(self.kv_mgr.kv_args.engine_rank %
|
|
188
|
-
* (
|
|
189
|
-
(self.kv_mgr.kv_args.engine_rank %
|
|
190
|
-
* (
|
|
310
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
|
311
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
|
312
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
|
313
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
|
191
314
|
)
|
|
192
315
|
]
|
|
193
316
|
|
|
@@ -196,6 +319,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
|
196
319
|
# or the KVPoll will never be set correctly
|
|
197
320
|
self.target_tp_rank = self.target_tp_ranks[0]
|
|
198
321
|
self.required_dst_info_num = 1
|
|
322
|
+
if self.kv_mgr.is_mla_backend:
|
|
323
|
+
self.required_prefill_response_num = (
|
|
324
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
|
325
|
+
)
|
|
326
|
+
else:
|
|
327
|
+
self.required_prefill_response_num = (
|
|
328
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
|
329
|
+
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
|
199
330
|
|
|
200
331
|
if prefill_dp_rank is not None:
|
|
201
332
|
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
|
@@ -206,6 +337,9 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
|
206
337
|
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
|
207
338
|
self.target_dp_group = self.prefill_dp_rank
|
|
208
339
|
|
|
340
|
+
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
|
341
|
+
self.required_prefill_response_num
|
|
342
|
+
)
|
|
209
343
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
|
210
344
|
bootstrap_key = (
|
|
211
345
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
|
@@ -214,41 +348,49 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
|
214
348
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
|
215
349
|
bootstrap_infos = []
|
|
216
350
|
for target_tp_rank in self.target_tp_ranks:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
)
|
|
221
|
-
if bootstrap_info is not None:
|
|
222
|
-
# NOTE: only support MLA for now: select one prefill rank as real rank
|
|
223
|
-
bootstrap_info["is_dummy"] = not bool(
|
|
224
|
-
target_tp_rank == self.target_tp_rank
|
|
225
|
-
or self.target_tp_rank is None
|
|
226
|
-
)
|
|
227
|
-
bootstrap_infos.append(bootstrap_info)
|
|
228
|
-
else:
|
|
229
|
-
logger.error(
|
|
230
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
|
351
|
+
for target_pp_rank in range(self.prefill_pp_size):
|
|
352
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
|
353
|
+
target_tp_rank, self.target_dp_group, target_pp_rank
|
|
231
354
|
)
|
|
355
|
+
if bootstrap_info is not None:
|
|
356
|
+
if self.kv_mgr.is_mla_backend:
|
|
357
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
|
358
|
+
bootstrap_info["is_dummy"] = not bool(
|
|
359
|
+
target_tp_rank == self.target_tp_rank
|
|
360
|
+
or self.target_tp_rank is None
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
|
364
|
+
bootstrap_info["is_dummy"] = False
|
|
365
|
+
logger.debug(
|
|
366
|
+
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
|
367
|
+
)
|
|
368
|
+
bootstrap_infos.append(bootstrap_info)
|
|
369
|
+
else:
|
|
370
|
+
self.kv_mgr.record_failure(
|
|
371
|
+
self.bootstrap_room,
|
|
372
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
|
373
|
+
)
|
|
374
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
|
375
|
+
return
|
|
376
|
+
|
|
232
377
|
self.bootstrap_infos = bootstrap_infos
|
|
378
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
|
233
379
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
|
240
|
-
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
|
241
|
-
self._register_kv_args()
|
|
380
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
|
381
|
+
self._register_kv_args()
|
|
242
382
|
else:
|
|
243
383
|
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
|
244
384
|
|
|
245
385
|
assert len(self.bootstrap_infos) > 0
|
|
246
386
|
|
|
247
|
-
def _get_bootstrap_info_from_server(
|
|
387
|
+
def _get_bootstrap_info_from_server(
|
|
388
|
+
self, engine_rank, target_dp_group, target_pp_rank
|
|
389
|
+
):
|
|
248
390
|
"""Fetch the bootstrap info from the bootstrap server."""
|
|
249
391
|
try:
|
|
250
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
|
251
|
-
response = requests.get(url)
|
|
392
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
|
|
393
|
+
response = requests.get(url, timeout=5)
|
|
252
394
|
if response.status_code == 200:
|
|
253
395
|
bootstrap_info = response.json()
|
|
254
396
|
return bootstrap_info
|
|
@@ -261,24 +403,28 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
|
261
403
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
|
262
404
|
return None
|
|
263
405
|
|
|
264
|
-
def
|
|
406
|
+
def _get_prefill_parallel_info_from_server(
|
|
407
|
+
self,
|
|
408
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
|
265
409
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
|
266
410
|
try:
|
|
267
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
|
411
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
|
268
412
|
response = requests.get(url)
|
|
269
413
|
if response.status_code == 200:
|
|
270
414
|
prefill_parallel_info = response.json()
|
|
271
|
-
return
|
|
272
|
-
prefill_parallel_info["
|
|
415
|
+
return (
|
|
416
|
+
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
|
417
|
+
int(prefill_parallel_info["prefill_dp_size"]),
|
|
418
|
+
int(prefill_parallel_info["prefill_pp_size"]),
|
|
273
419
|
)
|
|
274
420
|
else:
|
|
275
421
|
logger.error(
|
|
276
422
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
|
277
423
|
)
|
|
278
|
-
return None
|
|
424
|
+
return None, None, None
|
|
279
425
|
except Exception as e:
|
|
280
426
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
|
281
|
-
return None
|
|
427
|
+
return None, None, None
|
|
282
428
|
|
|
283
429
|
@classmethod
|
|
284
430
|
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
|
@@ -317,10 +463,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
317
463
|
self.store = dict()
|
|
318
464
|
self.lock = asyncio.Lock()
|
|
319
465
|
self._setup_routes()
|
|
320
|
-
self.
|
|
466
|
+
self.pp_size = None
|
|
467
|
+
self.attn_tp_size = None
|
|
321
468
|
self.dp_size = None
|
|
322
|
-
self.
|
|
323
|
-
|
|
469
|
+
self.prefill_port_table: Dict[
|
|
470
|
+
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
|
471
|
+
] = {}
|
|
324
472
|
|
|
325
473
|
# Start bootstrap server
|
|
326
474
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
|
@@ -331,6 +479,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
331
479
|
|
|
332
480
|
def _setup_routes(self):
|
|
333
481
|
self.app.router.add_route("*", "/route", self._handle_route)
|
|
482
|
+
self.app.router.add_get("/health", self._handle_health_check)
|
|
483
|
+
|
|
484
|
+
async def _handle_health_check(self, request):
|
|
485
|
+
return web.Response(text="OK", status=200)
|
|
334
486
|
|
|
335
487
|
async def _handle_route(self, request: web.Request):
|
|
336
488
|
method = request.method
|
|
@@ -346,37 +498,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
346
498
|
async def _handle_route_put(self, request: web.Request):
|
|
347
499
|
data = await request.json()
|
|
348
500
|
role = data["role"]
|
|
349
|
-
|
|
350
|
-
|
|
501
|
+
attn_tp_size = data["attn_tp_size"]
|
|
502
|
+
attn_tp_rank = data["attn_tp_rank"]
|
|
503
|
+
attn_dp_size = data["attn_dp_size"]
|
|
504
|
+
attn_dp_rank = data["attn_dp_rank"]
|
|
505
|
+
pp_size = data["pp_size"]
|
|
506
|
+
pp_rank = data["pp_rank"]
|
|
507
|
+
system_dp_size = data["system_dp_size"]
|
|
508
|
+
system_dp_rank = data["system_dp_rank"]
|
|
351
509
|
rank_ip = data["rank_ip"]
|
|
352
510
|
rank_port = int(data["rank_port"])
|
|
353
|
-
engine_rank = int(data["engine_rank"])
|
|
354
511
|
|
|
355
|
-
if self.
|
|
356
|
-
self.
|
|
512
|
+
if self.attn_tp_size is None:
|
|
513
|
+
self.attn_tp_size = attn_tp_size
|
|
357
514
|
|
|
358
515
|
if self.dp_size is None:
|
|
359
|
-
self.dp_size =
|
|
516
|
+
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
|
360
517
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
|
518
|
+
if self.pp_size is None:
|
|
519
|
+
self.pp_size = pp_size
|
|
364
520
|
|
|
365
|
-
# Add lock to make sure thread-safe
|
|
366
521
|
if role == "Prefill":
|
|
367
|
-
|
|
368
|
-
|
|
522
|
+
if system_dp_size == 1:
|
|
523
|
+
dp_group = attn_dp_rank
|
|
524
|
+
else:
|
|
525
|
+
dp_group = system_dp_rank
|
|
369
526
|
|
|
527
|
+
# Add lock to make sure thread-safe
|
|
370
528
|
async with self.lock:
|
|
371
529
|
if dp_group not in self.prefill_port_table:
|
|
372
530
|
self.prefill_port_table[dp_group] = {}
|
|
531
|
+
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
|
532
|
+
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
|
373
533
|
|
|
374
|
-
self.prefill_port_table[dp_group][
|
|
534
|
+
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
|
375
535
|
"rank_ip": rank_ip,
|
|
376
536
|
"rank_port": rank_port,
|
|
377
537
|
}
|
|
378
538
|
logger.debug(
|
|
379
|
-
f"Register
|
|
539
|
+
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
|
380
540
|
)
|
|
381
541
|
|
|
382
542
|
return web.Response(text="OK", status=200)
|
|
@@ -384,14 +544,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
384
544
|
async def _handle_route_get(self, request: web.Request):
|
|
385
545
|
engine_rank = request.query.get("engine_rank")
|
|
386
546
|
target_dp_group = request.query.get("target_dp_group")
|
|
387
|
-
|
|
547
|
+
target_pp_rank = request.query.get("target_pp_rank")
|
|
548
|
+
if not engine_rank or not target_dp_group or not target_pp_rank:
|
|
388
549
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
|
389
550
|
|
|
390
551
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
|
391
|
-
if
|
|
552
|
+
if (
|
|
553
|
+
int(engine_rank) == -1
|
|
554
|
+
and int(target_dp_group) == -1
|
|
555
|
+
and int(target_pp_rank) == -1
|
|
556
|
+
):
|
|
392
557
|
prefill_parallel_info = {
|
|
393
|
-
"
|
|
558
|
+
"prefill_attn_tp_size": self.attn_tp_size,
|
|
394
559
|
"prefill_dp_size": self.dp_size,
|
|
560
|
+
"prefill_pp_size": self.pp_size,
|
|
395
561
|
}
|
|
396
562
|
return web.json_response(prefill_parallel_info, status=200)
|
|
397
563
|
|
|
@@ -399,7 +565,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
399
565
|
async with self.lock:
|
|
400
566
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
|
401
567
|
int(engine_rank)
|
|
402
|
-
]
|
|
568
|
+
][int(target_pp_rank)]
|
|
403
569
|
|
|
404
570
|
if bootstrap_info is not None:
|
|
405
571
|
return web.json_response(bootstrap_info, status=200)
|
|
@@ -412,7 +578,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
|
412
578
|
self._loop = asyncio.new_event_loop()
|
|
413
579
|
asyncio.set_event_loop(self._loop)
|
|
414
580
|
|
|
415
|
-
|
|
581
|
+
access_log = None
|
|
582
|
+
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
|
|
583
|
+
access_log = self.app.logger
|
|
584
|
+
|
|
585
|
+
self._runner = web.AppRunner(self.app, access_log=access_log)
|
|
416
586
|
self._loop.run_until_complete(self._runner.setup())
|
|
417
587
|
|
|
418
588
|
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|