sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import concurrent.futures
|
5
4
|
import ctypes
|
6
5
|
import dataclasses
|
7
6
|
import logging
|
8
7
|
import os
|
9
|
-
import queue
|
10
|
-
import socket
|
11
8
|
import struct
|
12
9
|
import threading
|
13
10
|
import time
|
14
11
|
from collections import defaultdict
|
15
|
-
from
|
16
|
-
from typing import Dict, List, Optional, Tuple, Union
|
12
|
+
from typing import Dict, List, Optional, Tuple
|
17
13
|
|
18
14
|
import numpy as np
|
19
15
|
import numpy.typing as npt
|
20
16
|
import requests
|
21
17
|
import zmq
|
22
|
-
|
23
|
-
|
24
|
-
from sglang.srt.disaggregation.
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
KVArgs,
|
30
|
-
KVPoll,
|
18
|
+
|
19
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
20
|
+
from sglang.srt.disaggregation.common.conn import (
|
21
|
+
CommonKVBootstrapServer,
|
22
|
+
CommonKVManager,
|
23
|
+
CommonKVReceiver,
|
24
|
+
CommonKVSender,
|
31
25
|
)
|
32
26
|
from sglang.srt.disaggregation.common.utils import (
|
33
27
|
FastQueue,
|
@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
29
|
)
|
36
30
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
37
31
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
38
|
-
from sglang.srt.distributed import get_pp_group
|
39
|
-
from sglang.srt.layers.dp_attention import (
|
40
|
-
get_attention_dp_rank,
|
41
|
-
get_attention_dp_size,
|
42
|
-
get_attention_tp_rank,
|
43
|
-
get_attention_tp_size,
|
44
|
-
)
|
45
32
|
from sglang.srt.server_args import ServerArgs
|
46
33
|
from sglang.srt.utils import (
|
47
34
|
format_tcp_address,
|
48
35
|
get_bool_env_var,
|
49
|
-
get_free_port,
|
50
36
|
get_int_env_var,
|
51
|
-
get_ip,
|
52
|
-
get_local_ip_auto,
|
53
37
|
is_valid_ipv6_address,
|
54
|
-
maybe_wrap_ipv6_address,
|
55
38
|
)
|
56
39
|
|
57
40
|
logger = logging.getLogger(__name__)
|
@@ -159,7 +142,7 @@ class AuxDataCodec:
|
|
159
142
|
return
|
160
143
|
|
161
144
|
|
162
|
-
class MooncakeKVManager(
|
145
|
+
class MooncakeKVManager(CommonKVManager):
|
163
146
|
AUX_DATA_HEADER = b"AUX_DATA"
|
164
147
|
|
165
148
|
def __init__(
|
@@ -169,43 +152,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
169
152
|
server_args: ServerArgs,
|
170
153
|
is_mla_backend: Optional[bool] = False,
|
171
154
|
):
|
172
|
-
|
173
|
-
self.local_ip = get_local_ip_auto()
|
174
|
-
self.is_mla_backend = is_mla_backend
|
175
|
-
self.disaggregation_mode = disaggregation_mode
|
155
|
+
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
|
176
156
|
self.init_engine()
|
177
|
-
# for p/d multi node infer
|
178
|
-
self.bootstrap_host = server_args.host
|
179
|
-
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
180
|
-
self.dist_init_addr = server_args.dist_init_addr
|
181
|
-
self.attn_tp_size = get_attention_tp_size()
|
182
|
-
self.attn_tp_rank = get_attention_tp_rank()
|
183
|
-
self.attn_dp_size = get_attention_dp_size()
|
184
|
-
self.attn_dp_rank = get_attention_dp_rank()
|
185
|
-
self.system_dp_size = (
|
186
|
-
1 if server_args.enable_dp_attention else server_args.dp_size
|
187
|
-
)
|
188
|
-
self.system_dp_rank = (
|
189
|
-
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
190
|
-
)
|
191
|
-
self.pp_size = server_args.pp_size
|
192
|
-
self.pp_rank = self.kv_args.pp_rank
|
193
|
-
self.request_status: Dict[int, KVPoll] = {}
|
194
|
-
self.rank_port = None
|
195
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
196
|
-
if is_valid_ipv6_address(self.local_ip):
|
197
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
198
|
-
|
199
157
|
self.register_buffer_to_engine()
|
200
158
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
201
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
202
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
203
159
|
self.start_prefill_thread()
|
204
|
-
self._register_to_bootstrap()
|
205
160
|
self.session_failures = defaultdict(int)
|
206
161
|
self.failed_sessions = set()
|
207
162
|
self.session_lock = threading.Lock()
|
208
|
-
self.pp_group = get_pp_group()
|
209
163
|
# Determine the number of threads to use for kv sender
|
210
164
|
cpu_count = os.cpu_count()
|
211
165
|
transfer_thread_pool_size = get_int_env_var(
|
@@ -245,8 +199,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
245
199
|
self.session_pool = defaultdict(requests.Session)
|
246
200
|
self.session_pool_lock = threading.Lock()
|
247
201
|
self.addr_to_rooms_tracker = defaultdict(set)
|
248
|
-
self.connection_lock = threading.Lock()
|
249
|
-
self.required_prefill_response_num_table: Dict[int, int] = {}
|
250
202
|
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
251
203
|
# Heartbeat interval should be at least 2 seconds
|
252
204
|
self.heartbeat_interval = max(
|
@@ -257,20 +209,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
257
209
|
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
258
210
|
)
|
259
211
|
self.start_decode_thread()
|
260
|
-
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
261
|
-
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
262
|
-
self.prefill_dp_size_table: Dict[str, int] = {}
|
263
|
-
self.prefill_pp_size_table: Dict[str, int] = {}
|
264
212
|
# If a timeout happens on the decode side, it means decode instances
|
265
213
|
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
266
214
|
# These timeout requests should be aborted to release the tree cache.
|
267
215
|
self.waiting_timeout = get_int_env_var(
|
268
216
|
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
|
269
217
|
)
|
270
|
-
else:
|
271
|
-
raise ValueError(
|
272
|
-
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
273
|
-
)
|
274
218
|
|
275
219
|
self.failure_records: Dict[int, str] = {}
|
276
220
|
self.failure_lock = threading.Lock()
|
@@ -295,14 +239,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
295
239
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
296
240
|
)
|
297
241
|
|
298
|
-
@cache
|
299
|
-
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
300
|
-
socket = zmq.Context().socket(zmq.PUSH)
|
301
|
-
if is_ipv6:
|
302
|
-
socket.setsockopt(zmq.IPV6, 1)
|
303
|
-
socket.connect(endpoint)
|
304
|
-
return socket
|
305
|
-
|
306
242
|
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
307
243
|
if not transfer_blocks:
|
308
244
|
return 0
|
@@ -328,12 +264,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
328
264
|
layers_params = None
|
329
265
|
|
330
266
|
# pp is not supported on the decode side yet
|
331
|
-
start_layer = self.kv_args.prefill_start_layer
|
332
|
-
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
333
267
|
if self.is_mla_backend:
|
334
|
-
src_kv_ptrs =
|
335
|
-
|
336
|
-
|
268
|
+
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
269
|
+
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
270
|
+
)
|
337
271
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
338
272
|
layers_params = [
|
339
273
|
(
|
@@ -341,18 +275,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
341
275
|
dst_kv_ptrs[layer_id],
|
342
276
|
kv_item_len,
|
343
277
|
)
|
344
|
-
for layer_id in range(
|
278
|
+
for layer_id in range(layers_current_pp_stage)
|
345
279
|
]
|
346
280
|
else:
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
351
|
-
layers_per_pp_stage = len(src_k_ptrs)
|
352
|
-
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
353
|
-
dst_v_ptrs = dst_kv_ptrs[
|
354
|
-
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
355
|
-
]
|
281
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
282
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
283
|
+
)
|
356
284
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
357
285
|
layers_params = [
|
358
286
|
(
|
@@ -360,14 +288,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
360
288
|
dst_k_ptrs[layer_id],
|
361
289
|
kv_item_len,
|
362
290
|
)
|
363
|
-
for layer_id in range(
|
291
|
+
for layer_id in range(layers_current_pp_stage)
|
364
292
|
] + [
|
365
293
|
(
|
366
294
|
src_v_ptrs[layer_id],
|
367
295
|
dst_v_ptrs[layer_id],
|
368
296
|
kv_item_len,
|
369
297
|
)
|
370
|
-
for layer_id in range(
|
298
|
+
for layer_id in range(layers_current_pp_stage)
|
371
299
|
]
|
372
300
|
assert layers_params is not None
|
373
301
|
|
@@ -465,18 +393,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
465
393
|
num_heads_to_send = dst_heads_per_rank
|
466
394
|
dst_head_start_offset = 0
|
467
395
|
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
472
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
473
|
-
layers_per_pp_stage = len(src_k_ptrs)
|
474
|
-
start_layer = self.pp_rank * layers_per_pp_stage
|
475
|
-
end_layer = start_layer + layers_per_pp_stage
|
476
|
-
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
477
|
-
dst_v_ptrs = dst_kv_ptrs[
|
478
|
-
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
479
|
-
]
|
396
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
397
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
398
|
+
)
|
480
399
|
|
481
400
|
# Calculate precise byte offset and length for the sub-slice within the token
|
482
401
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
@@ -502,7 +421,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
502
421
|
dst_head_slice_offset,
|
503
422
|
heads_bytes_per_token_to_send,
|
504
423
|
)
|
505
|
-
for layer_id in range(
|
424
|
+
for layer_id in range(layers_current_pp_stage)
|
506
425
|
] + [
|
507
426
|
(
|
508
427
|
src_v_ptrs[layer_id],
|
@@ -513,7 +432,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
513
432
|
dst_head_slice_offset,
|
514
433
|
heads_bytes_per_token_to_send,
|
515
434
|
)
|
516
|
-
for layer_id in range(
|
435
|
+
for layer_id in range(layers_current_pp_stage)
|
517
436
|
]
|
518
437
|
|
519
438
|
def process_layer_tp_aware(layer_params):
|
@@ -654,6 +573,26 @@ class MooncakeKVManager(BaseKVManager):
|
|
654
573
|
]
|
655
574
|
)
|
656
575
|
|
576
|
+
def _handle_aux_data(self, msg: List[bytes]):
|
577
|
+
"""Handle AUX_DATA messages received by the decode thread."""
|
578
|
+
room = int(msg[1].decode("ascii"))
|
579
|
+
buffer_index = int(msg[2].decode("ascii"))
|
580
|
+
aux_index = int(msg[3].decode("ascii"))
|
581
|
+
data_length = struct.unpack(">I", msg[4])[0]
|
582
|
+
data = msg[5]
|
583
|
+
|
584
|
+
if len(data) != data_length:
|
585
|
+
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
586
|
+
return
|
587
|
+
|
588
|
+
AuxDataCodec.deserialize_data_to_buffer(
|
589
|
+
self.kv_args, buffer_index, aux_index, data
|
590
|
+
)
|
591
|
+
|
592
|
+
logger.debug(
|
593
|
+
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
594
|
+
)
|
595
|
+
|
657
596
|
def sync_status_to_decode_endpoint(
|
658
597
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
659
598
|
):
|
@@ -802,11 +741,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
802
741
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
803
742
|
)
|
804
743
|
|
805
|
-
def _bind_server_socket(self):
|
806
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
807
|
-
|
808
744
|
def start_prefill_thread(self):
|
809
|
-
self.rank_port = get_free_port()
|
810
745
|
self._bind_server_socket()
|
811
746
|
|
812
747
|
def bootstrap_thread():
|
@@ -844,28 +779,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
844
779
|
|
845
780
|
threading.Thread(target=bootstrap_thread).start()
|
846
781
|
|
847
|
-
def _handle_aux_data(self, msg: List[bytes]):
|
848
|
-
"""Handle AUX_DATA messages received by the decode thread."""
|
849
|
-
room = int(msg[1].decode("ascii"))
|
850
|
-
buffer_index = int(msg[2].decode("ascii"))
|
851
|
-
aux_index = int(msg[3].decode("ascii"))
|
852
|
-
data_length = struct.unpack(">I", msg[4])[0]
|
853
|
-
data = msg[5]
|
854
|
-
|
855
|
-
if len(data) != data_length:
|
856
|
-
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
857
|
-
return
|
858
|
-
|
859
|
-
AuxDataCodec.deserialize_data_to_buffer(
|
860
|
-
self.kv_args, buffer_index, aux_index, data
|
861
|
-
)
|
862
|
-
|
863
|
-
logger.debug(
|
864
|
-
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
865
|
-
)
|
866
|
-
|
867
782
|
def start_decode_thread(self):
|
868
|
-
self.rank_port = get_free_port()
|
869
783
|
self._bind_server_socket()
|
870
784
|
|
871
785
|
def decode_thread():
|
@@ -1020,51 +934,6 @@ class MooncakeKVManager(BaseKVManager):
|
|
1020
934
|
def get_session_id(self):
|
1021
935
|
return self.engine.get_session_id()
|
1022
936
|
|
1023
|
-
def _register_to_bootstrap(self):
|
1024
|
-
"""Register KVSender to bootstrap server via HTTP POST."""
|
1025
|
-
if self.dist_init_addr:
|
1026
|
-
# multi node case: bootstrap server's host is dist_init_addr
|
1027
|
-
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
1028
|
-
if self.dist_init_addr.endswith("]"):
|
1029
|
-
host = self.dist_init_addr
|
1030
|
-
else:
|
1031
|
-
host, _ = self.dist_init_addr.rsplit(":", 1)
|
1032
|
-
else:
|
1033
|
-
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
1034
|
-
else:
|
1035
|
-
# single node case: bootstrap server's host is same as http server's host
|
1036
|
-
host = self.bootstrap_host
|
1037
|
-
host = maybe_wrap_ipv6_address(host)
|
1038
|
-
|
1039
|
-
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
1040
|
-
url = f"http://{bootstrap_server_url}/route"
|
1041
|
-
payload = {
|
1042
|
-
"role": "Prefill",
|
1043
|
-
"attn_tp_size": self.attn_tp_size,
|
1044
|
-
"attn_tp_rank": self.attn_tp_rank,
|
1045
|
-
"attn_dp_size": self.attn_dp_size,
|
1046
|
-
"attn_dp_rank": self.attn_dp_rank,
|
1047
|
-
"pp_size": self.pp_size,
|
1048
|
-
"pp_rank": self.pp_rank,
|
1049
|
-
"system_dp_size": self.system_dp_size,
|
1050
|
-
"system_dp_rank": self.system_dp_rank,
|
1051
|
-
"rank_ip": self.local_ip,
|
1052
|
-
"rank_port": self.rank_port,
|
1053
|
-
}
|
1054
|
-
|
1055
|
-
try:
|
1056
|
-
response = requests.put(url, json=payload, timeout=5)
|
1057
|
-
if response.status_code == 200:
|
1058
|
-
logger.debug("Prefill successfully registered to bootstrap server.")
|
1059
|
-
else:
|
1060
|
-
logger.error(
|
1061
|
-
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
1062
|
-
)
|
1063
|
-
except Exception as e:
|
1064
|
-
logger.error(
|
1065
|
-
f"Prefill instance failed to register to bootstrap server: {e}"
|
1066
|
-
)
|
1067
|
-
|
1068
937
|
def _handle_node_failure(self, failed_bootstrap_addr):
|
1069
938
|
with self.connection_lock:
|
1070
939
|
keys_to_remove = [
|
@@ -1103,7 +972,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
1103
972
|
)
|
1104
973
|
|
1105
974
|
|
1106
|
-
class MooncakeKVSender(
|
975
|
+
class MooncakeKVSender(CommonKVSender):
|
1107
976
|
|
1108
977
|
def __init__(
|
1109
978
|
self,
|
@@ -1113,19 +982,9 @@ class MooncakeKVSender(BaseKVSender):
|
|
1113
982
|
dest_tp_ranks: List[int],
|
1114
983
|
pp_rank: int,
|
1115
984
|
):
|
1116
|
-
|
1117
|
-
self.bootstrap_room = bootstrap_room
|
1118
|
-
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
1119
|
-
self.aux_index = None
|
1120
|
-
self.bootstrap_server_url = bootstrap_addr
|
985
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
1121
986
|
self.conclude_state = None
|
1122
987
|
self.init_time = time.time()
|
1123
|
-
# inner state
|
1124
|
-
self.curr_idx = 0
|
1125
|
-
|
1126
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
1127
|
-
self.num_kv_indices = num_kv_indices
|
1128
|
-
self.aux_index = aux_index
|
1129
988
|
|
1130
989
|
def send(
|
1131
990
|
self,
|
@@ -1203,7 +1062,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
1203
1062
|
self.conclude_state = KVPoll.Failed
|
1204
1063
|
|
1205
1064
|
|
1206
|
-
class MooncakeKVReceiver(
|
1065
|
+
class MooncakeKVReceiver(CommonKVReceiver):
|
1207
1066
|
_ctx = zmq.Context()
|
1208
1067
|
_socket_cache = {}
|
1209
1068
|
_socket_locks = {}
|
@@ -1216,166 +1075,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1216
1075
|
bootstrap_room: Optional[int] = None,
|
1217
1076
|
prefill_dp_rank: Optional[int] = None,
|
1218
1077
|
):
|
1219
|
-
self.
|
1220
|
-
self.bootstrap_addr = bootstrap_addr
|
1221
|
-
self.kv_mgr = mgr
|
1222
|
-
self.session_id = self.kv_mgr.get_session_id()
|
1223
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
1078
|
+
self.session_id = mgr.get_session_id()
|
1224
1079
|
self.conclude_state = None
|
1225
1080
|
self.init_time = None
|
1081
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
1226
1082
|
|
1227
|
-
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
1228
|
-
(
|
1229
|
-
self.prefill_attn_tp_size,
|
1230
|
-
self.prefill_dp_size,
|
1231
|
-
self.prefill_pp_size,
|
1232
|
-
) = self._get_prefill_parallel_info_from_server()
|
1233
|
-
if (
|
1234
|
-
self.prefill_attn_tp_size is None
|
1235
|
-
or self.prefill_dp_size is None
|
1236
|
-
or self.prefill_pp_size is None
|
1237
|
-
):
|
1238
|
-
self.kv_mgr.record_failure(
|
1239
|
-
self.bootstrap_room,
|
1240
|
-
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
1241
|
-
)
|
1242
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1243
|
-
return
|
1244
|
-
else:
|
1245
|
-
logger.debug(
|
1246
|
-
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
1247
|
-
)
|
1248
|
-
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
1249
|
-
self.prefill_attn_tp_size
|
1250
|
-
)
|
1251
|
-
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
1252
|
-
self.prefill_dp_size
|
1253
|
-
)
|
1254
|
-
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
1255
|
-
self.prefill_pp_size
|
1256
|
-
)
|
1257
|
-
else:
|
1258
|
-
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
1259
|
-
self.bootstrap_addr
|
1260
|
-
]
|
1261
|
-
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
1262
|
-
self.bootstrap_addr
|
1263
|
-
]
|
1264
|
-
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
1265
|
-
self.bootstrap_addr
|
1266
|
-
]
|
1267
|
-
|
1268
|
-
# Currently, we don't allow prefill instance and decode instance to
|
1269
|
-
# have different TP sizes per DP rank, except for models using MLA.
|
1270
|
-
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
1271
|
-
self.target_tp_rank = (
|
1272
|
-
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1273
|
-
)
|
1274
|
-
self.required_dst_info_num = 1
|
1275
|
-
self.required_prefill_response_num = 1 * (
|
1276
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1277
|
-
)
|
1278
|
-
self.target_tp_ranks = [self.target_tp_rank]
|
1279
|
-
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
1280
|
-
if not self.kv_mgr.is_mla_backend:
|
1281
|
-
logger.warning_once(
|
1282
|
-
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1283
|
-
)
|
1284
|
-
self.target_tp_rank = (
|
1285
|
-
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1286
|
-
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
1287
|
-
self.required_dst_info_num = (
|
1288
|
-
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
1289
|
-
)
|
1290
|
-
self.required_prefill_response_num = 1 * (
|
1291
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1292
|
-
)
|
1293
|
-
self.target_tp_ranks = [self.target_tp_rank]
|
1294
|
-
else:
|
1295
|
-
if not self.kv_mgr.is_mla_backend:
|
1296
|
-
logger.warning_once(
|
1297
|
-
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1298
|
-
)
|
1299
|
-
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
1300
|
-
self.target_tp_ranks = [
|
1301
|
-
rank
|
1302
|
-
for rank in range(
|
1303
|
-
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
1304
|
-
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1305
|
-
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
1306
|
-
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1307
|
-
)
|
1308
|
-
]
|
1309
|
-
|
1310
|
-
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
1311
|
-
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
1312
|
-
# or the KVPoll will never be set correctly
|
1313
|
-
self.target_tp_rank = self.target_tp_ranks[0]
|
1314
|
-
self.required_dst_info_num = 1
|
1315
|
-
if self.kv_mgr.is_mla_backend:
|
1316
|
-
self.required_prefill_response_num = (
|
1317
|
-
self.prefill_pp_size // self.kv_mgr.pp_size
|
1318
|
-
)
|
1319
|
-
else:
|
1320
|
-
self.required_prefill_response_num = (
|
1321
|
-
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1322
|
-
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
1323
|
-
|
1324
|
-
if prefill_dp_rank is not None:
|
1325
|
-
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
1326
|
-
self.prefill_dp_rank = prefill_dp_rank
|
1327
|
-
else:
|
1328
|
-
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
1329
|
-
|
1330
|
-
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
1331
|
-
self.target_dp_group = self.prefill_dp_rank
|
1332
|
-
|
1333
|
-
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
1334
|
-
self.required_prefill_response_num
|
1335
|
-
)
|
1336
|
-
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
1337
|
-
bootstrap_key = (
|
1338
|
-
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
1339
|
-
)
|
1340
|
-
|
1341
|
-
if bootstrap_key not in self.kv_mgr.connection_pool:
|
1342
|
-
bootstrap_infos = []
|
1343
|
-
for target_tp_rank in self.target_tp_ranks:
|
1344
|
-
for target_pp_rank in range(self.prefill_pp_size):
|
1345
|
-
bootstrap_info = self._get_bootstrap_info_from_server(
|
1346
|
-
target_tp_rank, self.target_dp_group, target_pp_rank
|
1347
|
-
)
|
1348
|
-
if bootstrap_info is not None:
|
1349
|
-
if self.kv_mgr.is_mla_backend:
|
1350
|
-
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1351
|
-
bootstrap_info["is_dummy"] = not bool(
|
1352
|
-
target_tp_rank == self.target_tp_rank
|
1353
|
-
or self.target_tp_rank is None
|
1354
|
-
)
|
1355
|
-
else:
|
1356
|
-
# For non-MLA: all target_tp_ranks are selected real ranks
|
1357
|
-
bootstrap_info["is_dummy"] = False
|
1358
|
-
logger.debug(
|
1359
|
-
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
1360
|
-
)
|
1361
|
-
bootstrap_infos.append(bootstrap_info)
|
1362
|
-
else:
|
1363
|
-
self.kv_mgr.record_failure(
|
1364
|
-
self.bootstrap_room,
|
1365
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
1366
|
-
)
|
1367
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1368
|
-
return
|
1369
|
-
|
1370
|
-
self.bootstrap_infos = bootstrap_infos
|
1371
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
1372
|
-
|
1373
|
-
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
1374
|
-
self._register_kv_args()
|
1375
|
-
else:
|
1376
|
-
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
1377
|
-
|
1378
|
-
assert len(self.bootstrap_infos) > 0
|
1379
1083
|
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
|
1380
1084
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
|
1381
1085
|
|
@@ -1398,29 +1102,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1398
1102
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
1399
1103
|
return None
|
1400
1104
|
|
1401
|
-
def _get_prefill_parallel_info_from_server(
|
1402
|
-
self,
|
1403
|
-
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
1404
|
-
"""Fetch the prefill parallel info from the bootstrap server."""
|
1405
|
-
try:
|
1406
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
1407
|
-
response = requests.get(url)
|
1408
|
-
if response.status_code == 200:
|
1409
|
-
prefill_parallel_info = response.json()
|
1410
|
-
return (
|
1411
|
-
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
1412
|
-
int(prefill_parallel_info["prefill_dp_size"]),
|
1413
|
-
int(prefill_parallel_info["prefill_pp_size"]),
|
1414
|
-
)
|
1415
|
-
else:
|
1416
|
-
logger.error(
|
1417
|
-
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
1418
|
-
)
|
1419
|
-
return None, None, None
|
1420
|
-
except Exception as e:
|
1421
|
-
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
1422
|
-
return None, None, None
|
1423
|
-
|
1424
1105
|
def _register_kv_args(self):
|
1425
1106
|
for bootstrap_info in self.bootstrap_infos:
|
1426
1107
|
packed_kv_data_ptrs = b"".join(
|
@@ -1452,28 +1133,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1452
1133
|
]
|
1453
1134
|
)
|
1454
1135
|
|
1455
|
-
@classmethod
|
1456
|
-
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
1457
|
-
with cls._global_lock:
|
1458
|
-
if endpoint not in cls._socket_cache:
|
1459
|
-
sock = cls._ctx.socket(zmq.PUSH)
|
1460
|
-
if is_ipv6:
|
1461
|
-
sock.setsockopt(zmq.IPV6, 1)
|
1462
|
-
sock.connect(endpoint)
|
1463
|
-
cls._socket_cache[endpoint] = sock
|
1464
|
-
cls._socket_locks[endpoint] = threading.Lock()
|
1465
|
-
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
1466
|
-
|
1467
|
-
@classmethod
|
1468
|
-
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
1469
|
-
ip_address = bootstrap_info["rank_ip"]
|
1470
|
-
port = bootstrap_info["rank_port"]
|
1471
|
-
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
1472
|
-
sock, lock = cls._connect(
|
1473
|
-
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
1474
|
-
)
|
1475
|
-
return sock, lock
|
1476
|
-
|
1477
1136
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
1478
1137
|
for bootstrap_info in self.bootstrap_infos:
|
1479
1138
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
@@ -1551,154 +1210,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1551
1210
|
self.conclude_state = KVPoll.Failed
|
1552
1211
|
|
1553
1212
|
|
1554
|
-
class MooncakeKVBootstrapServer(
|
1555
|
-
|
1556
|
-
self.host = host
|
1557
|
-
self.port = port
|
1558
|
-
self.app = web.Application()
|
1559
|
-
self.store = dict()
|
1560
|
-
self.lock = asyncio.Lock()
|
1561
|
-
self._setup_routes()
|
1562
|
-
self.pp_size = None
|
1563
|
-
self.attn_tp_size = None
|
1564
|
-
self.dp_size = None
|
1565
|
-
self.prefill_port_table: Dict[
|
1566
|
-
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
1567
|
-
] = {}
|
1568
|
-
|
1569
|
-
# Start bootstrap server
|
1570
|
-
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
1571
|
-
self.run()
|
1572
|
-
|
1573
|
-
def run(self):
|
1574
|
-
self.thread.start()
|
1575
|
-
|
1576
|
-
def _setup_routes(self):
|
1577
|
-
self.app.router.add_route("*", "/route", self._handle_route)
|
1578
|
-
self.app.router.add_get("/health", self._handle_health_check)
|
1579
|
-
|
1580
|
-
async def _handle_health_check(self, request):
|
1581
|
-
return web.Response(text="OK", status=200)
|
1582
|
-
|
1583
|
-
async def _handle_route(self, request: web.Request):
|
1584
|
-
method = request.method
|
1585
|
-
if method == "PUT":
|
1586
|
-
return await self._handle_route_put(request)
|
1587
|
-
elif method == "GET":
|
1588
|
-
return await self._handle_route_get(request)
|
1589
|
-
else:
|
1590
|
-
return web.Response(
|
1591
|
-
text="Method not allowed", status=405, content_type="application/json"
|
1592
|
-
)
|
1593
|
-
|
1594
|
-
async def _handle_route_put(self, request: web.Request):
|
1595
|
-
data = await request.json()
|
1596
|
-
role = data["role"]
|
1597
|
-
attn_tp_size = data["attn_tp_size"]
|
1598
|
-
attn_tp_rank = data["attn_tp_rank"]
|
1599
|
-
attn_dp_size = data["attn_dp_size"]
|
1600
|
-
attn_dp_rank = data["attn_dp_rank"]
|
1601
|
-
pp_size = data["pp_size"]
|
1602
|
-
pp_rank = data["pp_rank"]
|
1603
|
-
system_dp_size = data["system_dp_size"]
|
1604
|
-
system_dp_rank = data["system_dp_rank"]
|
1605
|
-
rank_ip = data["rank_ip"]
|
1606
|
-
rank_port = int(data["rank_port"])
|
1607
|
-
|
1608
|
-
if self.attn_tp_size is None:
|
1609
|
-
self.attn_tp_size = attn_tp_size
|
1610
|
-
|
1611
|
-
if self.dp_size is None:
|
1612
|
-
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
1613
|
-
|
1614
|
-
if self.pp_size is None:
|
1615
|
-
self.pp_size = pp_size
|
1616
|
-
|
1617
|
-
if role == "Prefill":
|
1618
|
-
if system_dp_size == 1:
|
1619
|
-
dp_group = attn_dp_rank
|
1620
|
-
else:
|
1621
|
-
dp_group = system_dp_rank
|
1622
|
-
|
1623
|
-
# Add lock to make sure thread-safe
|
1624
|
-
async with self.lock:
|
1625
|
-
if dp_group not in self.prefill_port_table:
|
1626
|
-
self.prefill_port_table[dp_group] = {}
|
1627
|
-
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
1628
|
-
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
1629
|
-
|
1630
|
-
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
1631
|
-
"rank_ip": rank_ip,
|
1632
|
-
"rank_port": rank_port,
|
1633
|
-
}
|
1634
|
-
logger.debug(
|
1635
|
-
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
1636
|
-
)
|
1637
|
-
|
1638
|
-
return web.Response(text="OK", status=200)
|
1639
|
-
|
1640
|
-
async def _handle_route_get(self, request: web.Request):
|
1641
|
-
engine_rank = request.query.get("engine_rank")
|
1642
|
-
target_dp_group = request.query.get("target_dp_group")
|
1643
|
-
target_pp_rank = request.query.get("target_pp_rank")
|
1644
|
-
if not engine_rank or not target_dp_group or not target_pp_rank:
|
1645
|
-
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
1646
|
-
|
1647
|
-
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
1648
|
-
if (
|
1649
|
-
int(engine_rank) == -1
|
1650
|
-
and int(target_dp_group) == -1
|
1651
|
-
and int(target_pp_rank) == -1
|
1652
|
-
):
|
1653
|
-
prefill_parallel_info = {
|
1654
|
-
"prefill_attn_tp_size": self.attn_tp_size,
|
1655
|
-
"prefill_dp_size": self.dp_size,
|
1656
|
-
"prefill_pp_size": self.pp_size,
|
1657
|
-
}
|
1658
|
-
return web.json_response(prefill_parallel_info, status=200)
|
1659
|
-
|
1660
|
-
# Find corresponding prefill info
|
1661
|
-
async with self.lock:
|
1662
|
-
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
1663
|
-
int(engine_rank)
|
1664
|
-
][int(target_pp_rank)]
|
1665
|
-
|
1666
|
-
if bootstrap_info is not None:
|
1667
|
-
return web.json_response(bootstrap_info, status=200)
|
1668
|
-
else:
|
1669
|
-
return web.Response(text="Bootstrap info not Found", status=404)
|
1670
|
-
|
1671
|
-
def _run_server(self):
|
1672
|
-
try:
|
1673
|
-
# Event Loop
|
1674
|
-
self._loop = asyncio.new_event_loop()
|
1675
|
-
asyncio.set_event_loop(self._loop)
|
1676
|
-
|
1677
|
-
access_log = None
|
1678
|
-
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
|
1679
|
-
access_log = self.app.logger
|
1680
|
-
|
1681
|
-
self._runner = web.AppRunner(self.app, access_log=access_log)
|
1682
|
-
self._loop.run_until_complete(self._runner.setup())
|
1683
|
-
|
1684
|
-
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
1685
|
-
self._loop.run_until_complete(site.start())
|
1686
|
-
self._loop.run_forever()
|
1687
|
-
except Exception as e:
|
1688
|
-
logger.error(f"Server error: {str(e)}")
|
1689
|
-
finally:
|
1690
|
-
# Cleanup
|
1691
|
-
self._loop.run_until_complete(self._runner.cleanup())
|
1692
|
-
self._loop.close()
|
1693
|
-
|
1694
|
-
def close(self):
|
1695
|
-
"""Shutdown"""
|
1696
|
-
if self._loop is not None and self._loop.is_running():
|
1697
|
-
self._loop.call_soon_threadsafe(self._loop.stop)
|
1698
|
-
logger.info("Stopping server loop...")
|
1699
|
-
|
1700
|
-
if self.thread.is_alive():
|
1701
|
-
self.thread.join(timeout=2)
|
1702
|
-
logger.info("Server thread stopped")
|
1703
|
-
|
1704
|
-
def poll(self) -> KVPoll: ...
|
1213
|
+
class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
|
1214
|
+
pass
|