sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -68,16 +68,28 @@ class TransferInfo:
|
|
68
68
|
mooncake_session_id: str
|
69
69
|
dst_kv_indices: npt.NDArray[np.int64]
|
70
70
|
dst_aux_index: int
|
71
|
+
required_dst_info_num: int
|
72
|
+
is_dummy: bool
|
71
73
|
|
72
74
|
@classmethod
|
73
75
|
def from_zmq(cls, msg: List[bytes]):
|
76
|
+
if msg[4] == b"" and msg[5] == b"":
|
77
|
+
is_dummy = True
|
78
|
+
dst_kv_indices = np.array([], dtype=np.int64)
|
79
|
+
dst_aux_index = None
|
80
|
+
else:
|
81
|
+
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
|
82
|
+
dst_aux_index = int(msg[5].decode("ascii"))
|
83
|
+
is_dummy = False
|
74
84
|
return cls(
|
75
85
|
room=int(msg[0].decode("ascii")),
|
76
86
|
endpoint=msg[1].decode("ascii"),
|
77
87
|
dst_port=int(msg[2].decode("ascii")),
|
78
88
|
mooncake_session_id=msg[3].decode("ascii"),
|
79
|
-
dst_kv_indices=
|
80
|
-
dst_aux_index=
|
89
|
+
dst_kv_indices=dst_kv_indices,
|
90
|
+
dst_aux_index=dst_aux_index,
|
91
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
92
|
+
is_dummy=is_dummy,
|
81
93
|
)
|
82
94
|
|
83
95
|
|
@@ -108,6 +120,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
108
120
|
args: KVArgs,
|
109
121
|
disaggregation_mode: DisaggregationMode,
|
110
122
|
server_args: ServerArgs,
|
123
|
+
is_mla_backend: Optional[bool] = False,
|
111
124
|
):
|
112
125
|
self.kv_args = args
|
113
126
|
self.engine = MooncakeTransferEngine(
|
@@ -115,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
115
128
|
gpu_id=self.kv_args.gpu_id,
|
116
129
|
ib_device=self.kv_args.ib_device,
|
117
130
|
)
|
131
|
+
self.is_mla_backend = is_mla_backend
|
118
132
|
self.disaggregation_mode = disaggregation_mode
|
119
133
|
# for p/d multi node infer
|
120
134
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
@@ -132,7 +146,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
132
146
|
self.register_buffer_to_engine()
|
133
147
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
134
148
|
self.transfer_queue = queue.Queue()
|
135
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
149
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
136
150
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
137
151
|
self.start_prefill_thread()
|
138
152
|
self._register_to_bootstrap()
|
@@ -145,6 +159,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
145
159
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
146
160
|
self.start_decode_thread()
|
147
161
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
162
|
+
self.prefill_tp_size_table: Dict[str, int] = {}
|
148
163
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
149
164
|
else:
|
150
165
|
raise ValueError(
|
@@ -218,7 +233,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
218
233
|
status = future.result()
|
219
234
|
if status != 0:
|
220
235
|
# Immediate shutdown on first error (existing tasks will finish)
|
221
|
-
executor.shutdown(wait=False)
|
236
|
+
self.executor.shutdown(wait=False)
|
222
237
|
for f in futures:
|
223
238
|
f.cancel()
|
224
239
|
return status
|
@@ -250,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
250
265
|
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
251
266
|
[
|
252
267
|
str(room).encode("ascii"),
|
253
|
-
str(self.
|
268
|
+
str(self.check_status(room)).encode("ascii"),
|
254
269
|
]
|
255
270
|
)
|
256
271
|
|
@@ -264,8 +279,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
264
279
|
while True:
|
265
280
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
266
281
|
room = waiting_req_bytes[0].decode("ascii")
|
282
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
267
283
|
if room == "None":
|
268
|
-
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
269
284
|
self.decode_kv_args_table[mooncake_session_id] = (
|
270
285
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
271
286
|
)
|
@@ -273,53 +288,84 @@ class MooncakeKVManager(BaseKVManager):
|
|
273
288
|
f"Register KVArgs from {mooncake_session_id} successfully"
|
274
289
|
)
|
275
290
|
continue
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
291
|
+
else:
|
292
|
+
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
293
|
+
room = int(room)
|
294
|
+
if room not in self.transfer_infos:
|
295
|
+
self.transfer_infos[room] = {}
|
296
|
+
|
297
|
+
self.transfer_infos[room][mooncake_session_id] = (
|
298
|
+
TransferInfo.from_zmq(waiting_req_bytes)
|
299
|
+
)
|
300
|
+
# NOTE: after bootstrapping we can mark the req as waiting for input
|
301
|
+
if len(self.transfer_infos[room]) == required_dst_info_num:
|
302
|
+
self.update_status(room, KVPoll.WaitingForInput)
|
281
303
|
|
282
304
|
def transfer_thread():
|
283
305
|
# TODO: Shall we use KVPoll.Transferring state?
|
284
306
|
while True:
|
285
307
|
try:
|
286
308
|
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
309
|
+
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
|
310
|
+
polls = []
|
311
|
+
dst_ranks_infos = []
|
312
|
+
for req in reqs_to_be_processed:
|
313
|
+
if not req.is_dummy:
|
314
|
+
chunked_dst_kv_indice = req.dst_kv_indices[
|
315
|
+
kv_chunk.index_slice
|
316
|
+
]
|
317
|
+
assert len(chunked_dst_kv_indice) == len(
|
318
|
+
kv_chunk.prefill_kv_indices
|
319
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
320
|
+
|
321
|
+
ret = self.send_kvcache(
|
322
|
+
req.mooncake_session_id,
|
323
|
+
kv_chunk.prefill_kv_indices,
|
324
|
+
self.decode_kv_args_table[
|
325
|
+
req.mooncake_session_id
|
326
|
+
].dst_kv_ptrs,
|
327
|
+
chunked_dst_kv_indice,
|
328
|
+
)
|
329
|
+
if ret != 0:
|
330
|
+
self.update_status(kv_chunk.room, KVPoll.Failed)
|
331
|
+
self.sync_status_to_decode_endpoint(
|
332
|
+
req.endpoint, req.dst_port, req.room
|
333
|
+
)
|
334
|
+
continue
|
335
|
+
|
336
|
+
if kv_chunk.is_last:
|
337
|
+
# Only the last chunk we need to send the aux data
|
338
|
+
ret = self.send_aux(
|
339
|
+
req.mooncake_session_id,
|
340
|
+
kv_chunk.prefill_aux_index,
|
341
|
+
self.decode_kv_args_table[
|
342
|
+
req.mooncake_session_id
|
343
|
+
].dst_aux_ptrs,
|
344
|
+
req.dst_aux_index,
|
345
|
+
)
|
346
|
+
polls.append(True if ret == 0 else False)
|
347
|
+
dst_ranks_infos.append(
|
348
|
+
(req.endpoint, req.dst_port, req.room)
|
349
|
+
)
|
350
|
+
|
351
|
+
# Only sync status when all the dst ranks have received the kvcache
|
352
|
+
if len(polls) == req.required_dst_info_num:
|
353
|
+
self.update_status(
|
354
|
+
req.room,
|
355
|
+
KVPoll.Success if all(polls) else KVPoll.Failed,
|
356
|
+
)
|
357
|
+
for endpoint, dst_port, room in dst_ranks_infos:
|
358
|
+
self.sync_status_to_decode_endpoint(
|
359
|
+
endpoint, dst_port, room
|
360
|
+
)
|
361
|
+
else:
|
362
|
+
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
363
|
+
# Dummy request does not need to sync status to decode endpoint
|
364
|
+
if kv_chunk.is_last:
|
365
|
+
self.update_status(req.room, KVPoll.Success)
|
366
|
+
|
367
|
+
if self.check_status(kv_chunk.room) == KVPoll.Success:
|
368
|
+
self.transfer_infos.pop(kv_chunk.room)
|
323
369
|
|
324
370
|
except queue.Empty:
|
325
371
|
continue
|
@@ -336,7 +382,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
336
382
|
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
337
383
|
status = int(status.decode("ascii"))
|
338
384
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
339
|
-
self.
|
385
|
+
self.update_status(bootstrap_room, status)
|
340
386
|
|
341
387
|
threading.Thread(target=decode_thread).start()
|
342
388
|
|
@@ -360,11 +406,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
360
406
|
prefill_aux_index=aux_index,
|
361
407
|
)
|
362
408
|
)
|
363
|
-
self.
|
409
|
+
self.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
364
410
|
|
365
411
|
def check_status(self, bootstrap_room: int):
|
366
|
-
# TOOD: do we really need the poll()?
|
367
|
-
|
368
412
|
return self.request_status[bootstrap_room]
|
369
413
|
|
370
414
|
def update_status(self, bootstrap_room: int, status: KVPoll):
|
@@ -469,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
469
513
|
self.session_id = self.kv_mgr.get_session_id()
|
470
514
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
471
515
|
|
472
|
-
if not self.kv_mgr.
|
473
|
-
|
474
|
-
# both prefill role and decode role. If the decode instance does
|
475
|
-
# not enable dp_attention, then dp_attention is not enabled on the
|
476
|
-
# prefill instance as well. Therefore, we should skip questioning
|
477
|
-
# the prefill dp size to reduce bootstrap overhead.
|
478
|
-
self.prefill_dp_size = 1
|
479
|
-
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
480
|
-
self.prefill_dp_size, tp_size_per_dp_rank = (
|
516
|
+
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
517
|
+
self.prefill_tp_size, self.prefill_dp_size = (
|
481
518
|
self._get_prefill_dp_size_from_server()
|
482
519
|
)
|
483
|
-
|
484
|
-
# have different TP sizes per DP rank.
|
485
|
-
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
486
|
-
if self.prefill_dp_size is None:
|
520
|
+
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
487
521
|
logger.error(
|
488
|
-
f"Could not fetch prefill
|
522
|
+
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
489
523
|
)
|
490
524
|
else:
|
525
|
+
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
526
|
+
self.prefill_tp_size
|
527
|
+
)
|
491
528
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
492
529
|
self.prefill_dp_size
|
493
530
|
)
|
494
531
|
else:
|
532
|
+
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
533
|
+
self.bootstrap_addr
|
534
|
+
]
|
495
535
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
496
536
|
self.bootstrap_addr
|
497
537
|
]
|
498
538
|
|
499
|
-
#
|
539
|
+
# Currently, we don't allow prefill instance and decode instance to
|
540
|
+
# have different TP sizes per DP rank, except for models using MLA.
|
541
|
+
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
542
|
+
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
543
|
+
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
544
|
+
self.target_tp_rank = (
|
545
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
546
|
+
)
|
547
|
+
self.required_dst_info_num = 1
|
548
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
549
|
+
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
550
|
+
assert (
|
551
|
+
self.kv_mgr.is_mla_backend
|
552
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
553
|
+
self.target_tp_rank = (
|
554
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
555
|
+
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
556
|
+
self.required_dst_info_num = (
|
557
|
+
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
558
|
+
)
|
559
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
560
|
+
else:
|
561
|
+
assert (
|
562
|
+
self.kv_mgr.is_mla_backend
|
563
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
564
|
+
|
565
|
+
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
566
|
+
self.target_tp_ranks = [
|
567
|
+
rank
|
568
|
+
for rank in range(
|
569
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
570
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
571
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
572
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
573
|
+
)
|
574
|
+
]
|
575
|
+
|
576
|
+
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
577
|
+
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
578
|
+
# or the KVPoll will never be set correctly
|
579
|
+
self.target_tp_rank = self.target_tp_ranks[0]
|
580
|
+
self.required_dst_info_num = 1
|
581
|
+
|
500
582
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
501
|
-
|
583
|
+
|
584
|
+
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
585
|
+
bootstrap_key = (
|
586
|
+
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
587
|
+
)
|
502
588
|
|
503
589
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
504
|
-
|
505
|
-
|
506
|
-
self.
|
507
|
-
|
508
|
-
|
590
|
+
bootstrap_infos = []
|
591
|
+
for target_tp_rank in self.target_tp_ranks:
|
592
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
593
|
+
target_tp_rank,
|
594
|
+
self.target_dp_group,
|
595
|
+
)
|
596
|
+
if bootstrap_info is not None:
|
597
|
+
# NOTE: only support MLA for now: select one prefill rank as real rank
|
598
|
+
bootstrap_info["is_dummy"] = not bool(
|
599
|
+
target_tp_rank == self.target_tp_rank
|
600
|
+
or self.target_tp_rank is None
|
601
|
+
)
|
602
|
+
bootstrap_infos.append(bootstrap_info)
|
603
|
+
else:
|
604
|
+
logger.error(
|
605
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
606
|
+
)
|
607
|
+
self.bootstrap_infos = bootstrap_infos
|
608
|
+
|
609
|
+
if len(self.bootstrap_infos) == 0:
|
509
610
|
logger.error(
|
510
611
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
511
612
|
)
|
512
613
|
else:
|
513
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.
|
614
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
514
615
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
515
616
|
self._register_kv_args()
|
516
617
|
else:
|
517
|
-
self.
|
618
|
+
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
518
619
|
|
519
|
-
assert self.
|
620
|
+
assert len(self.bootstrap_infos) > 0
|
520
621
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
521
622
|
|
522
623
|
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
@@ -543,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
543
644
|
response = requests.get(url)
|
544
645
|
if response.status_code == 200:
|
545
646
|
prefill_parallel_info = response.json()
|
546
|
-
return int(prefill_parallel_info["
|
547
|
-
prefill_parallel_info["
|
647
|
+
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
648
|
+
prefill_parallel_info["prefill_dp_size"]
|
548
649
|
)
|
549
650
|
else:
|
550
651
|
logger.error(
|
@@ -556,29 +657,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
556
657
|
return None
|
557
658
|
|
558
659
|
def _register_kv_args(self):
|
559
|
-
self.
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
)
|
569
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
570
|
-
with lock:
|
571
|
-
sock.send_multipart(
|
572
|
-
[
|
573
|
-
"None".encode("ascii"),
|
574
|
-
get_local_ip_by_remote().encode("ascii"),
|
575
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
576
|
-
self.session_id.encode("ascii"),
|
577
|
-
packed_kv_data_ptrs,
|
578
|
-
packed_aux_data_ptrs,
|
579
|
-
]
|
660
|
+
for bootstrap_info in self.bootstrap_infos:
|
661
|
+
self.prefill_server_url = (
|
662
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
663
|
+
)
|
664
|
+
packed_kv_data_ptrs = b"".join(
|
665
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
666
|
+
)
|
667
|
+
packed_aux_data_ptrs = b"".join(
|
668
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
580
669
|
)
|
581
670
|
|
671
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
672
|
+
with lock:
|
673
|
+
sock.send_multipart(
|
674
|
+
[
|
675
|
+
"None".encode("ascii"),
|
676
|
+
get_local_ip_by_remote().encode("ascii"),
|
677
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
678
|
+
self.session_id.encode("ascii"),
|
679
|
+
packed_kv_data_ptrs,
|
680
|
+
packed_aux_data_ptrs,
|
681
|
+
]
|
682
|
+
)
|
683
|
+
|
582
684
|
@classmethod
|
583
685
|
def _connect(cls, endpoint: str):
|
584
686
|
with cls._global_lock:
|
@@ -590,25 +692,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
590
692
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
591
693
|
|
592
694
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
593
|
-
self.
|
594
|
-
|
595
|
-
|
596
|
-
logger.debug(
|
597
|
-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
598
|
-
)
|
599
|
-
|
600
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
601
|
-
with lock:
|
602
|
-
sock.send_multipart(
|
603
|
-
[
|
604
|
-
str(self.bootstrap_room).encode("ascii"),
|
605
|
-
get_local_ip_by_remote().encode("ascii"),
|
606
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
607
|
-
self.session_id.encode("ascii"),
|
608
|
-
kv_indices.tobytes(),
|
609
|
-
str(aux_index).encode("ascii"),
|
610
|
-
]
|
695
|
+
for bootstrap_info in self.bootstrap_infos:
|
696
|
+
self.prefill_server_url = (
|
697
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
611
698
|
)
|
699
|
+
logger.debug(
|
700
|
+
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
701
|
+
)
|
702
|
+
is_dummy = bootstrap_info["is_dummy"]
|
703
|
+
|
704
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
705
|
+
with lock:
|
706
|
+
sock.send_multipart(
|
707
|
+
[
|
708
|
+
str(self.bootstrap_room).encode("ascii"),
|
709
|
+
get_local_ip_by_remote().encode("ascii"),
|
710
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
711
|
+
self.session_id.encode("ascii"),
|
712
|
+
kv_indices.tobytes() if not is_dummy else b"",
|
713
|
+
str(aux_index).encode("ascii") if not is_dummy else b"",
|
714
|
+
str(self.required_dst_info_num).encode("ascii"),
|
715
|
+
]
|
716
|
+
)
|
612
717
|
|
613
718
|
def poll(self) -> KVPoll:
|
614
719
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
@@ -624,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
624
729
|
self.store = dict()
|
625
730
|
self.lock = asyncio.Lock()
|
626
731
|
self._setup_routes()
|
732
|
+
self.tp_size = None
|
627
733
|
self.dp_size = None
|
628
734
|
self.tp_size_per_dp_rank = None
|
629
735
|
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
@@ -658,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
658
764
|
rank_port = int(data["rank_port"])
|
659
765
|
engine_rank = int(data["engine_rank"])
|
660
766
|
|
767
|
+
if self.tp_size is None:
|
768
|
+
self.tp_size = tp_size
|
769
|
+
|
661
770
|
if self.dp_size is None:
|
662
771
|
self.dp_size = dp_size
|
663
772
|
|
@@ -693,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
693
802
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
694
803
|
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
695
804
|
prefill_parallel_info = {
|
805
|
+
"prefill_tp_size": self.tp_size,
|
696
806
|
"prefill_dp_size": self.dp_size,
|
697
|
-
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
698
807
|
}
|
699
808
|
return web.json_response(prefill_parallel_info, status=200)
|
700
809
|
|
701
810
|
# Find corresponding prefill info
|
702
|
-
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
703
|
-
|
704
811
|
async with self.lock:
|
705
812
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
706
|
-
|
813
|
+
int(engine_rank)
|
707
814
|
]
|
708
815
|
|
709
816
|
if bootstrap_info is not None:
|
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import (
|
|
34
34
|
ReqToMetadataIdxAllocator,
|
35
35
|
TransferBackend,
|
36
36
|
get_kv_class,
|
37
|
+
is_mla_backend,
|
37
38
|
kv_to_page_indices,
|
38
39
|
kv_to_page_num,
|
39
40
|
poll_and_all_reduce,
|
@@ -69,6 +70,7 @@ class PrefillBootstrapQueue:
|
|
69
70
|
scheduler: Scheduler,
|
70
71
|
):
|
71
72
|
self.token_to_kv_pool = token_to_kv_pool
|
73
|
+
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
72
74
|
self.aux_dtype = aux_dtype
|
73
75
|
|
74
76
|
self.metadata_buffers = metadata_buffers
|
@@ -112,7 +114,10 @@ class PrefillBootstrapQueue:
|
|
112
114
|
kv_args.gpu_id = self.scheduler.gpu_id
|
113
115
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
114
116
|
kv_manager = kv_manager_class(
|
115
|
-
kv_args,
|
117
|
+
kv_args,
|
118
|
+
DisaggregationMode.PREFILL,
|
119
|
+
self.scheduler.server_args,
|
120
|
+
self.is_mla_backend,
|
116
121
|
)
|
117
122
|
return kv_manager
|
118
123
|
|
@@ -277,19 +282,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|
277
282
|
next_token_ids,
|
278
283
|
extend_input_len_per_req,
|
279
284
|
extend_logprob_start_len_per_req,
|
280
|
-
bid,
|
281
285
|
) = (
|
282
286
|
result.logits_output,
|
283
287
|
result.next_token_ids,
|
284
288
|
result.extend_input_len_per_req,
|
285
289
|
result.extend_logprob_start_len_per_req,
|
286
|
-
result.bid,
|
287
290
|
)
|
288
291
|
|
289
292
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
290
293
|
if self.enable_overlap:
|
291
294
|
# wait
|
292
|
-
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
295
|
+
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
|
293
296
|
else:
|
294
297
|
next_token_ids = result.next_token_ids.tolist()
|
295
298
|
|
@@ -112,7 +112,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
112
112
|
|
113
113
|
|
114
114
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
115
|
-
# 1. The page is
|
115
|
+
# 1. The page is guaranteed to be full except the last page.
|
116
116
|
# 2. page index = kv_index // page_size
|
117
117
|
# The return vector is kv_indices[::page_size] // page_size
|
118
118
|
if page_size == 1: # shortcut
|
@@ -162,3 +162,9 @@ def register_disaggregation_server(
|
|
162
162
|
warnings.warn(
|
163
163
|
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
164
164
|
)
|
165
|
+
|
166
|
+
|
167
|
+
def is_mla_backend(target_kv_pool) -> bool:
|
168
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
169
|
+
|
170
|
+
return isinstance(target_kv_pool, MLATokenToKVPool)
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -285,6 +285,21 @@ class Engine(EngineBase):
|
|
285
285
|
ret = loop.run_until_complete(generator.__anext__())
|
286
286
|
return ret
|
287
287
|
|
288
|
+
async def async_encode(
|
289
|
+
self,
|
290
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
291
|
+
image_data: Optional[Union[List[str], str]] = None,
|
292
|
+
) -> Dict:
|
293
|
+
"""
|
294
|
+
Asynchronous version of encode method.
|
295
|
+
|
296
|
+
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
297
|
+
Please refer to `EmbeddingReqInput` for the documentation.
|
298
|
+
"""
|
299
|
+
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
300
|
+
generator = self.tokenizer_manager.generate_request(obj, None)
|
301
|
+
return await generator.__anext__()
|
302
|
+
|
288
303
|
def shutdown(self):
|
289
304
|
"""Shutdown the engine"""
|
290
305
|
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -315,7 +330,7 @@ class Engine(EngineBase):
|
|
315
330
|
return {
|
316
331
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
317
332
|
**self.scheduler_info,
|
318
|
-
|
333
|
+
"internal_states": internal_states,
|
319
334
|
"version": __version__,
|
320
335
|
}
|
321
336
|
|
@@ -471,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
471
486
|
if _is_cuda:
|
472
487
|
assert_pkg_version(
|
473
488
|
"sgl-kernel",
|
474
|
-
"0.1.
|
489
|
+
"0.1.2.post1",
|
475
490
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
476
491
|
)
|
477
492
|
|