sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
|
|
137
137
|
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
138
138
|
kv_receiver = kv_receiver_class(
|
139
139
|
mgr=self.kv_manager,
|
140
|
-
bootstrap_addr=f"{req.bootstrap_host}:{
|
140
|
+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
141
141
|
bootstrap_room=req.bootstrap_room,
|
142
142
|
)
|
143
143
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
@@ -307,7 +307,7 @@ class DecodeTransferQueue:
|
|
307
307
|
def extend(self, req_conns) -> None:
|
308
308
|
self.queue.extend(req_conns)
|
309
309
|
|
310
|
-
def pop_transferred(self) -> List[
|
310
|
+
def pop_transferred(self) -> List[DecodeRequest]:
|
311
311
|
if not self.queue:
|
312
312
|
return []
|
313
313
|
|
@@ -330,7 +330,7 @@ class DecodeTransferQueue:
|
|
330
330
|
assert len(decode_req.req.output_ids) == 0
|
331
331
|
assert decode_req.req.transferred_output_id is None
|
332
332
|
decode_req.req.transferred_output_id = output_id
|
333
|
-
transferred_reqs.append(decode_req
|
333
|
+
transferred_reqs.append(decode_req)
|
334
334
|
indices_to_remove.add(i)
|
335
335
|
elif poll in [
|
336
336
|
KVPoll.Bootstrapping,
|
@@ -444,8 +444,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
444
444
|
|
445
445
|
class SchedulerDisaggregationDecodeMixin:
|
446
446
|
|
447
|
+
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
448
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
449
|
+
result = None
|
450
|
+
if batch:
|
451
|
+
result = self.run_batch(batch)
|
452
|
+
if not delay_process:
|
453
|
+
self.process_batch_result(batch, result)
|
454
|
+
return batch, result
|
455
|
+
|
447
456
|
@torch.no_grad()
|
448
|
-
def event_loop_normal_disagg_decode(self):
|
457
|
+
def event_loop_normal_disagg_decode(self: Scheduler):
|
449
458
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
450
459
|
|
451
460
|
while True:
|
@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
|
|
456
465
|
batch = self.get_next_disagg_decode_batch_to_run()
|
457
466
|
self.cur_batch = batch
|
458
467
|
|
468
|
+
prepare_dp_attn_flag = (
|
469
|
+
self.server_args.enable_dp_attention
|
470
|
+
or self.server_args.enable_sp_layernorm
|
471
|
+
)
|
472
|
+
|
459
473
|
if batch:
|
460
474
|
# Generate fake extend output.
|
461
475
|
if batch.forward_mode.is_extend():
|
462
476
|
# Note: Logprobs should be handled on the prefill engine.
|
463
477
|
self.stream_output(batch.reqs, False)
|
478
|
+
if prepare_dp_attn_flag:
|
479
|
+
self._prepare_idle_batch_and_run(None)
|
464
480
|
else:
|
481
|
+
if prepare_dp_attn_flag:
|
482
|
+
self.prepare_dp_attn_batch(batch)
|
465
483
|
result = self.run_batch(batch)
|
466
484
|
self.process_batch_result(batch, result)
|
485
|
+
elif prepare_dp_attn_flag:
|
486
|
+
batch, _ = self._prepare_idle_batch_and_run(None)
|
467
487
|
|
468
488
|
if batch is None and (
|
469
489
|
len(self.disagg_decode_transfer_queue.queue)
|
@@ -477,10 +497,10 @@ class SchedulerDisaggregationDecodeMixin:
|
|
477
497
|
self.last_batch = batch
|
478
498
|
|
479
499
|
@torch.no_grad()
|
480
|
-
def event_loop_overlap_disagg_decode(self):
|
500
|
+
def event_loop_overlap_disagg_decode(self: Scheduler):
|
481
501
|
result_queue = deque()
|
482
502
|
self.last_batch: Optional[ScheduleBatch] = None
|
483
|
-
self.
|
503
|
+
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
484
504
|
|
485
505
|
while True:
|
486
506
|
recv_reqs = self.recv_requests()
|
@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin:
|
|
489
509
|
self.process_decode_queue()
|
490
510
|
batch = self.get_next_disagg_decode_batch_to_run()
|
491
511
|
self.cur_batch = batch
|
492
|
-
|
512
|
+
last_batch_in_queue = False
|
513
|
+
|
514
|
+
prepare_dp_attn_flag = (
|
515
|
+
self.server_args.enable_dp_attention
|
516
|
+
or self.server_args.enable_sp_layernorm
|
517
|
+
)
|
493
518
|
|
494
519
|
if batch:
|
495
520
|
# Generate fake extend output.
|
496
521
|
if batch.forward_mode.is_extend():
|
497
522
|
# Note: Logprobs should be handled on the prefill engine.
|
498
523
|
self.stream_output(batch.reqs, False)
|
499
|
-
|
524
|
+
if prepare_dp_attn_flag:
|
525
|
+
batch_, result = self._prepare_idle_batch_and_run(
|
526
|
+
None, delay_process=True
|
527
|
+
)
|
528
|
+
if batch_:
|
529
|
+
result_queue.append((batch_.copy(), result))
|
530
|
+
last_batch_in_queue = True
|
500
531
|
else:
|
532
|
+
if prepare_dp_attn_flag:
|
533
|
+
self.prepare_dp_attn_batch(batch)
|
501
534
|
result = self.run_batch(batch)
|
502
535
|
result_queue.append((batch.copy(), result))
|
536
|
+
last_batch_in_queue = True
|
537
|
+
elif prepare_dp_attn_flag:
|
538
|
+
batch, result = self._prepare_idle_batch_and_run(
|
539
|
+
None, delay_process=True
|
540
|
+
)
|
541
|
+
if batch:
|
542
|
+
result_queue.append((batch.copy(), result))
|
543
|
+
last_batch_in_queue = True
|
503
544
|
|
504
545
|
# Process the results of the previous batch but skip if the last batch is extend
|
505
|
-
if self.last_batch and
|
546
|
+
if self.last_batch and self.last_batch_in_queue:
|
506
547
|
tmp_batch, tmp_result = result_queue.popleft()
|
507
548
|
self.process_batch_result(tmp_batch, tmp_result)
|
508
549
|
|
@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
516
557
|
self.new_token_ratio = self.init_new_token_ratio
|
517
558
|
|
518
559
|
self.last_batch = batch
|
519
|
-
self.
|
560
|
+
self.last_batch_in_queue = last_batch_in_queue
|
520
561
|
|
521
562
|
def get_next_disagg_decode_batch_to_run(
|
522
563
|
self: Scheduler,
|
@@ -600,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
600
641
|
|
601
642
|
def process_decode_queue(self: Scheduler):
|
602
643
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
644
|
+
|
645
|
+
def _num_pre_alloc(req):
|
646
|
+
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
647
|
+
|
648
|
+
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
603
649
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
604
650
|
alloc_reqs = (
|
605
651
|
self.disagg_decode_transfer_queue.pop_transferred()
|
606
652
|
) # the requests which kv has arrived
|
607
|
-
self.
|
653
|
+
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
|
654
|
+
|
655
|
+
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
@@ -6,6 +6,7 @@ import asyncio
|
|
6
6
|
import random
|
7
7
|
import urllib
|
8
8
|
from itertools import chain
|
9
|
+
from typing import List
|
9
10
|
|
10
11
|
import aiohttp
|
11
12
|
import orjson
|
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
|
|
14
15
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
15
16
|
|
16
17
|
|
18
|
+
class PrefillConfig:
|
19
|
+
def __init__(self, url: str, bootstrap_port: int):
|
20
|
+
self.url = url
|
21
|
+
self.bootstrap_port = bootstrap_port
|
22
|
+
|
23
|
+
|
17
24
|
class MiniLoadBalancer:
|
18
|
-
def __init__(self,
|
19
|
-
self.
|
25
|
+
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
26
|
+
self.prefill_configs = prefill_configs
|
27
|
+
self.prefill_servers = [p.url for p in prefill_configs]
|
20
28
|
self.decode_servers = decode_servers
|
21
29
|
|
22
30
|
def select_pair(self):
|
23
|
-
|
31
|
+
prefill_config = random.choice(self.prefill_configs)
|
32
|
+
decode_server = random.choice(self.decode_servers)
|
33
|
+
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
24
34
|
|
25
35
|
async def generate(
|
26
36
|
self, modified_request, prefill_server, decode_server, endpoint
|
@@ -160,7 +170,7 @@ async def get_model_info():
|
|
160
170
|
|
161
171
|
@app.post("/generate")
|
162
172
|
async def handle_generate_request(request_data: dict):
|
163
|
-
prefill_server, decode_server = load_balancer.select_pair()
|
173
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
164
174
|
|
165
175
|
# Parse and transform prefill_server for bootstrap data
|
166
176
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
|
|
172
182
|
modified_request.update(
|
173
183
|
{
|
174
184
|
"bootstrap_host": [hostname] * batch_size,
|
185
|
+
"bootstrap_port": [bootstrap_port] * batch_size,
|
175
186
|
"bootstrap_room": [
|
176
187
|
_generate_bootstrap_room() for _ in range(batch_size)
|
177
188
|
],
|
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
|
|
181
192
|
modified_request.update(
|
182
193
|
{
|
183
194
|
"bootstrap_host": hostname,
|
195
|
+
"bootstrap_port": bootstrap_port,
|
184
196
|
"bootstrap_room": _generate_bootstrap_room(),
|
185
197
|
}
|
186
198
|
)
|
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
|
|
197
209
|
|
198
210
|
@app.post("/v1/chat/completions")
|
199
211
|
async def handle_completion_request(request_data: dict):
|
200
|
-
prefill_server, decode_server = load_balancer.select_pair()
|
212
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
201
213
|
|
202
214
|
# Parse and transform prefill_server for bootstrap data
|
203
215
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
|
|
206
218
|
modified_request.update(
|
207
219
|
{
|
208
220
|
"bootstrap_host": hostname,
|
221
|
+
"bootstrap_port": bootstrap_port,
|
209
222
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
210
223
|
}
|
211
224
|
)
|
@@ -255,9 +268,9 @@ async def get_models():
|
|
255
268
|
raise HTTPException(status_code=500, detail=str(e))
|
256
269
|
|
257
270
|
|
258
|
-
def run(
|
271
|
+
def run(prefill_configs, decode_addrs, host, port):
|
259
272
|
global load_balancer
|
260
|
-
load_balancer = MiniLoadBalancer(
|
273
|
+
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
261
274
|
uvicorn.run(app, host=host, port=port)
|
262
275
|
|
263
276
|
|
@@ -268,6 +281,11 @@ if __name__ == "__main__":
|
|
268
281
|
parser.add_argument(
|
269
282
|
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
270
283
|
)
|
284
|
+
parser.add_argument(
|
285
|
+
"--prefill-bootstrap-ports",
|
286
|
+
help="Comma-separated bootstrap ports for prefill servers",
|
287
|
+
default="8998",
|
288
|
+
)
|
271
289
|
parser.add_argument(
|
272
290
|
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
273
291
|
)
|
@@ -278,4 +296,23 @@ if __name__ == "__main__":
|
|
278
296
|
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
279
297
|
)
|
280
298
|
args = parser.parse_args()
|
281
|
-
|
299
|
+
|
300
|
+
prefill_urls = args.prefill.split(",")
|
301
|
+
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
302
|
+
|
303
|
+
if len(bootstrap_ports) == 1:
|
304
|
+
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
305
|
+
else:
|
306
|
+
if len(bootstrap_ports) != len(prefill_urls):
|
307
|
+
raise ValueError(
|
308
|
+
"Number of prefill URLs must match number of bootstrap ports"
|
309
|
+
)
|
310
|
+
exit(1)
|
311
|
+
|
312
|
+
prefill_configs = []
|
313
|
+
for url, port in zip(prefill_urls, bootstrap_ports):
|
314
|
+
prefill_configs.append(PrefillConfig(url, port))
|
315
|
+
|
316
|
+
decode_addrs = args.decode.split(",")
|
317
|
+
|
318
|
+
run(prefill_configs, decode_addrs, args.host, args.port)
|
@@ -1,8 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import concurrent.futures
|
4
5
|
import dataclasses
|
5
6
|
import logging
|
7
|
+
import os
|
6
8
|
import queue
|
7
9
|
import socket
|
8
10
|
import struct
|
@@ -73,9 +75,7 @@ class TransferInfo:
|
|
73
75
|
endpoint: str
|
74
76
|
dst_port: int
|
75
77
|
mooncake_session_id: str
|
76
|
-
dst_kv_ptrs: list[int]
|
77
78
|
dst_kv_indices: npt.NDArray[np.int64]
|
78
|
-
dst_aux_ptrs: list[int]
|
79
79
|
dst_aux_index: int
|
80
80
|
|
81
81
|
@classmethod
|
@@ -85,10 +85,29 @@ class TransferInfo:
|
|
85
85
|
endpoint=msg[1].decode("ascii"),
|
86
86
|
dst_port=int(msg[2].decode("ascii")),
|
87
87
|
mooncake_session_id=msg[3].decode("ascii"),
|
88
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
|
89
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
@dataclasses.dataclass
|
94
|
+
class KVArgsRegisterInfo:
|
95
|
+
room: str
|
96
|
+
endpoint: str
|
97
|
+
dst_port: int
|
98
|
+
mooncake_session_id: str
|
99
|
+
dst_kv_ptrs: list[int]
|
100
|
+
dst_aux_ptrs: list[int]
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_zmq(cls, msg: List[bytes]):
|
104
|
+
return cls(
|
105
|
+
room=str(msg[0].decode("ascii")),
|
106
|
+
endpoint=msg[1].decode("ascii"),
|
107
|
+
dst_port=int(msg[2].decode("ascii")),
|
108
|
+
mooncake_session_id=msg[3].decode("ascii"),
|
88
109
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
89
|
-
|
90
|
-
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
91
|
-
dst_aux_index=int(msg[7].decode("ascii")),
|
110
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
92
111
|
)
|
93
112
|
|
94
113
|
|
@@ -109,6 +128,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
109
128
|
# for p/d multi node infer
|
110
129
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
111
130
|
self.dist_init_addr = server_args.dist_init_addr
|
131
|
+
self.tp_size = server_args.tp_size
|
132
|
+
self.dp_size = server_args.dp_size
|
133
|
+
self.enable_dp_attention = server_args.enable_dp_attention
|
134
|
+
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
135
|
+
raise ValueError(
|
136
|
+
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
137
|
+
)
|
112
138
|
self.request_status: Dict[int, KVPoll] = {}
|
113
139
|
self.rank_port = None
|
114
140
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
@@ -116,11 +142,19 @@ class MooncakeKVManager(BaseKVManager):
|
|
116
142
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
117
143
|
self.transfer_queue = queue.Queue()
|
118
144
|
self.transfer_infos: Dict[int, TransferInfo] = {}
|
145
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
119
146
|
self.start_prefill_thread()
|
120
147
|
self._register_to_bootstrap()
|
148
|
+
|
149
|
+
# Determine the number of threads to use for kv sender
|
150
|
+
cpu_count = os.cpu_count()
|
151
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(
|
152
|
+
min(cpu_count // 4, 16)
|
153
|
+
)
|
121
154
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
122
155
|
self.start_decode_thread()
|
123
156
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
157
|
+
self.prefill_dp_size_table: Dict[str, int] = {}
|
124
158
|
else:
|
125
159
|
raise ValueError(
|
126
160
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -150,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
|
|
150
184
|
dst_kv_ptrs: list[int],
|
151
185
|
dst_kv_indices: npt.NDArray[np.int64],
|
152
186
|
):
|
153
|
-
#
|
187
|
+
# Group by indices
|
154
188
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
155
189
|
prefill_kv_indices, dst_kv_indices
|
156
190
|
)
|
157
191
|
|
158
192
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
193
|
+
layers_params = [
|
194
|
+
(
|
195
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
196
|
+
dst_kv_ptrs[layer_id],
|
197
|
+
self.kv_args.kv_item_lens[layer_id],
|
198
|
+
)
|
199
|
+
for layer_id in range(num_layers)
|
200
|
+
]
|
163
201
|
|
202
|
+
# Worker function for processing a single layer
|
203
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
164
204
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
165
205
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
166
206
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
167
207
|
length = item_len * len(prefill_index)
|
168
208
|
|
169
|
-
# TODO: make async later
|
170
209
|
status = self.engine.transfer_sync(
|
171
210
|
mooncake_session_id, src_addr, dst_addr, length
|
172
211
|
)
|
173
212
|
if status != 0:
|
174
213
|
return status
|
214
|
+
return 0
|
215
|
+
|
216
|
+
futures = [
|
217
|
+
self.executor.submit(
|
218
|
+
process_layer,
|
219
|
+
src_ptr,
|
220
|
+
dst_ptr,
|
221
|
+
item_len,
|
222
|
+
)
|
223
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
224
|
+
]
|
225
|
+
|
226
|
+
for future in concurrent.futures.as_completed(futures):
|
227
|
+
status = future.result()
|
228
|
+
if status != 0:
|
229
|
+
# Immediate shutdown on first error (existing tasks will finish)
|
230
|
+
executor.shutdown(wait=False)
|
231
|
+
for f in futures:
|
232
|
+
f.cancel()
|
233
|
+
return status
|
175
234
|
|
176
235
|
return 0
|
177
236
|
|
@@ -215,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
215
274
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
216
275
|
room = waiting_req_bytes[0].decode("ascii")
|
217
276
|
if room == "None":
|
277
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
278
|
+
self.decode_kv_args_table[mooncake_session_id] = (
|
279
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
280
|
+
)
|
281
|
+
logger.debug(
|
282
|
+
f"Register KVArgs from {mooncake_session_id} successfully"
|
283
|
+
)
|
218
284
|
continue
|
219
285
|
room = int(room)
|
220
286
|
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
@@ -236,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
236
302
|
ret = self.send_kvcache(
|
237
303
|
req.mooncake_session_id,
|
238
304
|
kv_chunk.prefill_kv_indices,
|
239
|
-
req.dst_kv_ptrs,
|
305
|
+
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
|
240
306
|
chunked_dst_kv_indice,
|
241
307
|
)
|
242
308
|
if ret != 0:
|
@@ -251,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
251
317
|
ret = self.send_aux(
|
252
318
|
req.mooncake_session_id,
|
253
319
|
kv_chunk.prefill_aux_index,
|
254
|
-
|
320
|
+
self.decode_kv_args_table[
|
321
|
+
req.mooncake_session_id
|
322
|
+
].dst_aux_ptrs,
|
255
323
|
req.dst_aux_index,
|
256
324
|
)
|
257
325
|
self.request_status[req.room] = (
|
@@ -331,6 +399,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
331
399
|
url = f"http://{bootstrap_server_url}/route"
|
332
400
|
payload = {
|
333
401
|
"role": "Prefill",
|
402
|
+
"tp_size": self.tp_size,
|
403
|
+
"dp_size": self.dp_size,
|
334
404
|
"rank_ip": get_local_ip_by_remote(),
|
335
405
|
"rank_port": self.rank_port,
|
336
406
|
"engine_rank": self.kv_args.engine_rank,
|
@@ -408,12 +478,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
408
478
|
self.session_id = self.kv_mgr.get_session_id()
|
409
479
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
410
480
|
|
481
|
+
if not self.kv_mgr.enable_dp_attention:
|
482
|
+
# We assume dp_attention should be activated simultaneously for
|
483
|
+
# both prefill role and decode role. If the decode instance does
|
484
|
+
# not enable dp_attention, then dp_attention is not enabled on the
|
485
|
+
# prefill instance as well. Therefore, we should skip questioning
|
486
|
+
# the prefill dp size to reduce bootstrap overhead.
|
487
|
+
self.prefill_dp_size = 1
|
488
|
+
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
489
|
+
self.prefill_dp_size, tp_size_per_dp_rank = (
|
490
|
+
self._get_prefill_dp_size_from_server()
|
491
|
+
)
|
492
|
+
# Currently, we don't allow prefill instance and decode instance to
|
493
|
+
# have different TP sizes per DP rank.
|
494
|
+
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
495
|
+
if self.prefill_dp_size is None:
|
496
|
+
logger.error(
|
497
|
+
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
|
498
|
+
)
|
499
|
+
else:
|
500
|
+
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
501
|
+
self.prefill_dp_size
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
505
|
+
self.bootstrap_addr
|
506
|
+
]
|
507
|
+
|
411
508
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
509
|
+
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
412
510
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
413
511
|
|
414
512
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
415
513
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
416
|
-
self.kv_mgr.kv_args.engine_rank
|
514
|
+
self.kv_mgr.kv_args.engine_rank,
|
515
|
+
self.target_dp_group,
|
417
516
|
)
|
418
517
|
if self.bootstrap_info is None:
|
419
518
|
logger.error(
|
@@ -421,16 +520,18 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
421
520
|
)
|
422
521
|
else:
|
423
522
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
523
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
524
|
+
self._register_kv_args()
|
424
525
|
else:
|
425
526
|
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
426
527
|
|
427
528
|
assert self.bootstrap_info is not None
|
428
529
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
429
530
|
|
430
|
-
def _get_bootstrap_info_from_server(self, engine_rank):
|
531
|
+
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
431
532
|
"""Fetch the bootstrap info from the bootstrap server."""
|
432
533
|
try:
|
433
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
534
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
434
535
|
response = requests.get(url)
|
435
536
|
if response.status_code == 200:
|
436
537
|
bootstrap_info = response.json()
|
@@ -444,6 +545,49 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
444
545
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
445
546
|
return None
|
446
547
|
|
548
|
+
def _get_prefill_dp_size_from_server(self) -> int:
|
549
|
+
"""Fetch the prefill parallel info from the bootstrap server."""
|
550
|
+
try:
|
551
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
552
|
+
response = requests.get(url)
|
553
|
+
if response.status_code == 200:
|
554
|
+
prefill_parallel_info = response.json()
|
555
|
+
return int(prefill_parallel_info["prefill_dp_size"]), int(
|
556
|
+
prefill_parallel_info["tp_size_per_dp_rank"]
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
logger.error(
|
560
|
+
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
561
|
+
)
|
562
|
+
return None
|
563
|
+
except Exception as e:
|
564
|
+
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
565
|
+
return None
|
566
|
+
|
567
|
+
def _register_kv_args(self):
|
568
|
+
self.prefill_server_url = (
|
569
|
+
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
570
|
+
)
|
571
|
+
|
572
|
+
packed_kv_data_ptrs = b"".join(
|
573
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
574
|
+
)
|
575
|
+
packed_aux_data_ptrs = b"".join(
|
576
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
577
|
+
)
|
578
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
579
|
+
with lock:
|
580
|
+
sock.send_multipart(
|
581
|
+
[
|
582
|
+
"None".encode("ascii"),
|
583
|
+
get_local_ip_by_remote().encode("ascii"),
|
584
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
585
|
+
self.session_id.encode("ascii"),
|
586
|
+
packed_kv_data_ptrs,
|
587
|
+
packed_aux_data_ptrs,
|
588
|
+
]
|
589
|
+
)
|
590
|
+
|
447
591
|
@classmethod
|
448
592
|
def _connect(cls, endpoint: str):
|
449
593
|
with cls._global_lock:
|
@@ -462,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
462
606
|
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
463
607
|
)
|
464
608
|
|
465
|
-
packed_kv_data_ptrs = b"".join(
|
466
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
467
|
-
)
|
468
|
-
packed_aux_data_ptrs = b"".join(
|
469
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
470
|
-
)
|
471
609
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
610
|
with lock:
|
473
611
|
sock.send_multipart(
|
@@ -476,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
476
614
|
get_local_ip_by_remote().encode("ascii"),
|
477
615
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
616
|
self.session_id.encode("ascii"),
|
479
|
-
packed_kv_data_ptrs,
|
480
617
|
kv_indices.tobytes(),
|
481
|
-
packed_aux_data_ptrs,
|
482
618
|
str(aux_index).encode("ascii"),
|
483
619
|
]
|
484
620
|
)
|
@@ -497,7 +633,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
497
633
|
self.store = dict()
|
498
634
|
self.lock = asyncio.Lock()
|
499
635
|
self._setup_routes()
|
500
|
-
self.
|
636
|
+
self.dp_size = None
|
637
|
+
self.tp_size_per_dp_rank = None
|
638
|
+
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
501
639
|
|
502
640
|
# Start bootstrap server
|
503
641
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -523,35 +661,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
523
661
|
async def _handle_route_put(self, request: web.Request):
|
524
662
|
data = await request.json()
|
525
663
|
role = data["role"]
|
664
|
+
tp_size = data["tp_size"]
|
665
|
+
dp_size = data["dp_size"]
|
526
666
|
rank_ip = data["rank_ip"]
|
527
667
|
rank_port = int(data["rank_port"])
|
528
668
|
engine_rank = int(data["engine_rank"])
|
529
669
|
|
670
|
+
if self.dp_size is None:
|
671
|
+
self.dp_size = dp_size
|
672
|
+
|
673
|
+
tp_size_per_dp_rank = tp_size // dp_size
|
674
|
+
if self.tp_size_per_dp_rank == None:
|
675
|
+
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
676
|
+
|
530
677
|
# Add lock to make sure thread-safe
|
531
678
|
if role == "Prefill":
|
532
|
-
|
679
|
+
dp_group = engine_rank // tp_size_per_dp_rank
|
680
|
+
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
681
|
+
|
682
|
+
async with self.lock:
|
683
|
+
if dp_group not in self.prefill_port_table:
|
684
|
+
self.prefill_port_table[dp_group] = {}
|
685
|
+
|
686
|
+
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
533
687
|
"rank_ip": rank_ip,
|
534
688
|
"rank_port": rank_port,
|
535
689
|
}
|
536
690
|
logger.debug(
|
537
|
-
f"
|
691
|
+
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
538
692
|
)
|
539
693
|
|
540
694
|
return web.Response(text="OK", status=200)
|
541
695
|
|
542
696
|
async def _handle_route_get(self, request: web.Request):
|
543
697
|
engine_rank = request.query.get("engine_rank")
|
544
|
-
|
545
|
-
|
698
|
+
target_dp_group = request.query.get("target_dp_group")
|
699
|
+
if not engine_rank or not target_dp_group:
|
700
|
+
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
701
|
+
|
702
|
+
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
703
|
+
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
704
|
+
prefill_parallel_info = {
|
705
|
+
"prefill_dp_size": self.dp_size,
|
706
|
+
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
707
|
+
}
|
708
|
+
return web.json_response(prefill_parallel_info, status=200)
|
546
709
|
|
547
710
|
# Find corresponding prefill info
|
711
|
+
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
712
|
+
|
548
713
|
async with self.lock:
|
549
|
-
bootstrap_info = self.prefill_port_table
|
714
|
+
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
715
|
+
tp_rank_in_dp_group
|
716
|
+
]
|
550
717
|
|
551
718
|
if bootstrap_info is not None:
|
552
719
|
return web.json_response(bootstrap_info, status=200)
|
553
720
|
else:
|
554
|
-
return web.Response(text="
|
721
|
+
return web.Response(text="Bootstrap info not Found", status=404)
|
555
722
|
|
556
723
|
def _run_server(self):
|
557
724
|
try:
|