sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
|
|
22
22
|
KVPoll,
|
23
23
|
)
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
25
|
+
from sglang.srt.distributed import get_pp_group
|
26
|
+
from sglang.srt.layers.dp_attention import (
|
27
|
+
get_attention_dp_rank,
|
28
|
+
get_attention_dp_size,
|
29
|
+
get_attention_tp_rank,
|
30
|
+
get_attention_tp_size,
|
31
|
+
)
|
25
32
|
from sglang.srt.server_args import ServerArgs
|
26
33
|
from sglang.srt.utils import (
|
27
34
|
format_tcp_address,
|
28
35
|
get_free_port,
|
29
|
-
|
30
|
-
get_local_ip_by_remote,
|
36
|
+
get_local_ip_auto,
|
31
37
|
is_valid_ipv6_address,
|
32
38
|
maybe_wrap_ipv6_address,
|
33
39
|
)
|
@@ -50,30 +56,49 @@ class CommonKVManager(BaseKVManager):
|
|
50
56
|
self.bootstrap_host = server_args.host
|
51
57
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
52
58
|
self.dist_init_addr = server_args.dist_init_addr
|
53
|
-
self.
|
54
|
-
self.
|
55
|
-
self.
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
59
|
+
self.attn_tp_size = get_attention_tp_size()
|
60
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
61
|
+
self.attn_dp_size = get_attention_dp_size()
|
62
|
+
self.attn_dp_rank = get_attention_dp_rank()
|
63
|
+
self.system_dp_size = (
|
64
|
+
1 if server_args.enable_dp_attention else server_args.dp_size
|
65
|
+
)
|
66
|
+
self.system_dp_rank = (
|
67
|
+
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
68
|
+
)
|
69
|
+
self.pp_size = server_args.pp_size
|
70
|
+
self.pp_rank = self.kv_args.pp_rank
|
61
71
|
self.rank_port = get_free_port()
|
72
|
+
self.local_ip = get_local_ip_auto()
|
73
|
+
self.server_socket = zmq.Context().socket(zmq.PULL)
|
74
|
+
if is_valid_ipv6_address(self.local_ip):
|
75
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
76
|
+
self.request_status: Dict[int, KVPoll] = {}
|
77
|
+
|
62
78
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
63
79
|
self._register_to_bootstrap()
|
80
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
81
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
82
|
+
self.pp_group = get_pp_group()
|
64
83
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
65
84
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
66
|
-
self.
|
85
|
+
self.connection_lock = threading.Lock()
|
86
|
+
self.required_prefill_response_num_table: Dict[int, int] = {}
|
87
|
+
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
67
88
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
89
|
+
self.prefill_pp_size_table: Dict[str, int] = {}
|
68
90
|
else:
|
69
91
|
raise ValueError(
|
70
92
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
71
93
|
)
|
72
94
|
|
95
|
+
def _bind_server_socket(self):
|
96
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
97
|
+
|
73
98
|
def _register_to_bootstrap(self):
|
74
99
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
75
100
|
if self.dist_init_addr:
|
76
|
-
#
|
101
|
+
# Multi-node case: bootstrap server's host is dist_init_addr
|
77
102
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
78
103
|
if self.dist_init_addr.endswith("]"):
|
79
104
|
host = self.dist_init_addr
|
@@ -82,7 +107,7 @@ class CommonKVManager(BaseKVManager):
|
|
82
107
|
else:
|
83
108
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
84
109
|
else:
|
85
|
-
#
|
110
|
+
# Single-node case: bootstrap server's host is the same as http server's host
|
86
111
|
host = self.bootstrap_host
|
87
112
|
host = maybe_wrap_ipv6_address(host)
|
88
113
|
|
@@ -90,23 +115,30 @@ class CommonKVManager(BaseKVManager):
|
|
90
115
|
url = f"http://{bootstrap_server_url}/route"
|
91
116
|
payload = {
|
92
117
|
"role": "Prefill",
|
93
|
-
"
|
94
|
-
"
|
95
|
-
"
|
118
|
+
"attn_tp_size": self.attn_tp_size,
|
119
|
+
"attn_tp_rank": self.attn_tp_rank,
|
120
|
+
"attn_dp_size": self.attn_dp_size,
|
121
|
+
"attn_dp_rank": self.attn_dp_rank,
|
122
|
+
"pp_size": self.pp_size,
|
123
|
+
"pp_rank": self.pp_rank,
|
124
|
+
"system_dp_size": self.system_dp_size,
|
125
|
+
"system_dp_rank": self.system_dp_rank,
|
126
|
+
"rank_ip": self.local_ip,
|
96
127
|
"rank_port": self.rank_port,
|
97
|
-
"engine_rank": self.kv_args.engine_rank,
|
98
128
|
}
|
99
129
|
|
100
130
|
try:
|
101
|
-
response = requests.put(url, json=payload)
|
131
|
+
response = requests.put(url, json=payload, timeout=5)
|
102
132
|
if response.status_code == 200:
|
103
133
|
logger.debug("Prefill successfully registered to bootstrap server.")
|
104
134
|
else:
|
105
135
|
logger.error(
|
106
|
-
f"Prefill
|
136
|
+
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
107
137
|
)
|
108
138
|
except Exception as e:
|
109
|
-
logger.error(
|
139
|
+
logger.error(
|
140
|
+
f"Prefill instance failed to register to bootstrap server: {e}"
|
141
|
+
)
|
110
142
|
|
111
143
|
@cache
|
112
144
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
@@ -116,6 +148,68 @@ class CommonKVManager(BaseKVManager):
|
|
116
148
|
socket.connect(endpoint)
|
117
149
|
return socket
|
118
150
|
|
151
|
+
def get_mha_kv_ptrs_with_pp(
|
152
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
153
|
+
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
154
|
+
# pp is not supported on the decode side yet
|
155
|
+
start_layer = self.kv_args.prefill_start_layer
|
156
|
+
num_kv_layers = len(src_kv_ptrs) // 2
|
157
|
+
end_layer = start_layer + num_kv_layers
|
158
|
+
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
159
|
+
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
160
|
+
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
161
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
162
|
+
dst_v_ptrs = dst_kv_ptrs[
|
163
|
+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
164
|
+
]
|
165
|
+
layers_current_pp_stage = len(src_k_ptrs)
|
166
|
+
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
167
|
+
|
168
|
+
def get_mla_kv_ptrs_with_pp(
|
169
|
+
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
170
|
+
) -> Tuple[List[int], List[int], int]:
|
171
|
+
# pp is not supported on the decode side yet
|
172
|
+
start_layer = self.kv_args.prefill_start_layer
|
173
|
+
end_layer = start_layer + len(src_kv_ptrs)
|
174
|
+
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
175
|
+
layers_current_pp_stage = len(src_kv_ptrs)
|
176
|
+
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
177
|
+
|
178
|
+
|
179
|
+
class CommonKVSender(BaseKVSender):
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
mgr: BaseKVManager,
|
184
|
+
bootstrap_addr: str,
|
185
|
+
bootstrap_room: int,
|
186
|
+
dest_tp_ranks: List[int],
|
187
|
+
pp_rank: int,
|
188
|
+
):
|
189
|
+
self.kv_mgr = mgr
|
190
|
+
self.bootstrap_room = bootstrap_room
|
191
|
+
self.aux_index = None
|
192
|
+
self.bootstrap_server_url = bootstrap_addr
|
193
|
+
# inner state
|
194
|
+
self.curr_idx = 0
|
195
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
196
|
+
|
197
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
198
|
+
self.num_kv_indices = num_kv_indices
|
199
|
+
self.aux_index = aux_index
|
200
|
+
|
201
|
+
def send(
|
202
|
+
self,
|
203
|
+
kv_indices: npt.NDArray[np.int32],
|
204
|
+
):
|
205
|
+
pass
|
206
|
+
|
207
|
+
def poll(self) -> KVPoll:
|
208
|
+
pass
|
209
|
+
|
210
|
+
def failure_exception(self):
|
211
|
+
raise Exception("Fake KVReceiver Exception")
|
212
|
+
|
119
213
|
|
120
214
|
class CommonKVReceiver(BaseKVReceiver):
|
121
215
|
_ctx = zmq.Context()
|
@@ -133,61 +227,88 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
133
227
|
self.bootstrap_room = bootstrap_room
|
134
228
|
self.bootstrap_addr = bootstrap_addr
|
135
229
|
self.kv_mgr = mgr
|
230
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
136
231
|
|
137
232
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
138
|
-
|
139
|
-
self.
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
233
|
+
(
|
234
|
+
self.prefill_attn_tp_size,
|
235
|
+
self.prefill_dp_size,
|
236
|
+
self.prefill_pp_size,
|
237
|
+
) = self._get_prefill_parallel_info_from_server()
|
238
|
+
if (
|
239
|
+
self.prefill_attn_tp_size is None
|
240
|
+
or self.prefill_dp_size is None
|
241
|
+
or self.prefill_pp_size is None
|
242
|
+
):
|
243
|
+
self.kv_mgr.record_failure(
|
244
|
+
self.bootstrap_room,
|
245
|
+
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
144
246
|
)
|
247
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
248
|
+
return
|
145
249
|
else:
|
146
|
-
|
147
|
-
self.
|
250
|
+
logger.debug(
|
251
|
+
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
252
|
+
)
|
253
|
+
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
254
|
+
self.prefill_attn_tp_size
|
148
255
|
)
|
149
256
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
150
257
|
self.prefill_dp_size
|
151
258
|
)
|
259
|
+
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
260
|
+
self.prefill_pp_size
|
261
|
+
)
|
152
262
|
else:
|
153
|
-
self.
|
263
|
+
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
154
264
|
self.bootstrap_addr
|
155
265
|
]
|
156
266
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
157
267
|
self.bootstrap_addr
|
158
268
|
]
|
269
|
+
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
270
|
+
self.bootstrap_addr
|
271
|
+
]
|
159
272
|
|
160
273
|
# Currently, we don't allow prefill instance and decode instance to
|
161
274
|
# have different TP sizes per DP rank, except for models using MLA.
|
162
|
-
|
163
|
-
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
164
|
-
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
275
|
+
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
165
276
|
self.target_tp_rank = (
|
166
|
-
self.kv_mgr.kv_args.engine_rank %
|
277
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
167
278
|
)
|
168
279
|
self.required_dst_info_num = 1
|
280
|
+
self.required_prefill_response_num = 1 * (
|
281
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
282
|
+
)
|
169
283
|
self.target_tp_ranks = [self.target_tp_rank]
|
170
|
-
elif
|
284
|
+
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
285
|
+
if not self.kv_mgr.is_mla_backend:
|
286
|
+
logger.warning_once(
|
287
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
288
|
+
)
|
171
289
|
self.target_tp_rank = (
|
172
|
-
self.kv_mgr.kv_args.engine_rank %
|
173
|
-
) // (
|
290
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
291
|
+
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
174
292
|
self.required_dst_info_num = (
|
175
|
-
|
293
|
+
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
294
|
+
)
|
295
|
+
self.required_prefill_response_num = 1 * (
|
296
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
176
297
|
)
|
177
298
|
self.target_tp_ranks = [self.target_tp_rank]
|
178
299
|
else:
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
300
|
+
if not self.kv_mgr.is_mla_backend:
|
301
|
+
logger.warning_once(
|
302
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
303
|
+
)
|
183
304
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
184
305
|
self.target_tp_ranks = [
|
185
306
|
rank
|
186
307
|
for rank in range(
|
187
|
-
(self.kv_mgr.kv_args.engine_rank %
|
188
|
-
* (
|
189
|
-
(self.kv_mgr.kv_args.engine_rank %
|
190
|
-
* (
|
308
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
309
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
310
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
311
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
191
312
|
)
|
192
313
|
]
|
193
314
|
|
@@ -196,6 +317,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
196
317
|
# or the KVPoll will never be set correctly
|
197
318
|
self.target_tp_rank = self.target_tp_ranks[0]
|
198
319
|
self.required_dst_info_num = 1
|
320
|
+
if self.kv_mgr.is_mla_backend:
|
321
|
+
self.required_prefill_response_num = (
|
322
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
self.required_prefill_response_num = (
|
326
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
327
|
+
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
199
328
|
|
200
329
|
if prefill_dp_rank is not None:
|
201
330
|
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
@@ -206,6 +335,9 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
206
335
|
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
207
336
|
self.target_dp_group = self.prefill_dp_rank
|
208
337
|
|
338
|
+
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
339
|
+
self.required_prefill_response_num
|
340
|
+
)
|
209
341
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
210
342
|
bootstrap_key = (
|
211
343
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
@@ -214,41 +346,49 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
214
346
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
215
347
|
bootstrap_infos = []
|
216
348
|
for target_tp_rank in self.target_tp_ranks:
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
)
|
221
|
-
if bootstrap_info is not None:
|
222
|
-
# NOTE: only support MLA for now: select one prefill rank as real rank
|
223
|
-
bootstrap_info["is_dummy"] = not bool(
|
224
|
-
target_tp_rank == self.target_tp_rank
|
225
|
-
or self.target_tp_rank is None
|
226
|
-
)
|
227
|
-
bootstrap_infos.append(bootstrap_info)
|
228
|
-
else:
|
229
|
-
logger.error(
|
230
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
349
|
+
for target_pp_rank in range(self.prefill_pp_size):
|
350
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
351
|
+
target_tp_rank, self.target_dp_group, target_pp_rank
|
231
352
|
)
|
353
|
+
if bootstrap_info is not None:
|
354
|
+
if self.kv_mgr.is_mla_backend:
|
355
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
356
|
+
bootstrap_info["is_dummy"] = not bool(
|
357
|
+
target_tp_rank == self.target_tp_rank
|
358
|
+
or self.target_tp_rank is None
|
359
|
+
)
|
360
|
+
else:
|
361
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
362
|
+
bootstrap_info["is_dummy"] = False
|
363
|
+
logger.debug(
|
364
|
+
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
365
|
+
)
|
366
|
+
bootstrap_infos.append(bootstrap_info)
|
367
|
+
else:
|
368
|
+
self.kv_mgr.record_failure(
|
369
|
+
self.bootstrap_room,
|
370
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
371
|
+
)
|
372
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
373
|
+
return
|
374
|
+
|
232
375
|
self.bootstrap_infos = bootstrap_infos
|
376
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
233
377
|
|
234
|
-
|
235
|
-
|
236
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
237
|
-
)
|
238
|
-
else:
|
239
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
240
|
-
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
241
|
-
self._register_kv_args()
|
378
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
379
|
+
self._register_kv_args()
|
242
380
|
else:
|
243
381
|
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
244
382
|
|
245
383
|
assert len(self.bootstrap_infos) > 0
|
246
384
|
|
247
|
-
def _get_bootstrap_info_from_server(
|
385
|
+
def _get_bootstrap_info_from_server(
|
386
|
+
self, engine_rank, target_dp_group, target_pp_rank
|
387
|
+
):
|
248
388
|
"""Fetch the bootstrap info from the bootstrap server."""
|
249
389
|
try:
|
250
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
251
|
-
response = requests.get(url)
|
390
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
|
391
|
+
response = requests.get(url, timeout=5)
|
252
392
|
if response.status_code == 200:
|
253
393
|
bootstrap_info = response.json()
|
254
394
|
return bootstrap_info
|
@@ -261,24 +401,28 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
261
401
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
262
402
|
return None
|
263
403
|
|
264
|
-
def
|
404
|
+
def _get_prefill_parallel_info_from_server(
|
405
|
+
self,
|
406
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
265
407
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
266
408
|
try:
|
267
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
409
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
268
410
|
response = requests.get(url)
|
269
411
|
if response.status_code == 200:
|
270
412
|
prefill_parallel_info = response.json()
|
271
|
-
return
|
272
|
-
prefill_parallel_info["
|
413
|
+
return (
|
414
|
+
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
415
|
+
int(prefill_parallel_info["prefill_dp_size"]),
|
416
|
+
int(prefill_parallel_info["prefill_pp_size"]),
|
273
417
|
)
|
274
418
|
else:
|
275
419
|
logger.error(
|
276
420
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
277
421
|
)
|
278
|
-
return None
|
422
|
+
return None, None, None
|
279
423
|
except Exception as e:
|
280
424
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
281
|
-
return None
|
425
|
+
return None, None, None
|
282
426
|
|
283
427
|
@classmethod
|
284
428
|
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
@@ -317,10 +461,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
317
461
|
self.store = dict()
|
318
462
|
self.lock = asyncio.Lock()
|
319
463
|
self._setup_routes()
|
320
|
-
self.
|
464
|
+
self.pp_size = None
|
465
|
+
self.attn_tp_size = None
|
321
466
|
self.dp_size = None
|
322
|
-
self.
|
323
|
-
|
467
|
+
self.prefill_port_table: Dict[
|
468
|
+
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
469
|
+
] = {}
|
324
470
|
|
325
471
|
# Start bootstrap server
|
326
472
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -331,6 +477,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
331
477
|
|
332
478
|
def _setup_routes(self):
|
333
479
|
self.app.router.add_route("*", "/route", self._handle_route)
|
480
|
+
self.app.router.add_get("/health", self._handle_health_check)
|
481
|
+
|
482
|
+
async def _handle_health_check(self, request):
|
483
|
+
return web.Response(text="OK", status=200)
|
334
484
|
|
335
485
|
async def _handle_route(self, request: web.Request):
|
336
486
|
method = request.method
|
@@ -346,37 +496,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
346
496
|
async def _handle_route_put(self, request: web.Request):
|
347
497
|
data = await request.json()
|
348
498
|
role = data["role"]
|
349
|
-
|
350
|
-
|
499
|
+
attn_tp_size = data["attn_tp_size"]
|
500
|
+
attn_tp_rank = data["attn_tp_rank"]
|
501
|
+
attn_dp_size = data["attn_dp_size"]
|
502
|
+
attn_dp_rank = data["attn_dp_rank"]
|
503
|
+
pp_size = data["pp_size"]
|
504
|
+
pp_rank = data["pp_rank"]
|
505
|
+
system_dp_size = data["system_dp_size"]
|
506
|
+
system_dp_rank = data["system_dp_rank"]
|
351
507
|
rank_ip = data["rank_ip"]
|
352
508
|
rank_port = int(data["rank_port"])
|
353
|
-
engine_rank = int(data["engine_rank"])
|
354
509
|
|
355
|
-
if self.
|
356
|
-
self.
|
510
|
+
if self.attn_tp_size is None:
|
511
|
+
self.attn_tp_size = attn_tp_size
|
357
512
|
|
358
513
|
if self.dp_size is None:
|
359
|
-
self.dp_size =
|
514
|
+
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
360
515
|
|
361
|
-
|
362
|
-
|
363
|
-
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
516
|
+
if self.pp_size is None:
|
517
|
+
self.pp_size = pp_size
|
364
518
|
|
365
|
-
# Add lock to make sure thread-safe
|
366
519
|
if role == "Prefill":
|
367
|
-
|
368
|
-
|
520
|
+
if system_dp_size == 1:
|
521
|
+
dp_group = attn_dp_rank
|
522
|
+
else:
|
523
|
+
dp_group = system_dp_rank
|
369
524
|
|
525
|
+
# Add lock to make sure thread-safe
|
370
526
|
async with self.lock:
|
371
527
|
if dp_group not in self.prefill_port_table:
|
372
528
|
self.prefill_port_table[dp_group] = {}
|
529
|
+
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
530
|
+
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
373
531
|
|
374
|
-
self.prefill_port_table[dp_group][
|
532
|
+
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
375
533
|
"rank_ip": rank_ip,
|
376
534
|
"rank_port": rank_port,
|
377
535
|
}
|
378
536
|
logger.debug(
|
379
|
-
f"Register
|
537
|
+
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
380
538
|
)
|
381
539
|
|
382
540
|
return web.Response(text="OK", status=200)
|
@@ -384,14 +542,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
384
542
|
async def _handle_route_get(self, request: web.Request):
|
385
543
|
engine_rank = request.query.get("engine_rank")
|
386
544
|
target_dp_group = request.query.get("target_dp_group")
|
387
|
-
|
545
|
+
target_pp_rank = request.query.get("target_pp_rank")
|
546
|
+
if not engine_rank or not target_dp_group or not target_pp_rank:
|
388
547
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
389
548
|
|
390
549
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
391
|
-
if
|
550
|
+
if (
|
551
|
+
int(engine_rank) == -1
|
552
|
+
and int(target_dp_group) == -1
|
553
|
+
and int(target_pp_rank) == -1
|
554
|
+
):
|
392
555
|
prefill_parallel_info = {
|
393
|
-
"
|
556
|
+
"prefill_attn_tp_size": self.attn_tp_size,
|
394
557
|
"prefill_dp_size": self.dp_size,
|
558
|
+
"prefill_pp_size": self.pp_size,
|
395
559
|
}
|
396
560
|
return web.json_response(prefill_parallel_info, status=200)
|
397
561
|
|
@@ -399,7 +563,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
399
563
|
async with self.lock:
|
400
564
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
401
565
|
int(engine_rank)
|
402
|
-
]
|
566
|
+
][int(target_pp_rank)]
|
403
567
|
|
404
568
|
if bootstrap_info is not None:
|
405
569
|
return web.json_response(bootstrap_info, status=200)
|
@@ -412,7 +576,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
412
576
|
self._loop = asyncio.new_event_loop()
|
413
577
|
asyncio.set_event_loop(self._loop)
|
414
578
|
|
415
|
-
|
579
|
+
access_log = None
|
580
|
+
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
|
581
|
+
access_log = self.app.logger
|
582
|
+
|
583
|
+
self._runner = web.AppRunner(self.app, access_log=access_log)
|
416
584
|
self._loop.run_until_complete(self._runner.setup())
|
417
585
|
|
418
586
|
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|