sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
|
|
37
37
|
def group_concurrent_contiguous(
|
38
38
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
39
39
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
48
|
-
if src_contiguous and dst_contiguous:
|
49
|
-
current_src.append(src_indices[i])
|
50
|
-
current_dst.append(dst_indices[i])
|
51
|
-
else:
|
52
|
-
src_groups.append(current_src)
|
53
|
-
dst_groups.append(current_dst)
|
54
|
-
current_src = [src_indices[i]]
|
55
|
-
current_dst = [dst_indices[i]]
|
40
|
+
"""Vectorised NumPy implementation."""
|
41
|
+
if src_indices.size == 0:
|
42
|
+
return [], []
|
43
|
+
|
44
|
+
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
45
|
+
src_groups = np.split(src_indices, brk)
|
46
|
+
dst_groups = np.split(dst_indices, brk)
|
56
47
|
|
57
|
-
src_groups.
|
58
|
-
dst_groups.
|
48
|
+
src_groups = [g.tolist() for g in src_groups]
|
49
|
+
dst_groups = [g.tolist() for g in dst_groups]
|
59
50
|
|
60
51
|
return src_groups, dst_groups
|
61
52
|
|
@@ -77,16 +68,28 @@ class TransferInfo:
|
|
77
68
|
mooncake_session_id: str
|
78
69
|
dst_kv_indices: npt.NDArray[np.int64]
|
79
70
|
dst_aux_index: int
|
71
|
+
required_dst_info_num: int
|
72
|
+
is_dummy: bool
|
80
73
|
|
81
74
|
@classmethod
|
82
75
|
def from_zmq(cls, msg: List[bytes]):
|
76
|
+
if msg[4] == b"" and msg[5] == b"":
|
77
|
+
is_dummy = True
|
78
|
+
dst_kv_indices = np.array([], dtype=np.int64)
|
79
|
+
dst_aux_index = None
|
80
|
+
else:
|
81
|
+
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
|
82
|
+
dst_aux_index = int(msg[5].decode("ascii"))
|
83
|
+
is_dummy = False
|
83
84
|
return cls(
|
84
85
|
room=int(msg[0].decode("ascii")),
|
85
86
|
endpoint=msg[1].decode("ascii"),
|
86
87
|
dst_port=int(msg[2].decode("ascii")),
|
87
88
|
mooncake_session_id=msg[3].decode("ascii"),
|
88
|
-
dst_kv_indices=
|
89
|
-
dst_aux_index=
|
89
|
+
dst_kv_indices=dst_kv_indices,
|
90
|
+
dst_aux_index=dst_aux_index,
|
91
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
92
|
+
is_dummy=is_dummy,
|
90
93
|
)
|
91
94
|
|
92
95
|
|
@@ -117,6 +120,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
117
120
|
args: KVArgs,
|
118
121
|
disaggregation_mode: DisaggregationMode,
|
119
122
|
server_args: ServerArgs,
|
123
|
+
is_mla_backend: Optional[bool] = False,
|
120
124
|
):
|
121
125
|
self.kv_args = args
|
122
126
|
self.engine = MooncakeTransferEngine(
|
@@ -124,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
124
128
|
gpu_id=self.kv_args.gpu_id,
|
125
129
|
ib_device=self.kv_args.ib_device,
|
126
130
|
)
|
131
|
+
self.is_mla_backend = is_mla_backend
|
127
132
|
self.disaggregation_mode = disaggregation_mode
|
128
133
|
# for p/d multi node infer
|
129
134
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
@@ -141,7 +146,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
141
146
|
self.register_buffer_to_engine()
|
142
147
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
143
148
|
self.transfer_queue = queue.Queue()
|
144
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
149
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
145
150
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
146
151
|
self.start_prefill_thread()
|
147
152
|
self._register_to_bootstrap()
|
@@ -154,6 +159,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
154
159
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
155
160
|
self.start_decode_thread()
|
156
161
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
162
|
+
self.prefill_tp_size_table: Dict[str, int] = {}
|
157
163
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
158
164
|
else:
|
159
165
|
raise ValueError(
|
@@ -227,7 +233,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
227
233
|
status = future.result()
|
228
234
|
if status != 0:
|
229
235
|
# Immediate shutdown on first error (existing tasks will finish)
|
230
|
-
executor.shutdown(wait=False)
|
236
|
+
self.executor.shutdown(wait=False)
|
231
237
|
for f in futures:
|
232
238
|
f.cancel()
|
233
239
|
return status
|
@@ -259,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
259
265
|
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
260
266
|
[
|
261
267
|
str(room).encode("ascii"),
|
262
|
-
str(self.
|
268
|
+
str(self.check_status(room)).encode("ascii"),
|
263
269
|
]
|
264
270
|
)
|
265
271
|
|
@@ -273,8 +279,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
273
279
|
while True:
|
274
280
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
275
281
|
room = waiting_req_bytes[0].decode("ascii")
|
282
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
276
283
|
if room == "None":
|
277
|
-
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
278
284
|
self.decode_kv_args_table[mooncake_session_id] = (
|
279
285
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
280
286
|
)
|
@@ -282,53 +288,84 @@ class MooncakeKVManager(BaseKVManager):
|
|
282
288
|
f"Register KVArgs from {mooncake_session_id} successfully"
|
283
289
|
)
|
284
290
|
continue
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
291
|
+
else:
|
292
|
+
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
293
|
+
room = int(room)
|
294
|
+
if room not in self.transfer_infos:
|
295
|
+
self.transfer_infos[room] = {}
|
296
|
+
|
297
|
+
self.transfer_infos[room][mooncake_session_id] = (
|
298
|
+
TransferInfo.from_zmq(waiting_req_bytes)
|
299
|
+
)
|
300
|
+
# NOTE: after bootstrapping we can mark the req as waiting for input
|
301
|
+
if len(self.transfer_infos[room]) == required_dst_info_num:
|
302
|
+
self.update_status(room, KVPoll.WaitingForInput)
|
290
303
|
|
291
304
|
def transfer_thread():
|
292
305
|
# TODO: Shall we use KVPoll.Transferring state?
|
293
306
|
while True:
|
294
307
|
try:
|
295
308
|
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
309
|
+
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
|
310
|
+
polls = []
|
311
|
+
dst_ranks_infos = []
|
312
|
+
for req in reqs_to_be_processed:
|
313
|
+
if not req.is_dummy:
|
314
|
+
chunked_dst_kv_indice = req.dst_kv_indices[
|
315
|
+
kv_chunk.index_slice
|
316
|
+
]
|
317
|
+
assert len(chunked_dst_kv_indice) == len(
|
318
|
+
kv_chunk.prefill_kv_indices
|
319
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
320
|
+
|
321
|
+
ret = self.send_kvcache(
|
322
|
+
req.mooncake_session_id,
|
323
|
+
kv_chunk.prefill_kv_indices,
|
324
|
+
self.decode_kv_args_table[
|
325
|
+
req.mooncake_session_id
|
326
|
+
].dst_kv_ptrs,
|
327
|
+
chunked_dst_kv_indice,
|
328
|
+
)
|
329
|
+
if ret != 0:
|
330
|
+
self.update_status(kv_chunk.room, KVPoll.Failed)
|
331
|
+
self.sync_status_to_decode_endpoint(
|
332
|
+
req.endpoint, req.dst_port, req.room
|
333
|
+
)
|
334
|
+
continue
|
335
|
+
|
336
|
+
if kv_chunk.is_last:
|
337
|
+
# Only the last chunk we need to send the aux data
|
338
|
+
ret = self.send_aux(
|
339
|
+
req.mooncake_session_id,
|
340
|
+
kv_chunk.prefill_aux_index,
|
341
|
+
self.decode_kv_args_table[
|
342
|
+
req.mooncake_session_id
|
343
|
+
].dst_aux_ptrs,
|
344
|
+
req.dst_aux_index,
|
345
|
+
)
|
346
|
+
polls.append(True if ret == 0 else False)
|
347
|
+
dst_ranks_infos.append(
|
348
|
+
(req.endpoint, req.dst_port, req.room)
|
349
|
+
)
|
350
|
+
|
351
|
+
# Only sync status when all the dst ranks have received the kvcache
|
352
|
+
if len(polls) == req.required_dst_info_num:
|
353
|
+
self.update_status(
|
354
|
+
req.room,
|
355
|
+
KVPoll.Success if all(polls) else KVPoll.Failed,
|
356
|
+
)
|
357
|
+
for endpoint, dst_port, room in dst_ranks_infos:
|
358
|
+
self.sync_status_to_decode_endpoint(
|
359
|
+
endpoint, dst_port, room
|
360
|
+
)
|
361
|
+
else:
|
362
|
+
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
363
|
+
# Dummy request does not need to sync status to decode endpoint
|
364
|
+
if kv_chunk.is_last:
|
365
|
+
self.update_status(req.room, KVPoll.Success)
|
366
|
+
|
367
|
+
if self.check_status(kv_chunk.room) == KVPoll.Success:
|
368
|
+
self.transfer_infos.pop(kv_chunk.room)
|
332
369
|
|
333
370
|
except queue.Empty:
|
334
371
|
continue
|
@@ -345,7 +382,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
345
382
|
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
346
383
|
status = int(status.decode("ascii"))
|
347
384
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
348
|
-
self.
|
385
|
+
self.update_status(bootstrap_room, status)
|
349
386
|
|
350
387
|
threading.Thread(target=decode_thread).start()
|
351
388
|
|
@@ -369,11 +406,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
369
406
|
prefill_aux_index=aux_index,
|
370
407
|
)
|
371
408
|
)
|
372
|
-
self.
|
409
|
+
self.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
373
410
|
|
374
411
|
def check_status(self, bootstrap_room: int):
|
375
|
-
# TOOD: do we really need the poll()?
|
376
|
-
|
377
412
|
return self.request_status[bootstrap_room]
|
378
413
|
|
379
414
|
def update_status(self, bootstrap_room: int, status: KVPoll):
|
@@ -478,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
478
513
|
self.session_id = self.kv_mgr.get_session_id()
|
479
514
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
480
515
|
|
481
|
-
if not self.kv_mgr.
|
482
|
-
|
483
|
-
# both prefill role and decode role. If the decode instance does
|
484
|
-
# not enable dp_attention, then dp_attention is not enabled on the
|
485
|
-
# prefill instance as well. Therefore, we should skip questioning
|
486
|
-
# the prefill dp size to reduce bootstrap overhead.
|
487
|
-
self.prefill_dp_size = 1
|
488
|
-
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
489
|
-
self.prefill_dp_size, tp_size_per_dp_rank = (
|
516
|
+
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
517
|
+
self.prefill_tp_size, self.prefill_dp_size = (
|
490
518
|
self._get_prefill_dp_size_from_server()
|
491
519
|
)
|
492
|
-
|
493
|
-
# have different TP sizes per DP rank.
|
494
|
-
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
495
|
-
if self.prefill_dp_size is None:
|
520
|
+
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
496
521
|
logger.error(
|
497
|
-
f"Could not fetch prefill
|
522
|
+
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
498
523
|
)
|
499
524
|
else:
|
525
|
+
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
526
|
+
self.prefill_tp_size
|
527
|
+
)
|
500
528
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
501
529
|
self.prefill_dp_size
|
502
530
|
)
|
503
531
|
else:
|
532
|
+
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
533
|
+
self.bootstrap_addr
|
534
|
+
]
|
504
535
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
505
536
|
self.bootstrap_addr
|
506
537
|
]
|
507
538
|
|
508
|
-
#
|
539
|
+
# Currently, we don't allow prefill instance and decode instance to
|
540
|
+
# have different TP sizes per DP rank, except for models using MLA.
|
541
|
+
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
542
|
+
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
543
|
+
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
544
|
+
self.target_tp_rank = (
|
545
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
546
|
+
)
|
547
|
+
self.required_dst_info_num = 1
|
548
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
549
|
+
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
550
|
+
assert (
|
551
|
+
self.kv_mgr.is_mla_backend
|
552
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
553
|
+
self.target_tp_rank = (
|
554
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
555
|
+
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
556
|
+
self.required_dst_info_num = (
|
557
|
+
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
558
|
+
)
|
559
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
560
|
+
else:
|
561
|
+
assert (
|
562
|
+
self.kv_mgr.is_mla_backend
|
563
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
564
|
+
|
565
|
+
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
566
|
+
self.target_tp_ranks = [
|
567
|
+
rank
|
568
|
+
for rank in range(
|
569
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
570
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
571
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
572
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
573
|
+
)
|
574
|
+
]
|
575
|
+
|
576
|
+
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
577
|
+
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
578
|
+
# or the KVPoll will never be set correctly
|
579
|
+
self.target_tp_rank = self.target_tp_ranks[0]
|
580
|
+
self.required_dst_info_num = 1
|
581
|
+
|
509
582
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
510
|
-
|
583
|
+
|
584
|
+
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
585
|
+
bootstrap_key = (
|
586
|
+
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
587
|
+
)
|
511
588
|
|
512
589
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
513
|
-
|
514
|
-
|
515
|
-
self.
|
516
|
-
|
517
|
-
|
590
|
+
bootstrap_infos = []
|
591
|
+
for target_tp_rank in self.target_tp_ranks:
|
592
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
593
|
+
target_tp_rank,
|
594
|
+
self.target_dp_group,
|
595
|
+
)
|
596
|
+
if bootstrap_info is not None:
|
597
|
+
# NOTE: only support MLA for now: select one prefill rank as real rank
|
598
|
+
bootstrap_info["is_dummy"] = not bool(
|
599
|
+
target_tp_rank == self.target_tp_rank
|
600
|
+
or self.target_tp_rank is None
|
601
|
+
)
|
602
|
+
bootstrap_infos.append(bootstrap_info)
|
603
|
+
else:
|
604
|
+
logger.error(
|
605
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
606
|
+
)
|
607
|
+
self.bootstrap_infos = bootstrap_infos
|
608
|
+
|
609
|
+
if len(self.bootstrap_infos) == 0:
|
518
610
|
logger.error(
|
519
611
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
520
612
|
)
|
521
613
|
else:
|
522
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.
|
614
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
523
615
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
524
616
|
self._register_kv_args()
|
525
617
|
else:
|
526
|
-
self.
|
618
|
+
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
527
619
|
|
528
|
-
assert self.
|
620
|
+
assert len(self.bootstrap_infos) > 0
|
529
621
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
530
622
|
|
531
623
|
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
@@ -552,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
552
644
|
response = requests.get(url)
|
553
645
|
if response.status_code == 200:
|
554
646
|
prefill_parallel_info = response.json()
|
555
|
-
return int(prefill_parallel_info["
|
556
|
-
prefill_parallel_info["
|
647
|
+
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
648
|
+
prefill_parallel_info["prefill_dp_size"]
|
557
649
|
)
|
558
650
|
else:
|
559
651
|
logger.error(
|
@@ -565,29 +657,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
565
657
|
return None
|
566
658
|
|
567
659
|
def _register_kv_args(self):
|
568
|
-
self.
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
)
|
578
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
579
|
-
with lock:
|
580
|
-
sock.send_multipart(
|
581
|
-
[
|
582
|
-
"None".encode("ascii"),
|
583
|
-
get_local_ip_by_remote().encode("ascii"),
|
584
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
585
|
-
self.session_id.encode("ascii"),
|
586
|
-
packed_kv_data_ptrs,
|
587
|
-
packed_aux_data_ptrs,
|
588
|
-
]
|
660
|
+
for bootstrap_info in self.bootstrap_infos:
|
661
|
+
self.prefill_server_url = (
|
662
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
663
|
+
)
|
664
|
+
packed_kv_data_ptrs = b"".join(
|
665
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
666
|
+
)
|
667
|
+
packed_aux_data_ptrs = b"".join(
|
668
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
589
669
|
)
|
590
670
|
|
671
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
672
|
+
with lock:
|
673
|
+
sock.send_multipart(
|
674
|
+
[
|
675
|
+
"None".encode("ascii"),
|
676
|
+
get_local_ip_by_remote().encode("ascii"),
|
677
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
678
|
+
self.session_id.encode("ascii"),
|
679
|
+
packed_kv_data_ptrs,
|
680
|
+
packed_aux_data_ptrs,
|
681
|
+
]
|
682
|
+
)
|
683
|
+
|
591
684
|
@classmethod
|
592
685
|
def _connect(cls, endpoint: str):
|
593
686
|
with cls._global_lock:
|
@@ -599,25 +692,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
599
692
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
600
693
|
|
601
694
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
602
|
-
self.
|
603
|
-
|
604
|
-
|
605
|
-
logger.debug(
|
606
|
-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
607
|
-
)
|
608
|
-
|
609
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
610
|
-
with lock:
|
611
|
-
sock.send_multipart(
|
612
|
-
[
|
613
|
-
str(self.bootstrap_room).encode("ascii"),
|
614
|
-
get_local_ip_by_remote().encode("ascii"),
|
615
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
616
|
-
self.session_id.encode("ascii"),
|
617
|
-
kv_indices.tobytes(),
|
618
|
-
str(aux_index).encode("ascii"),
|
619
|
-
]
|
695
|
+
for bootstrap_info in self.bootstrap_infos:
|
696
|
+
self.prefill_server_url = (
|
697
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
620
698
|
)
|
699
|
+
logger.debug(
|
700
|
+
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
701
|
+
)
|
702
|
+
is_dummy = bootstrap_info["is_dummy"]
|
703
|
+
|
704
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
705
|
+
with lock:
|
706
|
+
sock.send_multipart(
|
707
|
+
[
|
708
|
+
str(self.bootstrap_room).encode("ascii"),
|
709
|
+
get_local_ip_by_remote().encode("ascii"),
|
710
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
711
|
+
self.session_id.encode("ascii"),
|
712
|
+
kv_indices.tobytes() if not is_dummy else b"",
|
713
|
+
str(aux_index).encode("ascii") if not is_dummy else b"",
|
714
|
+
str(self.required_dst_info_num).encode("ascii"),
|
715
|
+
]
|
716
|
+
)
|
621
717
|
|
622
718
|
def poll(self) -> KVPoll:
|
623
719
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
@@ -633,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
633
729
|
self.store = dict()
|
634
730
|
self.lock = asyncio.Lock()
|
635
731
|
self._setup_routes()
|
732
|
+
self.tp_size = None
|
636
733
|
self.dp_size = None
|
637
734
|
self.tp_size_per_dp_rank = None
|
638
735
|
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
@@ -667,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
667
764
|
rank_port = int(data["rank_port"])
|
668
765
|
engine_rank = int(data["engine_rank"])
|
669
766
|
|
767
|
+
if self.tp_size is None:
|
768
|
+
self.tp_size = tp_size
|
769
|
+
|
670
770
|
if self.dp_size is None:
|
671
771
|
self.dp_size = dp_size
|
672
772
|
|
@@ -702,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
702
802
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
703
803
|
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
704
804
|
prefill_parallel_info = {
|
805
|
+
"prefill_tp_size": self.tp_size,
|
705
806
|
"prefill_dp_size": self.dp_size,
|
706
|
-
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
707
807
|
}
|
708
808
|
return web.json_response(prefill_parallel_info, status=200)
|
709
809
|
|
710
810
|
# Find corresponding prefill info
|
711
|
-
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
712
|
-
|
713
811
|
async with self.lock:
|
714
812
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
715
|
-
|
813
|
+
int(engine_rank)
|
716
814
|
]
|
717
815
|
|
718
816
|
if bootstrap_info is not None:
|