sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ import threading
|
|
10
10
|
import uuid
|
11
11
|
from collections import defaultdict
|
12
12
|
from functools import cache
|
13
|
-
from typing import Dict, List, Optional, Tuple, Union
|
13
|
+
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
import numpy.typing as npt
|
@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
|
32
32
|
|
33
33
|
logger = logging.getLogger(__name__)
|
34
34
|
|
35
|
+
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
36
|
+
|
37
|
+
|
38
|
+
# From Mooncake backend.
|
39
|
+
def group_concurrent_contiguous(
|
40
|
+
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
41
|
+
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
42
|
+
src_groups = []
|
43
|
+
dst_groups = []
|
44
|
+
current_src = [src_indices[0]]
|
45
|
+
current_dst = [dst_indices[0]]
|
46
|
+
|
47
|
+
for i in range(1, len(src_indices)):
|
48
|
+
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
|
49
|
+
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
50
|
+
if src_contiguous and dst_contiguous:
|
51
|
+
current_src.append(src_indices[i])
|
52
|
+
current_dst.append(dst_indices[i])
|
53
|
+
else:
|
54
|
+
src_groups.append(current_src)
|
55
|
+
dst_groups.append(current_dst)
|
56
|
+
current_src = [src_indices[i]]
|
57
|
+
current_dst = [dst_indices[i]]
|
58
|
+
|
59
|
+
src_groups.append(current_src)
|
60
|
+
dst_groups.append(current_dst)
|
61
|
+
|
62
|
+
return src_groups, dst_groups
|
63
|
+
|
64
|
+
|
65
|
+
GUARD = "NixlMsgGuard".encode("ascii")
|
66
|
+
|
35
67
|
|
36
68
|
@dataclasses.dataclass
|
37
69
|
class TransferInfo:
|
@@ -45,19 +77,36 @@ class TransferInfo:
|
|
45
77
|
dst_aux_index: int
|
46
78
|
dst_gpu_id: int
|
47
79
|
|
80
|
+
def is_dummy(self):
|
81
|
+
return self.endpoint == ""
|
82
|
+
|
48
83
|
@classmethod
|
49
84
|
def from_zmq(cls, msg: List[bytes]):
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
85
|
+
if len(msg) == 1:
|
86
|
+
# dummy msg
|
87
|
+
return cls(
|
88
|
+
room=int(msg[0].decode("ascii")),
|
89
|
+
endpoint="",
|
90
|
+
dst_port=0,
|
91
|
+
agent_metadata=b"",
|
92
|
+
dst_kv_ptrs=[],
|
93
|
+
dst_kv_indices=np.array([], dtype=np.int64),
|
94
|
+
dst_aux_ptrs=[],
|
95
|
+
dst_aux_index=0,
|
96
|
+
dst_gpu_id=0,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
return cls(
|
100
|
+
room=int(msg[0].decode("ascii")),
|
101
|
+
endpoint=msg[1].decode("ascii"),
|
102
|
+
dst_port=int(msg[2].decode("ascii")),
|
103
|
+
agent_metadata=msg[3],
|
104
|
+
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
105
|
+
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
106
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
107
|
+
dst_aux_index=int(msg[7].decode("ascii")),
|
108
|
+
dst_gpu_id=int(msg[8].decode("ascii")),
|
109
|
+
)
|
61
110
|
|
62
111
|
|
63
112
|
@dataclasses.dataclass
|
@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager):
|
|
98
147
|
# for p/d multi node infer
|
99
148
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
100
149
|
self.dist_init_addr = server_args.dist_init_addr
|
150
|
+
self.tp_size = server_args.tp_size
|
151
|
+
|
152
|
+
self.tp_rank = args.engine_rank
|
153
|
+
self.enable_dp_attention = server_args.enable_dp_attention
|
154
|
+
if self.enable_dp_attention:
|
155
|
+
assert (
|
156
|
+
server_args.dp_size > 1
|
157
|
+
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
|
158
|
+
self.dp_size = server_args.dp_size
|
159
|
+
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
|
160
|
+
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
|
161
|
+
self.dp_rank = args.engine_rank // self.tp_size_of_dp
|
162
|
+
|
101
163
|
self.rank_port = None
|
102
164
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
103
165
|
self.register_buffer_to_engine()
|
@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager):
|
|
110
172
|
self._start_bootstrap_thread()
|
111
173
|
self._register_to_bootstrap()
|
112
174
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
113
|
-
|
175
|
+
# bootstrap key -> (remote_engine_rank -> possible remote source info)
|
176
|
+
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
|
114
177
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
115
178
|
TransferStatus
|
116
179
|
)
|
@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager):
|
|
126
189
|
):
|
127
190
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
128
191
|
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
|
192
|
+
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
129
193
|
if not self.kv_descs:
|
130
194
|
raise Exception("NIXL memory registration failed for kv tensors")
|
131
195
|
aux_addrs = []
|
@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager):
|
|
134
198
|
):
|
135
199
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
136
200
|
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
|
201
|
+
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
137
202
|
if not self.aux_descs:
|
138
203
|
raise Exception("NIXL memory registration failed for aux tensors")
|
139
204
|
|
@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager):
|
|
157
222
|
dst_gpu_id: int,
|
158
223
|
notif: str,
|
159
224
|
):
|
225
|
+
# group by indices
|
226
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
227
|
+
prefill_kv_indices, dst_kv_indices
|
228
|
+
)
|
229
|
+
|
230
|
+
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
160
231
|
# Make descs
|
161
232
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
162
233
|
src_addrs = []
|
@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager):
|
|
166
237
|
dst_ptr = dst_kv_ptrs[layer_id]
|
167
238
|
item_len = self.kv_args.kv_item_lens[layer_id]
|
168
239
|
|
169
|
-
for prefill_index, decode_index in zip(
|
170
|
-
src_addr = src_ptr + int(prefill_index) * item_len
|
171
|
-
dst_addr = dst_ptr + int(decode_index) * item_len
|
172
|
-
length = item_len
|
240
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
241
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
242
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
243
|
+
length = item_len * len(prefill_index)
|
173
244
|
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
|
174
245
|
dst_addrs.append((dst_addr, length, dst_gpu_id))
|
246
|
+
|
247
|
+
logger.debug(
|
248
|
+
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
249
|
+
)
|
175
250
|
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
|
176
251
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
|
177
252
|
# Transfer data
|
@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager):
|
|
180
255
|
src_descs,
|
181
256
|
dst_descs,
|
182
257
|
peer_name,
|
183
|
-
notif.encode("ascii"),
|
258
|
+
notif.encode("ascii"), # type: ignore
|
184
259
|
)
|
185
260
|
if not xfer_handle:
|
186
261
|
raise Exception("KVSender failed to create transfer")
|
@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager):
|
|
213
288
|
src_descs,
|
214
289
|
dst_descs,
|
215
290
|
peer_name,
|
216
|
-
notif.encode("ascii"),
|
291
|
+
notif.encode("ascii"), # type: ignore
|
217
292
|
)
|
218
293
|
if not xfer_handle:
|
219
294
|
raise Exception("KVSender failed to create transfer")
|
@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager):
|
|
240
315
|
req = self.transfer_infos[bootstrap_room]
|
241
316
|
assert bootstrap_room == req.room
|
242
317
|
|
318
|
+
if req.is_dummy():
|
319
|
+
return []
|
320
|
+
|
243
321
|
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
|
244
322
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
245
323
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager):
|
|
256
334
|
handles = [kv_xfer_handle]
|
257
335
|
# Only the last chunk we need to send the aux data.
|
258
336
|
if is_last:
|
337
|
+
assert aux_index is not None
|
259
338
|
aux_xfer_handle = self.send_aux(
|
260
339
|
peer_name,
|
261
340
|
aux_index,
|
@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager):
|
|
325
404
|
"""This thread recvs transfer info from the decode engine"""
|
326
405
|
while True:
|
327
406
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
407
|
+
logger.debug(
|
408
|
+
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
|
409
|
+
)
|
410
|
+
assert (
|
411
|
+
waiting_req_bytes[0] == GUARD
|
412
|
+
), f"First message should be {GUARD}. Foreign traffic?"
|
413
|
+
waiting_req_bytes = waiting_req_bytes[1:]
|
328
414
|
room = waiting_req_bytes[0].decode("ascii")
|
329
415
|
if room == "None":
|
330
416
|
continue
|
@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender):
|
|
372
458
|
|
373
459
|
def poll(self) -> KVPoll:
|
374
460
|
if not self.has_sent:
|
375
|
-
return KVPoll.WaitingForInput
|
376
|
-
|
461
|
+
return KVPoll.WaitingForInput # type: ignore
|
377
462
|
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
378
463
|
if all([x == "DONE" for x in states]):
|
379
|
-
return KVPoll.Success
|
464
|
+
return KVPoll.Success # type: ignore
|
380
465
|
if any([x == "ERR" for x in states]):
|
381
466
|
raise Exception("KVSender transfer encountered an error.")
|
382
|
-
return KVPoll.WaitingForInput
|
467
|
+
return KVPoll.WaitingForInput # type: ignore
|
383
468
|
|
384
469
|
def failure_exception(self):
|
385
470
|
raise Exception("Fake KVSender Exception")
|
@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
401
486
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
402
487
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
403
488
|
|
404
|
-
if bootstrap_key not in self.kv_mgr.
|
489
|
+
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
|
405
490
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
406
491
|
self.kv_mgr.kv_args.engine_rank
|
407
492
|
)
|
@@ -410,25 +495,79 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
410
495
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
411
496
|
)
|
412
497
|
else:
|
413
|
-
self.kv_mgr.
|
498
|
+
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
|
414
499
|
else:
|
415
|
-
self.bootstrap_info = self.kv_mgr.
|
416
|
-
|
500
|
+
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
|
417
501
|
assert self.bootstrap_info is not None
|
418
502
|
|
419
|
-
|
503
|
+
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
|
504
|
+
# In each dict, there are multiple possible remotes named "equal sources".
|
505
|
+
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
|
506
|
+
def _get_bootstrap_info_from_server(
|
507
|
+
self, engine_rank
|
508
|
+
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
|
420
509
|
"""Fetch the bootstrap info from the bootstrap server."""
|
421
510
|
try:
|
422
|
-
|
423
|
-
|
424
|
-
|
511
|
+
if self.kv_mgr.enable_dp_attention:
|
512
|
+
url = f"http://{self.bootstrap_addr}/route"
|
513
|
+
response = requests.get(url)
|
514
|
+
if response.status_code != 200:
|
515
|
+
logger.error(
|
516
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
517
|
+
)
|
518
|
+
return None
|
519
|
+
|
425
520
|
bootstrap_info = response.json()
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
521
|
+
assert isinstance(bootstrap_info, dict)
|
522
|
+
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
|
523
|
+
|
524
|
+
# split out who need to send to this rank.
|
525
|
+
# currently for dpsk mla model, those ranks share the same latent cache.
|
526
|
+
# pick one as the real source
|
527
|
+
|
528
|
+
prefill_tp_size = len(bootstrap_info.keys())
|
529
|
+
|
530
|
+
assert (
|
531
|
+
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
|
532
|
+
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
|
533
|
+
|
534
|
+
num_remote_tp_rank_we_managed = (
|
535
|
+
prefill_tp_size // self.kv_mgr.tp_size_of_dp
|
536
|
+
)
|
537
|
+
|
538
|
+
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
|
539
|
+
remote_tp_ranks = list(range(0, prefill_tp_size))
|
540
|
+
# split it into tp_size_of_dp parts and get our part
|
541
|
+
remote_tp_ranks_grouped = [
|
542
|
+
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
|
543
|
+
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
|
544
|
+
]
|
545
|
+
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
|
546
|
+
|
547
|
+
assert len(managed_ranks) == num_remote_tp_rank_we_managed
|
548
|
+
|
549
|
+
logger.debug(
|
550
|
+
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
|
430
551
|
)
|
431
|
-
|
552
|
+
|
553
|
+
return [
|
554
|
+
{
|
555
|
+
rk: bootstrap_info[rk]
|
556
|
+
for rk in bootstrap_info.keys()
|
557
|
+
if rk in managed_ranks
|
558
|
+
}
|
559
|
+
]
|
560
|
+
else:
|
561
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
562
|
+
response = requests.get(url)
|
563
|
+
if response.status_code == 200:
|
564
|
+
bootstrap_info = response.json()
|
565
|
+
return [{engine_rank: bootstrap_info}]
|
566
|
+
else:
|
567
|
+
logger.error(
|
568
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
569
|
+
)
|
570
|
+
return None
|
432
571
|
except Exception as e:
|
433
572
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
434
573
|
return None
|
@@ -440,43 +579,67 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
440
579
|
return socket
|
441
580
|
|
442
581
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
443
|
-
self.prefill_server_url = (
|
444
|
-
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
445
|
-
)
|
446
|
-
logger.debug(
|
447
|
-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
448
|
-
)
|
449
582
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
457
|
-
[
|
458
|
-
str(self.bootstrap_room).encode("ascii"),
|
459
|
-
get_local_ip_by_remote().encode("ascii"),
|
460
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
461
|
-
self.kv_mgr.agent.get_agent_metadata(),
|
462
|
-
packed_kv_data_ptrs,
|
463
|
-
kv_indices.tobytes(),
|
464
|
-
packed_aux_data_ptrs,
|
465
|
-
str(aux_index).encode("ascii"),
|
466
|
-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
583
|
+
assert self.bootstrap_info is not None
|
584
|
+
assert self.bootstrap_room is not None
|
585
|
+
|
586
|
+
for equal_sources in self.bootstrap_info:
|
587
|
+
remote_rank = list(equal_sources.keys())[
|
588
|
+
self.bootstrap_room % len(equal_sources)
|
467
589
|
]
|
468
|
-
|
590
|
+
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
|
591
|
+
logger.debug(
|
592
|
+
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
|
593
|
+
)
|
594
|
+
|
595
|
+
packed_kv_data_ptrs = b"".join(
|
596
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
597
|
+
)
|
598
|
+
packed_aux_data_ptrs = b"".join(
|
599
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
600
|
+
)
|
601
|
+
|
602
|
+
logger.debug(
|
603
|
+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
604
|
+
)
|
605
|
+
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
606
|
+
[
|
607
|
+
GUARD,
|
608
|
+
str(self.bootstrap_room).encode("ascii"),
|
609
|
+
get_local_ip_by_remote().encode("ascii"),
|
610
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
611
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
612
|
+
packed_kv_data_ptrs,
|
613
|
+
kv_indices.tobytes(),
|
614
|
+
packed_aux_data_ptrs,
|
615
|
+
str(aux_index).encode("ascii"),
|
616
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
617
|
+
]
|
618
|
+
)
|
619
|
+
|
620
|
+
for dummy_rank in equal_sources.keys():
|
621
|
+
if dummy_rank == remote_rank:
|
622
|
+
continue
|
623
|
+
dummy_info = equal_sources[dummy_rank]
|
624
|
+
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
|
625
|
+
self._connect("tcp://" + dummy_url).send_multipart(
|
626
|
+
[
|
627
|
+
GUARD,
|
628
|
+
str(self.bootstrap_room).encode("ascii"),
|
629
|
+
]
|
630
|
+
)
|
631
|
+
|
469
632
|
self.started_transfer = True
|
470
633
|
|
471
634
|
def poll(self) -> KVPoll:
|
472
635
|
if not self.started_transfer:
|
473
|
-
return KVPoll.WaitingForInput
|
636
|
+
return KVPoll.WaitingForInput # type: ignore
|
474
637
|
|
475
638
|
self.kv_mgr.update_transfer_status()
|
476
639
|
|
477
|
-
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
|
478
|
-
return KVPoll.Success
|
479
|
-
return KVPoll.WaitingForInput
|
640
|
+
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
641
|
+
return KVPoll.Success # type: ignore
|
642
|
+
return KVPoll.WaitingForInput # type: ignore
|
480
643
|
|
481
644
|
def failure_exception(self):
|
482
645
|
raise Exception("Fake KVReceiver Exception")
|
@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
484
647
|
|
485
648
|
class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
486
649
|
def __init__(self, port: int):
|
650
|
+
logger.debug(f"NixlKVBootstrapServer started on port {port}")
|
487
651
|
self.port = port
|
488
652
|
self.app = web.Application()
|
489
653
|
self.store = dict()
|
@@ -564,13 +728,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
|
564
728
|
engine_rank = int(data["engine_rank"])
|
565
729
|
agent_name = data["agent_name"]
|
566
730
|
|
567
|
-
# Add lock to make sure thread-safe
|
568
731
|
if role == "Prefill":
|
569
|
-
self.
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
732
|
+
async with self.lock:
|
733
|
+
self.prefill_port_table[engine_rank] = {
|
734
|
+
"rank_ip": rank_ip,
|
735
|
+
"rank_port": rank_port,
|
736
|
+
"agent_name": agent_name,
|
737
|
+
}
|
574
738
|
logger.info(
|
575
739
|
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
|
576
740
|
)
|
@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
|
580
744
|
async def _handle_route_get(self, request: web.Request):
|
581
745
|
engine_rank = request.query.get("engine_rank")
|
582
746
|
if not engine_rank:
|
583
|
-
|
747
|
+
logger.debug(
|
748
|
+
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
|
749
|
+
)
|
750
|
+
# Return a dict of all engine_rank
|
751
|
+
async with self.lock:
|
752
|
+
bootstrap_info = self.prefill_port_table
|
753
|
+
return web.json_response(bootstrap_info, status=200)
|
584
754
|
|
585
755
|
# Find corresponding prefill info
|
586
756
|
async with self.lock:
|
@@ -1,13 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import dataclasses
|
4
|
+
import warnings
|
3
5
|
from collections import deque
|
4
6
|
from enum import Enum
|
5
|
-
from typing import List
|
7
|
+
from typing import List, Optional
|
6
8
|
|
7
9
|
import numpy as np
|
10
|
+
import requests
|
8
11
|
import torch
|
9
12
|
import torch.distributed as dist
|
10
13
|
|
14
|
+
from sglang.srt.utils import get_ip
|
15
|
+
|
11
16
|
|
12
17
|
class DisaggregationMode(Enum):
|
13
18
|
NULL = "null"
|
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|
119
124
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
120
125
|
# ceil(num_kv_indices / page_size)
|
121
126
|
return (num_kv_indices + page_size - 1) // page_size
|
127
|
+
|
128
|
+
|
129
|
+
@dataclasses.dataclass
|
130
|
+
class PDRegistryRequest:
|
131
|
+
"""A request to register a machine itself to the LB."""
|
132
|
+
|
133
|
+
mode: str
|
134
|
+
registry_url: str
|
135
|
+
bootstrap_port: Optional[int] = None
|
136
|
+
|
137
|
+
def __post_init__(self):
|
138
|
+
if self.mode == "prefill" and self.bootstrap_port is None:
|
139
|
+
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
140
|
+
elif self.mode == "decode" and self.bootstrap_port is not None:
|
141
|
+
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
142
|
+
elif self.mode not in ["prefill", "decode"]:
|
143
|
+
raise ValueError(
|
144
|
+
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
def register_disaggregation_server(
|
149
|
+
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
150
|
+
):
|
151
|
+
boostrap_port = bootstrap_port if mode == "prefill" else None
|
152
|
+
registry_request = PDRegistryRequest(
|
153
|
+
mode=mode,
|
154
|
+
registry_url=f"http://{get_ip()}:{server_port}",
|
155
|
+
bootstrap_port=boostrap_port,
|
156
|
+
)
|
157
|
+
res = requests.post(
|
158
|
+
f"{pdlb_url}/register",
|
159
|
+
json=dataclasses.asdict(registry_request),
|
160
|
+
)
|
161
|
+
if res.status_code != 200:
|
162
|
+
warnings.warn(
|
163
|
+
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
164
|
+
)
|
@@ -296,7 +296,6 @@ class CustomAllreduce:
|
|
296
296
|
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
297
297
|
)
|
298
298
|
self.register_buffer(self.buffer)
|
299
|
-
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
300
299
|
|
301
300
|
self.disabled = False
|
302
301
|
|
@@ -430,13 +429,7 @@ class CustomAllreduce:
|
|
430
429
|
|
431
430
|
if _is_hip:
|
432
431
|
if self.full_nvlink:
|
433
|
-
|
434
|
-
if self.MSCCL:
|
435
|
-
return False
|
436
|
-
else:
|
437
|
-
return inp_size < self.max_size
|
438
|
-
else:
|
439
|
-
return inp_size < self.max_size
|
432
|
+
return inp_size < self.max_size
|
440
433
|
return False
|
441
434
|
|
442
435
|
return False
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.distributed as dist
|
3
|
+
from torch.distributed import ProcessGroup
|
4
|
+
|
5
|
+
from sglang.srt.utils import is_npu
|
6
|
+
|
7
|
+
|
8
|
+
class NpuCommunicator:
|
9
|
+
|
10
|
+
def __init__(self, group: ProcessGroup):
|
11
|
+
if not is_npu():
|
12
|
+
self.disabled = True
|
13
|
+
return
|
14
|
+
self.disabled = False
|
15
|
+
self.group = group
|
16
|
+
self.world_size = dist.get_world_size(self.group)
|
17
|
+
|
18
|
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
19
|
+
dist.all_reduce(x, group=self.group)
|
20
|
+
return x
|
21
|
+
|
22
|
+
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
23
|
+
world_size = self.world_size
|
24
|
+
if dim < 0:
|
25
|
+
# Convert negative dim to positive.
|
26
|
+
dim += x.dim()
|
27
|
+
input_size = x.size()
|
28
|
+
output_size = (input_size[0] * world_size,) + input_size[1:]
|
29
|
+
# Allocate output tensor.
|
30
|
+
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
|
31
|
+
# All-gather.
|
32
|
+
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
33
|
+
# Reshape
|
34
|
+
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
35
|
+
output_tensor = output_tensor.movedim(0, dim)
|
36
|
+
output_tensor = output_tensor.reshape(
|
37
|
+
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
38
|
+
)
|
39
|
+
return output_tensor
|
@@ -75,7 +75,8 @@ class PyNcclCommunicator:
|
|
75
75
|
self.available = True
|
76
76
|
self.disabled = False
|
77
77
|
|
78
|
-
|
78
|
+
if self.rank == 0:
|
79
|
+
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
79
80
|
|
80
81
|
if self.rank == 0:
|
81
82
|
# get the unique id from NCCL
|
@@ -225,7 +225,8 @@ class MessageQueue:
|
|
225
225
|
remote_subscribe_port = get_open_port()
|
226
226
|
if is_valid_ipv6_address(connect_ip):
|
227
227
|
self.remote_socket.setsockopt(IPV6, 1)
|
228
|
-
|
228
|
+
connect_ip = f"[{connect_ip}]"
|
229
|
+
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
229
230
|
self.remote_socket.bind(socket_addr)
|
230
231
|
|
231
232
|
else:
|