sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +67 -13
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +36 -12
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +35 -4
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +38 -12
- sglang/srt/managers/scheduler.py +41 -28
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +19 -25
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +50 -11
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +31 -24
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +5 -1
- sglang/test/runners.py +6 -13
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +74 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import concurrent.futures
|
4
5
|
import dataclasses
|
5
6
|
import logging
|
7
|
+
import os
|
6
8
|
import queue
|
7
9
|
import socket
|
8
10
|
import struct
|
@@ -73,9 +75,7 @@ class TransferInfo:
|
|
73
75
|
endpoint: str
|
74
76
|
dst_port: int
|
75
77
|
mooncake_session_id: str
|
76
|
-
dst_kv_ptrs: list[int]
|
77
78
|
dst_kv_indices: npt.NDArray[np.int64]
|
78
|
-
dst_aux_ptrs: list[int]
|
79
79
|
dst_aux_index: int
|
80
80
|
|
81
81
|
@classmethod
|
@@ -85,10 +85,29 @@ class TransferInfo:
|
|
85
85
|
endpoint=msg[1].decode("ascii"),
|
86
86
|
dst_port=int(msg[2].decode("ascii")),
|
87
87
|
mooncake_session_id=msg[3].decode("ascii"),
|
88
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
|
89
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
@dataclasses.dataclass
|
94
|
+
class KVArgsRegisterInfo:
|
95
|
+
room: str
|
96
|
+
endpoint: str
|
97
|
+
dst_port: int
|
98
|
+
mooncake_session_id: str
|
99
|
+
dst_kv_ptrs: list[int]
|
100
|
+
dst_aux_ptrs: list[int]
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_zmq(cls, msg: List[bytes]):
|
104
|
+
return cls(
|
105
|
+
room=str(msg[0].decode("ascii")),
|
106
|
+
endpoint=msg[1].decode("ascii"),
|
107
|
+
dst_port=int(msg[2].decode("ascii")),
|
108
|
+
mooncake_session_id=msg[3].decode("ascii"),
|
88
109
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
89
|
-
|
90
|
-
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
91
|
-
dst_aux_index=int(msg[7].decode("ascii")),
|
110
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
92
111
|
)
|
93
112
|
|
94
113
|
|
@@ -109,6 +128,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
109
128
|
# for p/d multi node infer
|
110
129
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
111
130
|
self.dist_init_addr = server_args.dist_init_addr
|
131
|
+
self.tp_size = server_args.tp_size
|
132
|
+
self.dp_size = server_args.dp_size
|
133
|
+
self.enable_dp_attention = server_args.enable_dp_attention
|
134
|
+
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
135
|
+
raise ValueError(
|
136
|
+
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
137
|
+
)
|
112
138
|
self.request_status: Dict[int, KVPoll] = {}
|
113
139
|
self.rank_port = None
|
114
140
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
@@ -116,11 +142,19 @@ class MooncakeKVManager(BaseKVManager):
|
|
116
142
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
117
143
|
self.transfer_queue = queue.Queue()
|
118
144
|
self.transfer_infos: Dict[int, TransferInfo] = {}
|
145
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
119
146
|
self.start_prefill_thread()
|
120
147
|
self._register_to_bootstrap()
|
148
|
+
|
149
|
+
# Determine the number of threads to use for kv sender
|
150
|
+
cpu_count = os.cpu_count()
|
151
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(
|
152
|
+
min(cpu_count // 4, 16)
|
153
|
+
)
|
121
154
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
122
155
|
self.start_decode_thread()
|
123
156
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
157
|
+
self.prefill_dp_size_table: Dict[str, int] = {}
|
124
158
|
else:
|
125
159
|
raise ValueError(
|
126
160
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -150,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
|
|
150
184
|
dst_kv_ptrs: list[int],
|
151
185
|
dst_kv_indices: npt.NDArray[np.int64],
|
152
186
|
):
|
153
|
-
#
|
187
|
+
# Group by indices
|
154
188
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
155
189
|
prefill_kv_indices, dst_kv_indices
|
156
190
|
)
|
157
191
|
|
158
192
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
193
|
+
layers_params = [
|
194
|
+
(
|
195
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
196
|
+
dst_kv_ptrs[layer_id],
|
197
|
+
self.kv_args.kv_item_lens[layer_id],
|
198
|
+
)
|
199
|
+
for layer_id in range(num_layers)
|
200
|
+
]
|
163
201
|
|
202
|
+
# Worker function for processing a single layer
|
203
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
164
204
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
165
205
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
166
206
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
167
207
|
length = item_len * len(prefill_index)
|
168
208
|
|
169
|
-
# TODO: make async later
|
170
209
|
status = self.engine.transfer_sync(
|
171
210
|
mooncake_session_id, src_addr, dst_addr, length
|
172
211
|
)
|
173
212
|
if status != 0:
|
174
213
|
return status
|
214
|
+
return 0
|
215
|
+
|
216
|
+
futures = [
|
217
|
+
self.executor.submit(
|
218
|
+
process_layer,
|
219
|
+
src_ptr,
|
220
|
+
dst_ptr,
|
221
|
+
item_len,
|
222
|
+
)
|
223
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
224
|
+
]
|
225
|
+
|
226
|
+
for future in concurrent.futures.as_completed(futures):
|
227
|
+
status = future.result()
|
228
|
+
if status != 0:
|
229
|
+
# Immediate shutdown on first error (existing tasks will finish)
|
230
|
+
executor.shutdown(wait=False)
|
231
|
+
for f in futures:
|
232
|
+
f.cancel()
|
233
|
+
return status
|
175
234
|
|
176
235
|
return 0
|
177
236
|
|
@@ -215,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
215
274
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
216
275
|
room = waiting_req_bytes[0].decode("ascii")
|
217
276
|
if room == "None":
|
277
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
278
|
+
self.decode_kv_args_table[mooncake_session_id] = (
|
279
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
280
|
+
)
|
281
|
+
logger.debug(
|
282
|
+
f"Register KVArgs from {mooncake_session_id} successfully"
|
283
|
+
)
|
218
284
|
continue
|
219
285
|
room = int(room)
|
220
286
|
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
@@ -236,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
236
302
|
ret = self.send_kvcache(
|
237
303
|
req.mooncake_session_id,
|
238
304
|
kv_chunk.prefill_kv_indices,
|
239
|
-
req.dst_kv_ptrs,
|
305
|
+
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
|
240
306
|
chunked_dst_kv_indice,
|
241
307
|
)
|
242
308
|
if ret != 0:
|
@@ -251,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
251
317
|
ret = self.send_aux(
|
252
318
|
req.mooncake_session_id,
|
253
319
|
kv_chunk.prefill_aux_index,
|
254
|
-
|
320
|
+
self.decode_kv_args_table[
|
321
|
+
req.mooncake_session_id
|
322
|
+
].dst_aux_ptrs,
|
255
323
|
req.dst_aux_index,
|
256
324
|
)
|
257
325
|
self.request_status[req.room] = (
|
@@ -331,6 +399,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
331
399
|
url = f"http://{bootstrap_server_url}/route"
|
332
400
|
payload = {
|
333
401
|
"role": "Prefill",
|
402
|
+
"tp_size": self.tp_size,
|
403
|
+
"dp_size": self.dp_size,
|
334
404
|
"rank_ip": get_local_ip_by_remote(),
|
335
405
|
"rank_port": self.rank_port,
|
336
406
|
"engine_rank": self.kv_args.engine_rank,
|
@@ -408,12 +478,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
408
478
|
self.session_id = self.kv_mgr.get_session_id()
|
409
479
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
410
480
|
|
481
|
+
if not self.kv_mgr.enable_dp_attention:
|
482
|
+
# We assume dp_attention should be activated simultaneously for
|
483
|
+
# both prefill role and decode role. If the decode instance does
|
484
|
+
# not enable dp_attention, then dp_attention is not enabled on the
|
485
|
+
# prefill instance as well. Therefore, we should skip questioning
|
486
|
+
# the prefill dp size to reduce bootstrap overhead.
|
487
|
+
self.prefill_dp_size = 1
|
488
|
+
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
489
|
+
self.prefill_dp_size, tp_size_per_dp_rank = (
|
490
|
+
self._get_prefill_dp_size_from_server()
|
491
|
+
)
|
492
|
+
# Currently, we don't allow prefill instance and decode instance to
|
493
|
+
# have different TP sizes per DP rank.
|
494
|
+
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
495
|
+
if self.prefill_dp_size is None:
|
496
|
+
logger.error(
|
497
|
+
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
|
498
|
+
)
|
499
|
+
else:
|
500
|
+
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
501
|
+
self.prefill_dp_size
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
505
|
+
self.bootstrap_addr
|
506
|
+
]
|
507
|
+
|
411
508
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
509
|
+
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
412
510
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
413
511
|
|
414
512
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
415
513
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
416
|
-
self.kv_mgr.kv_args.engine_rank
|
514
|
+
self.kv_mgr.kv_args.engine_rank,
|
515
|
+
self.target_dp_group,
|
417
516
|
)
|
418
517
|
if self.bootstrap_info is None:
|
419
518
|
logger.error(
|
@@ -421,16 +520,18 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
421
520
|
)
|
422
521
|
else:
|
423
522
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
523
|
+
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
524
|
+
self._register_kv_args()
|
424
525
|
else:
|
425
526
|
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
426
527
|
|
427
528
|
assert self.bootstrap_info is not None
|
428
529
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
429
530
|
|
430
|
-
def _get_bootstrap_info_from_server(self, engine_rank):
|
531
|
+
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
431
532
|
"""Fetch the bootstrap info from the bootstrap server."""
|
432
533
|
try:
|
433
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
534
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
434
535
|
response = requests.get(url)
|
435
536
|
if response.status_code == 200:
|
436
537
|
bootstrap_info = response.json()
|
@@ -444,6 +545,49 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
444
545
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
445
546
|
return None
|
446
547
|
|
548
|
+
def _get_prefill_dp_size_from_server(self) -> int:
|
549
|
+
"""Fetch the prefill parallel info from the bootstrap server."""
|
550
|
+
try:
|
551
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
552
|
+
response = requests.get(url)
|
553
|
+
if response.status_code == 200:
|
554
|
+
prefill_parallel_info = response.json()
|
555
|
+
return int(prefill_parallel_info["prefill_dp_size"]), int(
|
556
|
+
prefill_parallel_info["tp_size_per_dp_rank"]
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
logger.error(
|
560
|
+
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
561
|
+
)
|
562
|
+
return None
|
563
|
+
except Exception as e:
|
564
|
+
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
565
|
+
return None
|
566
|
+
|
567
|
+
def _register_kv_args(self):
|
568
|
+
self.prefill_server_url = (
|
569
|
+
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
570
|
+
)
|
571
|
+
|
572
|
+
packed_kv_data_ptrs = b"".join(
|
573
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
574
|
+
)
|
575
|
+
packed_aux_data_ptrs = b"".join(
|
576
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
577
|
+
)
|
578
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
579
|
+
with lock:
|
580
|
+
sock.send_multipart(
|
581
|
+
[
|
582
|
+
"None".encode("ascii"),
|
583
|
+
get_local_ip_by_remote().encode("ascii"),
|
584
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
585
|
+
self.session_id.encode("ascii"),
|
586
|
+
packed_kv_data_ptrs,
|
587
|
+
packed_aux_data_ptrs,
|
588
|
+
]
|
589
|
+
)
|
590
|
+
|
447
591
|
@classmethod
|
448
592
|
def _connect(cls, endpoint: str):
|
449
593
|
with cls._global_lock:
|
@@ -462,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
462
606
|
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
463
607
|
)
|
464
608
|
|
465
|
-
packed_kv_data_ptrs = b"".join(
|
466
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
467
|
-
)
|
468
|
-
packed_aux_data_ptrs = b"".join(
|
469
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
470
|
-
)
|
471
609
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
610
|
with lock:
|
473
611
|
sock.send_multipart(
|
@@ -476,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
476
614
|
get_local_ip_by_remote().encode("ascii"),
|
477
615
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
616
|
self.session_id.encode("ascii"),
|
479
|
-
packed_kv_data_ptrs,
|
480
617
|
kv_indices.tobytes(),
|
481
|
-
packed_aux_data_ptrs,
|
482
618
|
str(aux_index).encode("ascii"),
|
483
619
|
]
|
484
620
|
)
|
@@ -497,7 +633,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
497
633
|
self.store = dict()
|
498
634
|
self.lock = asyncio.Lock()
|
499
635
|
self._setup_routes()
|
500
|
-
self.
|
636
|
+
self.dp_size = None
|
637
|
+
self.tp_size_per_dp_rank = None
|
638
|
+
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
501
639
|
|
502
640
|
# Start bootstrap server
|
503
641
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -523,35 +661,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
523
661
|
async def _handle_route_put(self, request: web.Request):
|
524
662
|
data = await request.json()
|
525
663
|
role = data["role"]
|
664
|
+
tp_size = data["tp_size"]
|
665
|
+
dp_size = data["dp_size"]
|
526
666
|
rank_ip = data["rank_ip"]
|
527
667
|
rank_port = int(data["rank_port"])
|
528
668
|
engine_rank = int(data["engine_rank"])
|
529
669
|
|
670
|
+
if self.dp_size is None:
|
671
|
+
self.dp_size = dp_size
|
672
|
+
|
673
|
+
tp_size_per_dp_rank = tp_size // dp_size
|
674
|
+
if self.tp_size_per_dp_rank == None:
|
675
|
+
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
676
|
+
|
530
677
|
# Add lock to make sure thread-safe
|
531
678
|
if role == "Prefill":
|
532
|
-
|
679
|
+
dp_group = engine_rank // tp_size_per_dp_rank
|
680
|
+
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
681
|
+
|
682
|
+
async with self.lock:
|
683
|
+
if dp_group not in self.prefill_port_table:
|
684
|
+
self.prefill_port_table[dp_group] = {}
|
685
|
+
|
686
|
+
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
533
687
|
"rank_ip": rank_ip,
|
534
688
|
"rank_port": rank_port,
|
535
689
|
}
|
536
690
|
logger.debug(
|
537
|
-
f"
|
691
|
+
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
538
692
|
)
|
539
693
|
|
540
694
|
return web.Response(text="OK", status=200)
|
541
695
|
|
542
696
|
async def _handle_route_get(self, request: web.Request):
|
543
697
|
engine_rank = request.query.get("engine_rank")
|
544
|
-
|
545
|
-
|
698
|
+
target_dp_group = request.query.get("target_dp_group")
|
699
|
+
if not engine_rank or not target_dp_group:
|
700
|
+
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
701
|
+
|
702
|
+
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
703
|
+
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
704
|
+
prefill_parallel_info = {
|
705
|
+
"prefill_dp_size": self.dp_size,
|
706
|
+
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
707
|
+
}
|
708
|
+
return web.json_response(prefill_parallel_info, status=200)
|
546
709
|
|
547
710
|
# Find corresponding prefill info
|
711
|
+
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
712
|
+
|
548
713
|
async with self.lock:
|
549
|
-
bootstrap_info = self.prefill_port_table
|
714
|
+
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
715
|
+
tp_rank_in_dp_group
|
716
|
+
]
|
550
717
|
|
551
718
|
if bootstrap_info is not None:
|
552
719
|
return web.json_response(bootstrap_info, status=200)
|
553
720
|
else:
|
554
|
-
return web.Response(text="
|
721
|
+
return web.Response(text="Bootstrap info not Found", status=404)
|
555
722
|
|
556
723
|
def _run_server(self):
|
557
724
|
try:
|
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
import threading
|
23
24
|
from collections import deque
|
24
25
|
from typing import TYPE_CHECKING, List, Optional
|
25
26
|
|
@@ -28,6 +29,7 @@ import torch
|
|
28
29
|
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
29
30
|
from sglang.srt.disaggregation.utils import (
|
30
31
|
DisaggregationMode,
|
32
|
+
FakeBootstrapHost,
|
31
33
|
KVClassType,
|
32
34
|
ReqToMetadataIdxAllocator,
|
33
35
|
TransferBackend,
|
@@ -115,7 +117,11 @@ class PrefillBootstrapQueue:
|
|
115
117
|
return kv_manager
|
116
118
|
|
117
119
|
def add(self, req: Req) -> None:
|
118
|
-
|
120
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
121
|
+
# Fake transfer for warmup reqs
|
122
|
+
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
123
|
+
else:
|
124
|
+
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
119
125
|
req.disagg_kv_sender = kv_sender_class(
|
120
126
|
mgr=self.kv_manager,
|
121
127
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
@@ -176,17 +182,25 @@ class SchedulerDisaggregationPrefillMixin:
|
|
176
182
|
"""
|
177
183
|
|
178
184
|
@torch.no_grad()
|
179
|
-
def event_loop_normal_disagg_prefill(self):
|
185
|
+
def event_loop_normal_disagg_prefill(self: Scheduler):
|
180
186
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
181
187
|
|
182
188
|
while True:
|
183
189
|
recv_reqs = self.recv_requests()
|
184
190
|
self.process_input_requests(recv_reqs)
|
185
191
|
self.waiting_queue.extend(
|
186
|
-
self.
|
192
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
187
193
|
)
|
188
194
|
self.process_prefill_chunk()
|
189
195
|
batch = self.get_new_batch_prefill()
|
196
|
+
|
197
|
+
# Handle DP attention
|
198
|
+
if (
|
199
|
+
self.server_args.enable_dp_attention
|
200
|
+
or self.server_args.enable_sp_layernorm
|
201
|
+
):
|
202
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
203
|
+
|
190
204
|
self.cur_batch = batch
|
191
205
|
|
192
206
|
if batch:
|
@@ -206,17 +220,25 @@ class SchedulerDisaggregationPrefillMixin:
|
|
206
220
|
self.running_batch.batch_is_full = False
|
207
221
|
|
208
222
|
@torch.no_grad()
|
209
|
-
def event_loop_overlap_disagg_prefill(self):
|
223
|
+
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
210
224
|
self.result_queue = deque()
|
211
225
|
|
212
226
|
while True:
|
213
227
|
recv_reqs = self.recv_requests()
|
214
228
|
self.process_input_requests(recv_reqs)
|
215
229
|
self.waiting_queue.extend(
|
216
|
-
self.
|
230
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
217
231
|
)
|
218
232
|
self.process_prefill_chunk()
|
219
233
|
batch = self.get_new_batch_prefill()
|
234
|
+
|
235
|
+
# Handle DP attention
|
236
|
+
if (
|
237
|
+
self.server_args.enable_dp_attention
|
238
|
+
or self.server_args.enable_sp_layernorm
|
239
|
+
):
|
240
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
241
|
+
|
220
242
|
self.cur_batch = batch
|
221
243
|
|
222
244
|
if batch:
|
@@ -240,7 +262,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
240
262
|
self.running_batch.batch_is_full = False
|
241
263
|
|
242
264
|
def process_batch_result_disagg_prefill(
|
243
|
-
self: Scheduler,
|
265
|
+
self: Scheduler,
|
266
|
+
batch: ScheduleBatch,
|
267
|
+
result: GenerationBatchResult,
|
268
|
+
launch_done: Optional[threading.Event] = None,
|
244
269
|
) -> None:
|
245
270
|
"""
|
246
271
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
@@ -264,7 +289,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
264
289
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
265
290
|
if self.enable_overlap:
|
266
291
|
# wait
|
267
|
-
_, next_token_ids = self.tp_worker.
|
292
|
+
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
268
293
|
else:
|
269
294
|
next_token_ids = result.next_token_ids.tolist()
|
270
295
|
|
@@ -310,7 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
310
335
|
raise Exception("Transferring failed")
|
311
336
|
|
312
337
|
for req in done_reqs:
|
313
|
-
self.
|
338
|
+
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
314
339
|
req.metadata_buffer_index
|
315
340
|
)
|
316
341
|
|
@@ -326,9 +351,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
326
351
|
# only finished requests to running_batch.
|
327
352
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
328
353
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
329
|
-
if
|
330
|
-
|
331
|
-
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
354
|
+
if self.enable_overlap:
|
355
|
+
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
332
356
|
self.chunked_req.tmp_end_idx = min(
|
333
357
|
len(self.chunked_req.fill_ids),
|
334
358
|
len(self.chunked_req.origin_input_ids),
|
@@ -374,7 +398,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
374
398
|
.numpy()
|
375
399
|
)
|
376
400
|
if last_chunk is True:
|
377
|
-
self.
|
401
|
+
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
378
402
|
req.metadata_buffer_index, token_id
|
379
403
|
)
|
380
404
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
|
|
15
15
|
DECODE = "decode"
|
16
16
|
|
17
17
|
|
18
|
+
FakeBootstrapHost = "2.2.2.2"
|
19
|
+
|
20
|
+
|
18
21
|
def poll_and_all_reduce(pollers, gloo_group):
|
19
22
|
polls = [int(poller.poll()) for poller in pollers]
|
20
23
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
@@ -59,6 +62,8 @@ class KVClassType(Enum):
|
|
59
62
|
|
60
63
|
|
61
64
|
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
65
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
66
|
+
|
62
67
|
if transfer_backend == TransferBackend.MOONCAKE:
|
63
68
|
from sglang.srt.disaggregation.mooncake import (
|
64
69
|
MooncakeKVBootstrapServer,
|
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
70
75
|
class_mapping = {
|
71
76
|
KVClassType.MANAGER: MooncakeKVManager,
|
72
77
|
KVClassType.SENDER: MooncakeKVSender,
|
73
|
-
KVClassType.RECEIVER: MooncakeKVReceiver,
|
78
|
+
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
74
79
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
75
80
|
}
|
76
81
|
return class_mapping.get(class_type)
|
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
85
90
|
class_mapping = {
|
86
91
|
KVClassType.MANAGER: NixlKVManager,
|
87
92
|
KVClassType.SENDER: NixlKVSender,
|
88
|
-
KVClassType.RECEIVER: NixlKVReceiver,
|
93
|
+
KVClassType.RECEIVER: (NixlKVReceiver),
|
89
94
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
95
|
}
|
91
96
|
return class_mapping.get(class_type)
|
97
|
+
if transfer_backend == TransferBackend.FAKE:
|
98
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
99
|
+
|
100
|
+
class_mapping = {
|
101
|
+
KVClassType.SENDER: FakeKVSender,
|
102
|
+
KVClassType.RECEIVER: (FakeKVReceiver),
|
103
|
+
}
|
104
|
+
return class_mapping.get(class_type)
|
105
|
+
|
92
106
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
93
107
|
|
94
108
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -66,6 +66,7 @@ from sglang.srt.utils import (
|
|
66
66
|
assert_pkg_version,
|
67
67
|
configure_logger,
|
68
68
|
get_zmq_socket,
|
69
|
+
is_cuda,
|
69
70
|
kill_process_tree,
|
70
71
|
launch_dummy_health_check_server,
|
71
72
|
maybe_set_triton_cache_manager,
|
@@ -78,6 +79,8 @@ from sglang.version import __version__
|
|
78
79
|
logger = logging.getLogger(__name__)
|
79
80
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
80
81
|
|
82
|
+
_is_cuda = is_cuda()
|
83
|
+
|
81
84
|
|
82
85
|
class Engine(EngineBase):
|
83
86
|
"""
|
@@ -452,6 +455,12 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
452
455
|
"reinstall the latest version by following the instructions "
|
453
456
|
"at https://docs.flashinfer.ai/installation.html.",
|
454
457
|
)
|
458
|
+
if _is_cuda:
|
459
|
+
assert_pkg_version(
|
460
|
+
"sgl-kernel",
|
461
|
+
"0.1.0",
|
462
|
+
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
463
|
+
)
|
455
464
|
|
456
465
|
def sigchld_handler(signum, frame):
|
457
466
|
pid, exitcode = os.waitpid(0, os.WNOHANG)
|