sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import concurrent.futures
|
4
5
|
import dataclasses
|
5
6
|
import logging
|
7
|
+
import os
|
6
8
|
import queue
|
7
9
|
import socket
|
8
10
|
import struct
|
@@ -73,9 +75,7 @@ class TransferInfo:
|
|
73
75
|
endpoint: str
|
74
76
|
dst_port: int
|
75
77
|
mooncake_session_id: str
|
76
|
-
dst_kv_ptrs: list[int]
|
77
78
|
dst_kv_indices: npt.NDArray[np.int64]
|
78
|
-
dst_aux_ptrs: list[int]
|
79
79
|
dst_aux_index: int
|
80
80
|
|
81
81
|
@classmethod
|
@@ -85,10 +85,29 @@ class TransferInfo:
|
|
85
85
|
endpoint=msg[1].decode("ascii"),
|
86
86
|
dst_port=int(msg[2].decode("ascii")),
|
87
87
|
mooncake_session_id=msg[3].decode("ascii"),
|
88
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
|
89
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
@dataclasses.dataclass
|
94
|
+
class KVArgsRegisterInfo:
|
95
|
+
room: str
|
96
|
+
endpoint: str
|
97
|
+
dst_port: int
|
98
|
+
mooncake_session_id: str
|
99
|
+
dst_kv_ptrs: list[int]
|
100
|
+
dst_aux_ptrs: list[int]
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_zmq(cls, msg: List[bytes]):
|
104
|
+
return cls(
|
105
|
+
room=str(msg[0].decode("ascii")),
|
106
|
+
endpoint=msg[1].decode("ascii"),
|
107
|
+
dst_port=int(msg[2].decode("ascii")),
|
108
|
+
mooncake_session_id=msg[3].decode("ascii"),
|
88
109
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
89
|
-
|
90
|
-
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
91
|
-
dst_aux_index=int(msg[7].decode("ascii")),
|
110
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
92
111
|
)
|
93
112
|
|
94
113
|
|
@@ -109,6 +128,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
109
128
|
# for p/d multi node infer
|
110
129
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
111
130
|
self.dist_init_addr = server_args.dist_init_addr
|
131
|
+
self.tp_size = server_args.tp_size
|
132
|
+
self.dp_size = server_args.dp_size
|
133
|
+
self.enable_dp_attention = server_args.enable_dp_attention
|
134
|
+
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
135
|
+
raise ValueError(
|
136
|
+
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
137
|
+
)
|
112
138
|
self.request_status: Dict[int, KVPoll] = {}
|
113
139
|
self.rank_port = None
|
114
140
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
@@ -116,11 +142,19 @@ class MooncakeKVManager(BaseKVManager):
|
|
116
142
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
117
143
|
self.transfer_queue = queue.Queue()
|
118
144
|
self.transfer_infos: Dict[int, TransferInfo] = {}
|
145
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
119
146
|
self.start_prefill_thread()
|
120
147
|
self._register_to_bootstrap()
|
148
|
+
|
149
|
+
# Determine the number of threads to use for kv sender
|
150
|
+
cpu_count = os.cpu_count()
|
151
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(
|
152
|
+
min(cpu_count // 4, 16)
|
153
|
+
)
|
121
154
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
122
155
|
self.start_decode_thread()
|
123
156
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
157
|
+
self.prefill_dp_size_table: Dict[str, int] = {}
|
124
158
|
else:
|
125
159
|
raise ValueError(
|
126
160
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -150,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
|
|
150
184
|
dst_kv_ptrs: list[int],
|
151
185
|
dst_kv_indices: npt.NDArray[np.int64],
|
152
186
|
):
|
153
|
-
#
|
187
|
+
# Group by indices
|
154
188
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
155
189
|
prefill_kv_indices, dst_kv_indices
|
156
190
|
)
|
157
191
|
|
158
192
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
193
|
+
layers_params = [
|
194
|
+
(
|
195
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
196
|
+
dst_kv_ptrs[layer_id],
|
197
|
+
self.kv_args.kv_item_lens[layer_id],
|
198
|
+
)
|
199
|
+
for layer_id in range(num_layers)
|
200
|
+
]
|
163
201
|
|
202
|
+
# Worker function for processing a single layer
|
203
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
164
204
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
165
205
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
166
206
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
167
207
|
length = item_len * len(prefill_index)
|
168
208
|
|
169
|
-
# TODO: make async later
|
170
209
|
status = self.engine.transfer_sync(
|
171
210
|
mooncake_session_id, src_addr, dst_addr, length
|
172
211
|
)
|
173
212
|
if status != 0:
|
174
213
|
return status
|
214
|
+
return 0
|
215
|
+
|
216
|
+
futures = [
|
217
|
+
self.executor.submit(
|
218
|
+
process_layer,
|
219
|
+
src_ptr,
|
220
|
+
dst_ptr,
|
221
|
+
item_len,
|
222
|
+
)
|
223
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
224
|
+
]
|
225
|
+
|
226
|
+
for future in concurrent.futures.as_completed(futures):
|
227
|
+
status = future.result()
|
228
|
+
if status != 0:
|
229
|
+
# Immediate shutdown on first error (existing tasks will finish)
|
230
|
+
executor.shutdown(wait=False)
|
231
|
+
for f in futures:
|
232
|
+
f.cancel()
|
233
|
+
return status
|
175
234
|
|
176
235
|
return 0
|
177
236
|
|
@@ -215,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
215
274
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
216
275
|
room = waiting_req_bytes[0].decode("ascii")
|
217
276
|
if room == "None":
|
277
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
278
|
+
self.decode_kv_args_table[mooncake_session_id] = (
|
279
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
280
|
+
)
|
281
|
+
logger.debug(
|
282
|
+
f"Register KVArgs from {mooncake_session_id} successfully"
|
283
|
+
)
|
218
284
|
continue
|
219
285
|
room = int(room)
|
220
286
|
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
@@ -231,12 +297,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
231
297
|
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
232
298
|
assert len(chunked_dst_kv_indice) == len(
|
233
299
|
kv_chunk.prefill_kv_indices
|
234
|
-
)
|
300
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
235
301
|
|
236
302
|
ret = self.send_kvcache(
|
237
303
|
req.mooncake_session_id,
|
238
304
|
kv_chunk.prefill_kv_indices,
|
239
|
-
req.dst_kv_ptrs,
|
305
|
+
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
|
240
306
|
chunked_dst_kv_indice,
|
241
307
|
)
|
242
308
|
if ret != 0:
|
@@ -251,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
251
317
|
ret = self.send_aux(
|
252
318
|
req.mooncake_session_id,
|
253
319
|
kv_chunk.prefill_aux_index,
|
254
|
-
|
320
|
+
self.decode_kv_args_table[
|
321
|
+
req.mooncake_session_id
|
322
|
+
].dst_aux_ptrs,
|
255
323
|
req.dst_aux_index,
|
256
324
|
)
|
257
325
|
self.request_status[req.room] = (
|
@@ -331,6 +399,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
331
399
|
url = f"http://{bootstrap_server_url}/route"
|
332
400
|
payload = {
|
333
401
|
"role": "Prefill",
|
402
|
+
"tp_size": self.tp_size,
|
403
|
+
"dp_size": self.dp_size,
|
334
404
|
"rank_ip": get_local_ip_by_remote(),
|
335
405
|
"rank_port": self.rank_port,
|
336
406
|
"engine_rank": self.kv_args.engine_rank,
|
@@ -408,12 +478,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
408
478
|
self.session_id = self.kv_mgr.get_session_id()
|
409
479
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
410
480
|
|
481
|
+
if not self.kv_mgr.enable_dp_attention:
|
482
|
+
# We assume dp_attention should be activated simultaneously for
|
483
|
+
# both prefill role and decode role. If the decode instance does
|
484
|
+
# not enable dp_attention, then dp_attention is not enabled on the
|
485
|
+
# prefill instance as well. Therefore, we should skip questioning
|
486
|
+
# the prefill dp size to reduce bootstrap overhead.
|
487
|
+
self.prefill_dp_size = 1
|
488
|
+
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
489
|
+
self.prefill_dp_size, tp_size_per_dp_rank = (
|
490
|
+
self._get_prefill_dp_size_from_server()
|
491
|
+
)
|
492
|
+
# Currently, we don't allow prefill instance and decode instance to
|
493
|
+
# have different TP sizes per DP rank.
|
494
|
+
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
495
|
+
if self.prefill_dp_size is None:
|
496
|
+
logger.error(
|
497
|
+
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
|
498
|
+
)
|
499
|
+
else:
|
500
|
+
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
501
|
+
self.prefill_dp_size
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
505
|
+
self.bootstrap_addr
|
506
|
+
]
|
507
|
+
|
411
508
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
509
|
+
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
412
510
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
413
511
|
|
414
512
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
415
513
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
416
|
-
self.kv_mgr.kv_args.engine_rank
|
514
|
+
self.kv_mgr.kv_args.engine_rank,
|
515
|
+
self.target_dp_group,
|
417
516
|
)
|
418
517
|
if self.bootstrap_info is None:
|
419
518
|
logger.error(
|
@@ -421,16 +520,18 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
421
520
|
)
|
422
521
|
else:
|
423
522
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
523
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
524
|
+
self._register_kv_args()
|
424
525
|
else:
|
425
526
|
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
426
527
|
|
427
528
|
assert self.bootstrap_info is not None
|
428
529
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
429
530
|
|
430
|
-
def _get_bootstrap_info_from_server(self, engine_rank):
|
531
|
+
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
431
532
|
"""Fetch the bootstrap info from the bootstrap server."""
|
432
533
|
try:
|
433
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
534
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
434
535
|
response = requests.get(url)
|
435
536
|
if response.status_code == 200:
|
436
537
|
bootstrap_info = response.json()
|
@@ -444,6 +545,49 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
444
545
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
445
546
|
return None
|
446
547
|
|
548
|
+
def _get_prefill_dp_size_from_server(self) -> int:
|
549
|
+
"""Fetch the prefill parallel info from the bootstrap server."""
|
550
|
+
try:
|
551
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
552
|
+
response = requests.get(url)
|
553
|
+
if response.status_code == 200:
|
554
|
+
prefill_parallel_info = response.json()
|
555
|
+
return int(prefill_parallel_info["prefill_dp_size"]), int(
|
556
|
+
prefill_parallel_info["tp_size_per_dp_rank"]
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
logger.error(
|
560
|
+
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
561
|
+
)
|
562
|
+
return None
|
563
|
+
except Exception as e:
|
564
|
+
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
565
|
+
return None
|
566
|
+
|
567
|
+
def _register_kv_args(self):
|
568
|
+
self.prefill_server_url = (
|
569
|
+
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
570
|
+
)
|
571
|
+
|
572
|
+
packed_kv_data_ptrs = b"".join(
|
573
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
574
|
+
)
|
575
|
+
packed_aux_data_ptrs = b"".join(
|
576
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
577
|
+
)
|
578
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
579
|
+
with lock:
|
580
|
+
sock.send_multipart(
|
581
|
+
[
|
582
|
+
"None".encode("ascii"),
|
583
|
+
get_local_ip_by_remote().encode("ascii"),
|
584
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
585
|
+
self.session_id.encode("ascii"),
|
586
|
+
packed_kv_data_ptrs,
|
587
|
+
packed_aux_data_ptrs,
|
588
|
+
]
|
589
|
+
)
|
590
|
+
|
447
591
|
@classmethod
|
448
592
|
def _connect(cls, endpoint: str):
|
449
593
|
with cls._global_lock:
|
@@ -462,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
462
606
|
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
463
607
|
)
|
464
608
|
|
465
|
-
packed_kv_data_ptrs = b"".join(
|
466
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
467
|
-
)
|
468
|
-
packed_aux_data_ptrs = b"".join(
|
469
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
470
|
-
)
|
471
609
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
610
|
with lock:
|
473
611
|
sock.send_multipart(
|
@@ -476,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
476
614
|
get_local_ip_by_remote().encode("ascii"),
|
477
615
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
616
|
self.session_id.encode("ascii"),
|
479
|
-
packed_kv_data_ptrs,
|
480
617
|
kv_indices.tobytes(),
|
481
|
-
packed_aux_data_ptrs,
|
482
618
|
str(aux_index).encode("ascii"),
|
483
619
|
]
|
484
620
|
)
|
@@ -497,7 +633,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
497
633
|
self.store = dict()
|
498
634
|
self.lock = asyncio.Lock()
|
499
635
|
self._setup_routes()
|
500
|
-
self.
|
636
|
+
self.dp_size = None
|
637
|
+
self.tp_size_per_dp_rank = None
|
638
|
+
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
501
639
|
|
502
640
|
# Start bootstrap server
|
503
641
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -523,35 +661,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
523
661
|
async def _handle_route_put(self, request: web.Request):
|
524
662
|
data = await request.json()
|
525
663
|
role = data["role"]
|
664
|
+
tp_size = data["tp_size"]
|
665
|
+
dp_size = data["dp_size"]
|
526
666
|
rank_ip = data["rank_ip"]
|
527
667
|
rank_port = int(data["rank_port"])
|
528
668
|
engine_rank = int(data["engine_rank"])
|
529
669
|
|
670
|
+
if self.dp_size is None:
|
671
|
+
self.dp_size = dp_size
|
672
|
+
|
673
|
+
tp_size_per_dp_rank = tp_size // dp_size
|
674
|
+
if self.tp_size_per_dp_rank == None:
|
675
|
+
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
676
|
+
|
530
677
|
# Add lock to make sure thread-safe
|
531
678
|
if role == "Prefill":
|
532
|
-
|
679
|
+
dp_group = engine_rank // tp_size_per_dp_rank
|
680
|
+
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
681
|
+
|
682
|
+
async with self.lock:
|
683
|
+
if dp_group not in self.prefill_port_table:
|
684
|
+
self.prefill_port_table[dp_group] = {}
|
685
|
+
|
686
|
+
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
533
687
|
"rank_ip": rank_ip,
|
534
688
|
"rank_port": rank_port,
|
535
689
|
}
|
536
690
|
logger.debug(
|
537
|
-
f"
|
691
|
+
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
538
692
|
)
|
539
693
|
|
540
694
|
return web.Response(text="OK", status=200)
|
541
695
|
|
542
696
|
async def _handle_route_get(self, request: web.Request):
|
543
697
|
engine_rank = request.query.get("engine_rank")
|
544
|
-
|
545
|
-
|
698
|
+
target_dp_group = request.query.get("target_dp_group")
|
699
|
+
if not engine_rank or not target_dp_group:
|
700
|
+
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
701
|
+
|
702
|
+
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
703
|
+
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
704
|
+
prefill_parallel_info = {
|
705
|
+
"prefill_dp_size": self.dp_size,
|
706
|
+
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
707
|
+
}
|
708
|
+
return web.json_response(prefill_parallel_info, status=200)
|
546
709
|
|
547
710
|
# Find corresponding prefill info
|
711
|
+
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
712
|
+
|
548
713
|
async with self.lock:
|
549
|
-
bootstrap_info = self.prefill_port_table
|
714
|
+
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
715
|
+
tp_rank_in_dp_group
|
716
|
+
]
|
550
717
|
|
551
718
|
if bootstrap_info is not None:
|
552
719
|
return web.json_response(bootstrap_info, status=200)
|
553
720
|
else:
|
554
|
-
return web.Response(text="
|
721
|
+
return web.Response(text="Bootstrap info not Found", status=404)
|
555
722
|
|
556
723
|
def _run_server(self):
|
557
724
|
try:
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
|