sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import ctypes
6
5
  import dataclasses
7
6
  import logging
8
7
  import os
9
- import queue
10
- import socket
11
8
  import struct
12
9
  import threading
13
10
  import time
14
11
  from collections import defaultdict
15
- from functools import cache
16
- from typing import Dict, List, Optional, Tuple, Union
12
+ from typing import Dict, List, Optional, Tuple
17
13
 
18
14
  import numpy as np
19
15
  import numpy.typing as npt
20
16
  import requests
21
17
  import zmq
22
- from aiohttp import web
23
-
24
- from sglang.srt.disaggregation.base.conn import (
25
- BaseKVBootstrapServer,
26
- BaseKVManager,
27
- BaseKVReceiver,
28
- BaseKVSender,
29
- KVArgs,
30
- KVPoll,
18
+
19
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
20
+ from sglang.srt.disaggregation.common.conn import (
21
+ CommonKVBootstrapServer,
22
+ CommonKVManager,
23
+ CommonKVReceiver,
24
+ CommonKVSender,
31
25
  )
32
26
  from sglang.srt.disaggregation.common.utils import (
33
27
  FastQueue,
@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
35
29
  )
36
30
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
37
31
  from sglang.srt.disaggregation.utils import DisaggregationMode
38
- from sglang.srt.distributed import get_pp_group
39
- from sglang.srt.layers.dp_attention import (
40
- get_attention_dp_rank,
41
- get_attention_dp_size,
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- )
45
32
  from sglang.srt.server_args import ServerArgs
46
33
  from sglang.srt.utils import (
47
34
  format_tcp_address,
48
35
  get_bool_env_var,
49
- get_free_port,
50
36
  get_int_env_var,
51
- get_ip,
52
- get_local_ip_auto,
53
37
  is_valid_ipv6_address,
54
- maybe_wrap_ipv6_address,
55
38
  )
56
39
 
57
40
  logger = logging.getLogger(__name__)
@@ -159,7 +142,7 @@ class AuxDataCodec:
159
142
  return
160
143
 
161
144
 
162
- class MooncakeKVManager(BaseKVManager):
145
+ class MooncakeKVManager(CommonKVManager):
163
146
  AUX_DATA_HEADER = b"AUX_DATA"
164
147
 
165
148
  def __init__(
@@ -169,43 +152,14 @@ class MooncakeKVManager(BaseKVManager):
169
152
  server_args: ServerArgs,
170
153
  is_mla_backend: Optional[bool] = False,
171
154
  ):
172
- self.kv_args = args
173
- self.local_ip = get_local_ip_auto()
174
- self.is_mla_backend = is_mla_backend
175
- self.disaggregation_mode = disaggregation_mode
155
+ super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
176
156
  self.init_engine()
177
- # for p/d multi node infer
178
- self.bootstrap_host = server_args.host
179
- self.bootstrap_port = server_args.disaggregation_bootstrap_port
180
- self.dist_init_addr = server_args.dist_init_addr
181
- self.attn_tp_size = get_attention_tp_size()
182
- self.attn_tp_rank = get_attention_tp_rank()
183
- self.attn_dp_size = get_attention_dp_size()
184
- self.attn_dp_rank = get_attention_dp_rank()
185
- self.system_dp_size = (
186
- 1 if server_args.enable_dp_attention else server_args.dp_size
187
- )
188
- self.system_dp_rank = (
189
- self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
190
- )
191
- self.pp_size = server_args.pp_size
192
- self.pp_rank = self.kv_args.pp_rank
193
- self.request_status: Dict[int, KVPoll] = {}
194
- self.rank_port = None
195
- self.server_socket = zmq.Context().socket(zmq.PULL)
196
- if is_valid_ipv6_address(self.local_ip):
197
- self.server_socket.setsockopt(zmq.IPV6, 1)
198
-
199
157
  self.register_buffer_to_engine()
200
158
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
201
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
202
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
203
159
  self.start_prefill_thread()
204
- self._register_to_bootstrap()
205
160
  self.session_failures = defaultdict(int)
206
161
  self.failed_sessions = set()
207
162
  self.session_lock = threading.Lock()
208
- self.pp_group = get_pp_group()
209
163
  # Determine the number of threads to use for kv sender
210
164
  cpu_count = os.cpu_count()
211
165
  transfer_thread_pool_size = get_int_env_var(
@@ -245,8 +199,6 @@ class MooncakeKVManager(BaseKVManager):
245
199
  self.session_pool = defaultdict(requests.Session)
246
200
  self.session_pool_lock = threading.Lock()
247
201
  self.addr_to_rooms_tracker = defaultdict(set)
248
- self.connection_lock = threading.Lock()
249
- self.required_prefill_response_num_table: Dict[int, int] = {}
250
202
  self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
251
203
  # Heartbeat interval should be at least 2 seconds
252
204
  self.heartbeat_interval = max(
@@ -257,20 +209,12 @@ class MooncakeKVManager(BaseKVManager):
257
209
  get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
258
210
  )
259
211
  self.start_decode_thread()
260
- self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
261
- self.prefill_attn_tp_size_table: Dict[str, int] = {}
262
- self.prefill_dp_size_table: Dict[str, int] = {}
263
- self.prefill_pp_size_table: Dict[str, int] = {}
264
212
  # If a timeout happens on the decode side, it means decode instances
265
213
  # fail to receive the KV Cache transfer done signal after bootstrapping.
266
214
  # These timeout requests should be aborted to release the tree cache.
267
215
  self.waiting_timeout = get_int_env_var(
268
216
  "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
269
217
  )
270
- else:
271
- raise ValueError(
272
- f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
273
- )
274
218
 
275
219
  self.failure_records: Dict[int, str] = {}
276
220
  self.failure_lock = threading.Lock()
@@ -295,14 +239,6 @@ class MooncakeKVManager(BaseKVManager):
295
239
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
296
240
  )
297
241
 
298
- @cache
299
- def _connect(self, endpoint: str, is_ipv6: bool = False):
300
- socket = zmq.Context().socket(zmq.PUSH)
301
- if is_ipv6:
302
- socket.setsockopt(zmq.IPV6, 1)
303
- socket.connect(endpoint)
304
- return socket
305
-
306
242
  def _transfer_data(self, mooncake_session_id, transfer_blocks):
307
243
  if not transfer_blocks:
308
244
  return 0
@@ -328,12 +264,10 @@ class MooncakeKVManager(BaseKVManager):
328
264
  layers_params = None
329
265
 
330
266
  # pp is not supported on the decode side yet
331
- start_layer = self.kv_args.prefill_start_layer
332
- end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
333
267
  if self.is_mla_backend:
334
- src_kv_ptrs = self.kv_args.kv_data_ptrs
335
- layers_per_pp_stage = len(src_kv_ptrs)
336
- dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
268
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
269
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
270
+ )
337
271
  kv_item_len = self.kv_args.kv_item_lens[0]
338
272
  layers_params = [
339
273
  (
@@ -341,18 +275,12 @@ class MooncakeKVManager(BaseKVManager):
341
275
  dst_kv_ptrs[layer_id],
342
276
  kv_item_len,
343
277
  )
344
- for layer_id in range(layers_per_pp_stage)
278
+ for layer_id in range(layers_current_pp_stage)
345
279
  ]
346
280
  else:
347
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
348
- dst_num_total_layers = num_kv_layers * self.pp_size
349
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
350
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
351
- layers_per_pp_stage = len(src_k_ptrs)
352
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
353
- dst_v_ptrs = dst_kv_ptrs[
354
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
355
- ]
281
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
282
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
283
+ )
356
284
  kv_item_len = self.kv_args.kv_item_lens[0]
357
285
  layers_params = [
358
286
  (
@@ -360,14 +288,14 @@ class MooncakeKVManager(BaseKVManager):
360
288
  dst_k_ptrs[layer_id],
361
289
  kv_item_len,
362
290
  )
363
- for layer_id in range(layers_per_pp_stage)
291
+ for layer_id in range(layers_current_pp_stage)
364
292
  ] + [
365
293
  (
366
294
  src_v_ptrs[layer_id],
367
295
  dst_v_ptrs[layer_id],
368
296
  kv_item_len,
369
297
  )
370
- for layer_id in range(layers_per_pp_stage)
298
+ for layer_id in range(layers_current_pp_stage)
371
299
  ]
372
300
  assert layers_params is not None
373
301
 
@@ -465,18 +393,9 @@ class MooncakeKVManager(BaseKVManager):
465
393
  num_heads_to_send = dst_heads_per_rank
466
394
  dst_head_start_offset = 0
467
395
 
468
- # pp is not supported on the decode side yet
469
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
470
- dst_num_total_layers = num_kv_layers * self.pp_size
471
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
472
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
473
- layers_per_pp_stage = len(src_k_ptrs)
474
- start_layer = self.pp_rank * layers_per_pp_stage
475
- end_layer = start_layer + layers_per_pp_stage
476
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
477
- dst_v_ptrs = dst_kv_ptrs[
478
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
479
- ]
396
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
397
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
398
+ )
480
399
 
481
400
  # Calculate precise byte offset and length for the sub-slice within the token
482
401
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -502,7 +421,7 @@ class MooncakeKVManager(BaseKVManager):
502
421
  dst_head_slice_offset,
503
422
  heads_bytes_per_token_to_send,
504
423
  )
505
- for layer_id in range(layers_per_pp_stage)
424
+ for layer_id in range(layers_current_pp_stage)
506
425
  ] + [
507
426
  (
508
427
  src_v_ptrs[layer_id],
@@ -513,7 +432,7 @@ class MooncakeKVManager(BaseKVManager):
513
432
  dst_head_slice_offset,
514
433
  heads_bytes_per_token_to_send,
515
434
  )
516
- for layer_id in range(layers_per_pp_stage)
435
+ for layer_id in range(layers_current_pp_stage)
517
436
  ]
518
437
 
519
438
  def process_layer_tp_aware(layer_params):
@@ -654,6 +573,26 @@ class MooncakeKVManager(BaseKVManager):
654
573
  ]
655
574
  )
656
575
 
576
+ def _handle_aux_data(self, msg: List[bytes]):
577
+ """Handle AUX_DATA messages received by the decode thread."""
578
+ room = int(msg[1].decode("ascii"))
579
+ buffer_index = int(msg[2].decode("ascii"))
580
+ aux_index = int(msg[3].decode("ascii"))
581
+ data_length = struct.unpack(">I", msg[4])[0]
582
+ data = msg[5]
583
+
584
+ if len(data) != data_length:
585
+ logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
586
+ return
587
+
588
+ AuxDataCodec.deserialize_data_to_buffer(
589
+ self.kv_args, buffer_index, aux_index, data
590
+ )
591
+
592
+ logger.debug(
593
+ f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
594
+ )
595
+
657
596
  def sync_status_to_decode_endpoint(
658
597
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
659
598
  ):
@@ -802,11 +741,7 @@ class MooncakeKVManager(BaseKVManager):
802
741
  f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
803
742
  )
804
743
 
805
- def _bind_server_socket(self):
806
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
807
-
808
744
  def start_prefill_thread(self):
809
- self.rank_port = get_free_port()
810
745
  self._bind_server_socket()
811
746
 
812
747
  def bootstrap_thread():
@@ -844,28 +779,7 @@ class MooncakeKVManager(BaseKVManager):
844
779
 
845
780
  threading.Thread(target=bootstrap_thread).start()
846
781
 
847
- def _handle_aux_data(self, msg: List[bytes]):
848
- """Handle AUX_DATA messages received by the decode thread."""
849
- room = int(msg[1].decode("ascii"))
850
- buffer_index = int(msg[2].decode("ascii"))
851
- aux_index = int(msg[3].decode("ascii"))
852
- data_length = struct.unpack(">I", msg[4])[0]
853
- data = msg[5]
854
-
855
- if len(data) != data_length:
856
- logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
857
- return
858
-
859
- AuxDataCodec.deserialize_data_to_buffer(
860
- self.kv_args, buffer_index, aux_index, data
861
- )
862
-
863
- logger.debug(
864
- f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
865
- )
866
-
867
782
  def start_decode_thread(self):
868
- self.rank_port = get_free_port()
869
783
  self._bind_server_socket()
870
784
 
871
785
  def decode_thread():
@@ -1020,51 +934,6 @@ class MooncakeKVManager(BaseKVManager):
1020
934
  def get_session_id(self):
1021
935
  return self.engine.get_session_id()
1022
936
 
1023
- def _register_to_bootstrap(self):
1024
- """Register KVSender to bootstrap server via HTTP POST."""
1025
- if self.dist_init_addr:
1026
- # multi node case: bootstrap server's host is dist_init_addr
1027
- if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
1028
- if self.dist_init_addr.endswith("]"):
1029
- host = self.dist_init_addr
1030
- else:
1031
- host, _ = self.dist_init_addr.rsplit(":", 1)
1032
- else:
1033
- host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
1034
- else:
1035
- # single node case: bootstrap server's host is same as http server's host
1036
- host = self.bootstrap_host
1037
- host = maybe_wrap_ipv6_address(host)
1038
-
1039
- bootstrap_server_url = f"{host}:{self.bootstrap_port}"
1040
- url = f"http://{bootstrap_server_url}/route"
1041
- payload = {
1042
- "role": "Prefill",
1043
- "attn_tp_size": self.attn_tp_size,
1044
- "attn_tp_rank": self.attn_tp_rank,
1045
- "attn_dp_size": self.attn_dp_size,
1046
- "attn_dp_rank": self.attn_dp_rank,
1047
- "pp_size": self.pp_size,
1048
- "pp_rank": self.pp_rank,
1049
- "system_dp_size": self.system_dp_size,
1050
- "system_dp_rank": self.system_dp_rank,
1051
- "rank_ip": self.local_ip,
1052
- "rank_port": self.rank_port,
1053
- }
1054
-
1055
- try:
1056
- response = requests.put(url, json=payload, timeout=5)
1057
- if response.status_code == 200:
1058
- logger.debug("Prefill successfully registered to bootstrap server.")
1059
- else:
1060
- logger.error(
1061
- f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
1062
- )
1063
- except Exception as e:
1064
- logger.error(
1065
- f"Prefill instance failed to register to bootstrap server: {e}"
1066
- )
1067
-
1068
937
  def _handle_node_failure(self, failed_bootstrap_addr):
1069
938
  with self.connection_lock:
1070
939
  keys_to_remove = [
@@ -1103,7 +972,7 @@ class MooncakeKVManager(BaseKVManager):
1103
972
  )
1104
973
 
1105
974
 
1106
- class MooncakeKVSender(BaseKVSender):
975
+ class MooncakeKVSender(CommonKVSender):
1107
976
 
1108
977
  def __init__(
1109
978
  self,
@@ -1113,19 +982,9 @@ class MooncakeKVSender(BaseKVSender):
1113
982
  dest_tp_ranks: List[int],
1114
983
  pp_rank: int,
1115
984
  ):
1116
- self.kv_mgr = mgr
1117
- self.bootstrap_room = bootstrap_room
1118
- self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
1119
- self.aux_index = None
1120
- self.bootstrap_server_url = bootstrap_addr
985
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
1121
986
  self.conclude_state = None
1122
987
  self.init_time = time.time()
1123
- # inner state
1124
- self.curr_idx = 0
1125
-
1126
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
1127
- self.num_kv_indices = num_kv_indices
1128
- self.aux_index = aux_index
1129
988
 
1130
989
  def send(
1131
990
  self,
@@ -1203,7 +1062,7 @@ class MooncakeKVSender(BaseKVSender):
1203
1062
  self.conclude_state = KVPoll.Failed
1204
1063
 
1205
1064
 
1206
- class MooncakeKVReceiver(BaseKVReceiver):
1065
+ class MooncakeKVReceiver(CommonKVReceiver):
1207
1066
  _ctx = zmq.Context()
1208
1067
  _socket_cache = {}
1209
1068
  _socket_locks = {}
@@ -1216,166 +1075,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
1216
1075
  bootstrap_room: Optional[int] = None,
1217
1076
  prefill_dp_rank: Optional[int] = None,
1218
1077
  ):
1219
- self.bootstrap_room = bootstrap_room
1220
- self.bootstrap_addr = bootstrap_addr
1221
- self.kv_mgr = mgr
1222
- self.session_id = self.kv_mgr.get_session_id()
1223
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
1078
+ self.session_id = mgr.get_session_id()
1224
1079
  self.conclude_state = None
1225
1080
  self.init_time = None
1081
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
1226
1082
 
1227
- if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1228
- (
1229
- self.prefill_attn_tp_size,
1230
- self.prefill_dp_size,
1231
- self.prefill_pp_size,
1232
- ) = self._get_prefill_parallel_info_from_server()
1233
- if (
1234
- self.prefill_attn_tp_size is None
1235
- or self.prefill_dp_size is None
1236
- or self.prefill_pp_size is None
1237
- ):
1238
- self.kv_mgr.record_failure(
1239
- self.bootstrap_room,
1240
- f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
1241
- )
1242
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1243
- return
1244
- else:
1245
- logger.debug(
1246
- f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
1247
- )
1248
- self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
1249
- self.prefill_attn_tp_size
1250
- )
1251
- self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
1252
- self.prefill_dp_size
1253
- )
1254
- self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
1255
- self.prefill_pp_size
1256
- )
1257
- else:
1258
- self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
1259
- self.bootstrap_addr
1260
- ]
1261
- self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
1262
- self.bootstrap_addr
1263
- ]
1264
- self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
1265
- self.bootstrap_addr
1266
- ]
1267
-
1268
- # Currently, we don't allow prefill instance and decode instance to
1269
- # have different TP sizes per DP rank, except for models using MLA.
1270
- if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
1271
- self.target_tp_rank = (
1272
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1273
- )
1274
- self.required_dst_info_num = 1
1275
- self.required_prefill_response_num = 1 * (
1276
- self.prefill_pp_size // self.kv_mgr.pp_size
1277
- )
1278
- self.target_tp_ranks = [self.target_tp_rank]
1279
- elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
1280
- if not self.kv_mgr.is_mla_backend:
1281
- logger.warning_once(
1282
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1283
- )
1284
- self.target_tp_rank = (
1285
- self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
1286
- ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
1287
- self.required_dst_info_num = (
1288
- self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
1289
- )
1290
- self.required_prefill_response_num = 1 * (
1291
- self.prefill_pp_size // self.kv_mgr.pp_size
1292
- )
1293
- self.target_tp_ranks = [self.target_tp_rank]
1294
- else:
1295
- if not self.kv_mgr.is_mla_backend:
1296
- logger.warning_once(
1297
- "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1298
- )
1299
- # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
1300
- self.target_tp_ranks = [
1301
- rank
1302
- for rank in range(
1303
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
1304
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1305
- (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
1306
- * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
1307
- )
1308
- ]
1309
-
1310
- # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
1311
- # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
1312
- # or the KVPoll will never be set correctly
1313
- self.target_tp_rank = self.target_tp_ranks[0]
1314
- self.required_dst_info_num = 1
1315
- if self.kv_mgr.is_mla_backend:
1316
- self.required_prefill_response_num = (
1317
- self.prefill_pp_size // self.kv_mgr.pp_size
1318
- )
1319
- else:
1320
- self.required_prefill_response_num = (
1321
- self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1322
- ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
1323
-
1324
- if prefill_dp_rank is not None:
1325
- logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
1326
- self.prefill_dp_rank = prefill_dp_rank
1327
- else:
1328
- self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
1329
-
1330
- # FIXME: alias here: target_dp_group -> prefill_dp_rank
1331
- self.target_dp_group = self.prefill_dp_rank
1332
-
1333
- self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1334
- self.required_prefill_response_num
1335
- )
1336
- # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
1337
- bootstrap_key = (
1338
- f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
1339
- )
1340
-
1341
- if bootstrap_key not in self.kv_mgr.connection_pool:
1342
- bootstrap_infos = []
1343
- for target_tp_rank in self.target_tp_ranks:
1344
- for target_pp_rank in range(self.prefill_pp_size):
1345
- bootstrap_info = self._get_bootstrap_info_from_server(
1346
- target_tp_rank, self.target_dp_group, target_pp_rank
1347
- )
1348
- if bootstrap_info is not None:
1349
- if self.kv_mgr.is_mla_backend:
1350
- # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1351
- bootstrap_info["is_dummy"] = not bool(
1352
- target_tp_rank == self.target_tp_rank
1353
- or self.target_tp_rank is None
1354
- )
1355
- else:
1356
- # For non-MLA: all target_tp_ranks are selected real ranks
1357
- bootstrap_info["is_dummy"] = False
1358
- logger.debug(
1359
- f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
1360
- )
1361
- bootstrap_infos.append(bootstrap_info)
1362
- else:
1363
- self.kv_mgr.record_failure(
1364
- self.bootstrap_room,
1365
- f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
1366
- )
1367
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
1368
- return
1369
-
1370
- self.bootstrap_infos = bootstrap_infos
1371
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
1372
-
1373
- # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
1374
- self._register_kv_args()
1375
- else:
1376
- self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
1377
-
1378
- assert len(self.bootstrap_infos) > 0
1379
1083
  self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1380
1084
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1381
1085
 
@@ -1398,29 +1102,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1398
1102
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
1399
1103
  return None
1400
1104
 
1401
- def _get_prefill_parallel_info_from_server(
1402
- self,
1403
- ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
1404
- """Fetch the prefill parallel info from the bootstrap server."""
1405
- try:
1406
- url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
1407
- response = requests.get(url)
1408
- if response.status_code == 200:
1409
- prefill_parallel_info = response.json()
1410
- return (
1411
- int(prefill_parallel_info["prefill_attn_tp_size"]),
1412
- int(prefill_parallel_info["prefill_dp_size"]),
1413
- int(prefill_parallel_info["prefill_pp_size"]),
1414
- )
1415
- else:
1416
- logger.error(
1417
- f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
1418
- )
1419
- return None, None, None
1420
- except Exception as e:
1421
- logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
1422
- return None, None, None
1423
-
1424
1105
  def _register_kv_args(self):
1425
1106
  for bootstrap_info in self.bootstrap_infos:
1426
1107
  packed_kv_data_ptrs = b"".join(
@@ -1452,28 +1133,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1452
1133
  ]
1453
1134
  )
1454
1135
 
1455
- @classmethod
1456
- def _connect(cls, endpoint: str, is_ipv6: bool = False):
1457
- with cls._global_lock:
1458
- if endpoint not in cls._socket_cache:
1459
- sock = cls._ctx.socket(zmq.PUSH)
1460
- if is_ipv6:
1461
- sock.setsockopt(zmq.IPV6, 1)
1462
- sock.connect(endpoint)
1463
- cls._socket_cache[endpoint] = sock
1464
- cls._socket_locks[endpoint] = threading.Lock()
1465
- return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
1466
-
1467
- @classmethod
1468
- def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
1469
- ip_address = bootstrap_info["rank_ip"]
1470
- port = bootstrap_info["rank_port"]
1471
- is_ipv6_address = is_valid_ipv6_address(ip_address)
1472
- sock, lock = cls._connect(
1473
- format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
1474
- )
1475
- return sock, lock
1476
-
1477
1136
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1478
1137
  for bootstrap_info in self.bootstrap_infos:
1479
1138
  sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
@@ -1551,154 +1210,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
1551
1210
  self.conclude_state = KVPoll.Failed
1552
1211
 
1553
1212
 
1554
- class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1555
- def __init__(self, host: str, port: int):
1556
- self.host = host
1557
- self.port = port
1558
- self.app = web.Application()
1559
- self.store = dict()
1560
- self.lock = asyncio.Lock()
1561
- self._setup_routes()
1562
- self.pp_size = None
1563
- self.attn_tp_size = None
1564
- self.dp_size = None
1565
- self.prefill_port_table: Dict[
1566
- int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
1567
- ] = {}
1568
-
1569
- # Start bootstrap server
1570
- self.thread = threading.Thread(target=self._run_server, daemon=True)
1571
- self.run()
1572
-
1573
- def run(self):
1574
- self.thread.start()
1575
-
1576
- def _setup_routes(self):
1577
- self.app.router.add_route("*", "/route", self._handle_route)
1578
- self.app.router.add_get("/health", self._handle_health_check)
1579
-
1580
- async def _handle_health_check(self, request):
1581
- return web.Response(text="OK", status=200)
1582
-
1583
- async def _handle_route(self, request: web.Request):
1584
- method = request.method
1585
- if method == "PUT":
1586
- return await self._handle_route_put(request)
1587
- elif method == "GET":
1588
- return await self._handle_route_get(request)
1589
- else:
1590
- return web.Response(
1591
- text="Method not allowed", status=405, content_type="application/json"
1592
- )
1593
-
1594
- async def _handle_route_put(self, request: web.Request):
1595
- data = await request.json()
1596
- role = data["role"]
1597
- attn_tp_size = data["attn_tp_size"]
1598
- attn_tp_rank = data["attn_tp_rank"]
1599
- attn_dp_size = data["attn_dp_size"]
1600
- attn_dp_rank = data["attn_dp_rank"]
1601
- pp_size = data["pp_size"]
1602
- pp_rank = data["pp_rank"]
1603
- system_dp_size = data["system_dp_size"]
1604
- system_dp_rank = data["system_dp_rank"]
1605
- rank_ip = data["rank_ip"]
1606
- rank_port = int(data["rank_port"])
1607
-
1608
- if self.attn_tp_size is None:
1609
- self.attn_tp_size = attn_tp_size
1610
-
1611
- if self.dp_size is None:
1612
- self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
1613
-
1614
- if self.pp_size is None:
1615
- self.pp_size = pp_size
1616
-
1617
- if role == "Prefill":
1618
- if system_dp_size == 1:
1619
- dp_group = attn_dp_rank
1620
- else:
1621
- dp_group = system_dp_rank
1622
-
1623
- # Add lock to make sure thread-safe
1624
- async with self.lock:
1625
- if dp_group not in self.prefill_port_table:
1626
- self.prefill_port_table[dp_group] = {}
1627
- if attn_tp_rank not in self.prefill_port_table[dp_group]:
1628
- self.prefill_port_table[dp_group][attn_tp_rank] = {}
1629
-
1630
- self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
1631
- "rank_ip": rank_ip,
1632
- "rank_port": rank_port,
1633
- }
1634
- logger.debug(
1635
- f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1636
- )
1637
-
1638
- return web.Response(text="OK", status=200)
1639
-
1640
- async def _handle_route_get(self, request: web.Request):
1641
- engine_rank = request.query.get("engine_rank")
1642
- target_dp_group = request.query.get("target_dp_group")
1643
- target_pp_rank = request.query.get("target_pp_rank")
1644
- if not engine_rank or not target_dp_group or not target_pp_rank:
1645
- return web.Response(text="Missing inputs for bootstrap server.", status=400)
1646
-
1647
- # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
1648
- if (
1649
- int(engine_rank) == -1
1650
- and int(target_dp_group) == -1
1651
- and int(target_pp_rank) == -1
1652
- ):
1653
- prefill_parallel_info = {
1654
- "prefill_attn_tp_size": self.attn_tp_size,
1655
- "prefill_dp_size": self.dp_size,
1656
- "prefill_pp_size": self.pp_size,
1657
- }
1658
- return web.json_response(prefill_parallel_info, status=200)
1659
-
1660
- # Find corresponding prefill info
1661
- async with self.lock:
1662
- bootstrap_info = self.prefill_port_table[int(target_dp_group)][
1663
- int(engine_rank)
1664
- ][int(target_pp_rank)]
1665
-
1666
- if bootstrap_info is not None:
1667
- return web.json_response(bootstrap_info, status=200)
1668
- else:
1669
- return web.Response(text="Bootstrap info not Found", status=404)
1670
-
1671
- def _run_server(self):
1672
- try:
1673
- # Event Loop
1674
- self._loop = asyncio.new_event_loop()
1675
- asyncio.set_event_loop(self._loop)
1676
-
1677
- access_log = None
1678
- if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
1679
- access_log = self.app.logger
1680
-
1681
- self._runner = web.AppRunner(self.app, access_log=access_log)
1682
- self._loop.run_until_complete(self._runner.setup())
1683
-
1684
- site = web.TCPSite(self._runner, host=self.host, port=self.port)
1685
- self._loop.run_until_complete(site.start())
1686
- self._loop.run_forever()
1687
- except Exception as e:
1688
- logger.error(f"Server error: {str(e)}")
1689
- finally:
1690
- # Cleanup
1691
- self._loop.run_until_complete(self._runner.cleanup())
1692
- self._loop.close()
1693
-
1694
- def close(self):
1695
- """Shutdown"""
1696
- if self._loop is not None and self._loop.is_running():
1697
- self._loop.call_soon_threadsafe(self._loop.stop)
1698
- logger.info("Stopping server loop...")
1699
-
1700
- if self.thread.is_alive():
1701
- self.thread.join(timeout=2)
1702
- logger.info("Server thread stopped")
1703
-
1704
- def poll(self) -> KVPoll: ...
1213
+ class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
1214
+ pass