sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/_custom_ops.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
2
2
|
import logging
|
3
|
-
from typing import List, Tuple
|
3
|
+
from typing import List, Optional, Tuple
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -114,6 +114,34 @@ else:
|
|
114
114
|
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
115
115
|
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
116
116
|
|
117
|
+
# ROCM custom quick allreduce
|
118
|
+
|
119
|
+
def init_custom_qr(
|
120
|
+
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
121
|
+
) -> int:
|
122
|
+
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
123
|
+
|
124
|
+
def qr_get_handle(fa: int) -> torch.Tensor:
|
125
|
+
return sgl_kernel.allreduce.qr_get_handle(fa)
|
126
|
+
|
127
|
+
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
128
|
+
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
129
|
+
|
130
|
+
def qr_all_reduce(
|
131
|
+
fa: int,
|
132
|
+
inp: torch.Tensor,
|
133
|
+
out: torch.Tensor,
|
134
|
+
quant_level: int,
|
135
|
+
cast_bf2half: bool,
|
136
|
+
) -> None:
|
137
|
+
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
138
|
+
|
139
|
+
def qr_destroy(fa: int) -> None:
|
140
|
+
sgl_kernel.allreduce.qr_destroy(fa)
|
141
|
+
|
142
|
+
def qr_max_size() -> int:
|
143
|
+
return sgl_kernel.allreduce.qr_max_size()
|
144
|
+
|
117
145
|
|
118
146
|
def mscclpp_generate_unique_id() -> bytes:
|
119
147
|
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
@@ -475,7 +475,7 @@ class ModelConfig:
|
|
475
475
|
|
476
476
|
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
477
477
|
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
478
|
-
if eos_ids:
|
478
|
+
if eos_ids is not None:
|
479
479
|
# it can be either int or list of int
|
480
480
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
481
481
|
if eos_ids is None:
|
sglang/srt/conversation.py
CHANGED
@@ -984,7 +984,7 @@ register_conv_template(
|
|
984
984
|
|
985
985
|
@register_conv_template_matching_function
|
986
986
|
def match_internvl(model_path: str):
|
987
|
-
if re.search(r"
|
987
|
+
if re.search(r"internvl", model_path, re.IGNORECASE):
|
988
988
|
return "internvl-2-5"
|
989
989
|
|
990
990
|
|
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
|
|
23
23
|
)
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
25
25
|
from sglang.srt.server_args import ServerArgs
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import (
|
27
|
+
format_tcp_address,
|
28
|
+
get_free_port,
|
29
|
+
get_ip,
|
30
|
+
get_local_ip_by_remote,
|
31
|
+
is_valid_ipv6_address,
|
32
|
+
maybe_wrap_ipv6_address,
|
33
|
+
)
|
27
34
|
|
28
35
|
logger = logging.getLogger(__name__)
|
29
36
|
|
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
|
|
65
72
|
def _register_to_bootstrap(self):
|
66
73
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
67
74
|
if self.dist_init_addr:
|
68
|
-
|
75
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
|
+
if self.dist_init_addr.endswith("]"):
|
77
|
+
host = self.dist_init_addr
|
78
|
+
else:
|
79
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
80
|
+
else:
|
81
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
69
82
|
else:
|
70
|
-
|
83
|
+
host = get_ip()
|
84
|
+
host = maybe_wrap_ipv6_address(host)
|
71
85
|
|
72
|
-
bootstrap_server_url = f"{
|
86
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
73
87
|
url = f"http://{bootstrap_server_url}/route"
|
74
88
|
payload = {
|
75
89
|
"role": "Prefill",
|
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
|
|
92
106
|
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
93
107
|
|
94
108
|
@cache
|
95
|
-
def _connect(self, endpoint: str):
|
109
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
96
110
|
socket = zmq.Context().socket(zmq.PUSH)
|
111
|
+
if is_ipv6:
|
112
|
+
socket.setsockopt(zmq.IPV6, 1)
|
97
113
|
socket.connect(endpoint)
|
98
114
|
return socket
|
99
115
|
|
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
263
279
|
return None
|
264
280
|
|
265
281
|
@classmethod
|
266
|
-
def _connect(cls, endpoint: str):
|
282
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
267
283
|
with cls._global_lock:
|
268
284
|
if endpoint not in cls._socket_cache:
|
269
285
|
sock = cls._ctx.socket(zmq.PUSH)
|
286
|
+
if is_ipv6:
|
287
|
+
sock.setsockopt(zmq.IPV6, 1)
|
270
288
|
sock.connect(endpoint)
|
271
289
|
cls._socket_cache[endpoint] = sock
|
272
290
|
cls._socket_locks[endpoint] = threading.Lock()
|
273
291
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
274
292
|
|
293
|
+
@classmethod
|
294
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
295
|
+
ip_address = bootstrap_info["rank_ip"]
|
296
|
+
port = bootstrap_info["rank_port"]
|
297
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
298
|
+
sock, lock = cls._connect(
|
299
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
300
|
+
)
|
301
|
+
return sock, lock
|
302
|
+
|
275
303
|
def _register_kv_args(self):
|
276
304
|
pass
|
277
305
|
|
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
|
|
17
17
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
|
+
from sglang.srt.utils import maybe_wrap_ipv6_address
|
20
21
|
|
21
22
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
23
|
1024 * 64
|
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
|
|
271
272
|
|
272
273
|
# Parse and transform prefill_server for bootstrap data
|
273
274
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
274
|
-
hostname = parsed_url.hostname
|
275
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
275
276
|
modified_request = request_data.copy()
|
276
277
|
|
277
278
|
batch_size = _get_request_batch_size(modified_request)
|
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|
309
310
|
|
310
311
|
# Parse and transform prefill_server for bootstrap data
|
311
312
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
312
|
-
hostname = parsed_url.hostname
|
313
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
313
314
|
modified_request = request_data.copy()
|
314
315
|
modified_request.update(
|
315
316
|
{
|
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
format_tcp_address,
|
40
|
+
get_free_port,
|
41
|
+
get_int_env_var,
|
42
|
+
get_ip,
|
43
|
+
get_local_ip_auto,
|
44
|
+
is_valid_ipv6_address,
|
45
|
+
maybe_wrap_ipv6_address,
|
46
|
+
)
|
39
47
|
|
40
48
|
logger = logging.getLogger(__name__)
|
41
49
|
|
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
148
156
|
self.request_status: Dict[int, KVPoll] = {}
|
149
157
|
self.rank_port = None
|
150
158
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
159
|
+
if is_valid_ipv6_address(self.local_ip):
|
160
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
161
|
+
|
151
162
|
self.register_buffer_to_engine()
|
152
163
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
153
164
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
240
251
|
self.engine.register(aux_data_ptr, aux_data_len)
|
241
252
|
|
242
253
|
@cache
|
243
|
-
def _connect(self, endpoint: str):
|
254
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
244
255
|
socket = zmq.Context().socket(zmq.PUSH)
|
256
|
+
if is_ipv6:
|
257
|
+
socket.setsockopt(zmq.IPV6, 1)
|
245
258
|
socket.connect(endpoint)
|
246
259
|
return socket
|
247
260
|
|
@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
484
|
def sync_status_to_decode_endpoint(
|
472
485
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
473
486
|
):
|
474
|
-
|
475
|
-
remote =
|
476
|
-
|
487
|
+
self._connect(
|
488
|
+
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
489
|
+
).send_multipart(
|
477
490
|
[
|
478
491
|
str(room).encode("ascii"),
|
479
492
|
str(status).encode("ascii"),
|
@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
616
629
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
617
630
|
)
|
618
631
|
|
632
|
+
def _bind_server_socket(self):
|
633
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
634
|
+
|
619
635
|
def start_prefill_thread(self):
|
620
636
|
self.rank_port = get_free_port()
|
621
|
-
self.
|
637
|
+
self._bind_server_socket()
|
622
638
|
|
623
639
|
def bootstrap_thread():
|
624
640
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
657
673
|
|
658
674
|
def start_decode_thread(self):
|
659
675
|
self.rank_port = get_free_port()
|
660
|
-
self.
|
676
|
+
self._bind_server_socket()
|
661
677
|
|
662
678
|
def decode_thread():
|
663
679
|
while True:
|
@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
776
792
|
# requests with the same dst_sessions will be added into the same
|
777
793
|
# queue, which enables early abort with failed sessions.
|
778
794
|
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
779
|
-
session_port_sum = sum(int(session.
|
795
|
+
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
|
780
796
|
shard_idx = session_port_sum % len(self.transfer_queues)
|
781
797
|
|
782
798
|
self.transfer_queues[shard_idx].put(
|
@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
|
|
814
830
|
def _register_to_bootstrap(self):
|
815
831
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
816
832
|
if self.dist_init_addr:
|
817
|
-
|
833
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
834
|
+
if self.dist_init_addr.endswith("]"):
|
835
|
+
host = self.dist_init_addr
|
836
|
+
else:
|
837
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
838
|
+
else:
|
839
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
818
840
|
else:
|
819
|
-
|
841
|
+
host = get_ip()
|
842
|
+
host = maybe_wrap_ipv6_address(host)
|
820
843
|
|
821
|
-
bootstrap_server_url = f"{
|
844
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
822
845
|
url = f"http://{bootstrap_server_url}/route"
|
823
846
|
payload = {
|
824
847
|
"role": "Prefill",
|
@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1163
1186
|
|
1164
1187
|
def _register_kv_args(self):
|
1165
1188
|
for bootstrap_info in self.bootstrap_infos:
|
1166
|
-
self.prefill_server_url = (
|
1167
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1168
|
-
)
|
1169
1189
|
packed_kv_data_ptrs = b"".join(
|
1170
1190
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
1171
1191
|
)
|
@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1179
1199
|
dst_tp_size = str(tp_size).encode("ascii")
|
1180
1200
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
1181
1201
|
|
1182
|
-
sock, lock = self.
|
1202
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1183
1203
|
with lock:
|
1184
1204
|
sock.send_multipart(
|
1185
1205
|
[
|
@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1196
1216
|
)
|
1197
1217
|
|
1198
1218
|
@classmethod
|
1199
|
-
def _connect(cls, endpoint: str):
|
1219
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
1200
1220
|
with cls._global_lock:
|
1201
1221
|
if endpoint not in cls._socket_cache:
|
1202
1222
|
sock = cls._ctx.socket(zmq.PUSH)
|
1223
|
+
if is_ipv6:
|
1224
|
+
sock.setsockopt(zmq.IPV6, 1)
|
1203
1225
|
sock.connect(endpoint)
|
1204
1226
|
cls._socket_cache[endpoint] = sock
|
1205
1227
|
cls._socket_locks[endpoint] = threading.Lock()
|
1206
1228
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
1207
1229
|
|
1230
|
+
@classmethod
|
1231
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
1232
|
+
ip_address = bootstrap_info["rank_ip"]
|
1233
|
+
port = bootstrap_info["rank_port"]
|
1234
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
1235
|
+
sock, lock = cls._connect(
|
1236
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
1237
|
+
)
|
1238
|
+
return sock, lock
|
1239
|
+
|
1208
1240
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
1209
1241
|
for bootstrap_info in self.bootstrap_infos:
|
1210
|
-
|
1211
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1212
|
-
)
|
1242
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1213
1243
|
is_dummy = bootstrap_info["is_dummy"]
|
1214
1244
|
|
1215
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
1216
1245
|
with lock:
|
1217
1246
|
sock.send_multipart(
|
1218
1247
|
[
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from sglang.srt.utils import get_bool_env_var, get_free_port
|
4
|
+
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
5
5
|
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
|
|
27
27
|
hostname=self.hostname,
|
28
28
|
device_name=self.ib_device,
|
29
29
|
)
|
30
|
-
self.session_id =
|
30
|
+
self.session_id = (
|
31
|
+
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
32
|
+
)
|
31
33
|
|
32
34
|
def register(self, ptr, length):
|
33
35
|
try:
|
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
|
|
27
27
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
28
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
29
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
30
|
+
from sglang.srt.utils import (
|
31
|
+
format_tcp_address,
|
32
|
+
get_local_ip_auto,
|
33
|
+
is_valid_ipv6_address,
|
34
|
+
)
|
31
35
|
|
32
36
|
logger = logging.getLogger(__name__)
|
33
37
|
|
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
|
|
124
128
|
"to run SGLang with NixlTransferEngine."
|
125
129
|
) from e
|
126
130
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
131
|
+
self.local_ip = get_local_ip_auto()
|
127
132
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
133
|
+
if is_valid_ipv6_address(self.local_ip):
|
134
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
128
135
|
self.register_buffer_to_engine()
|
129
136
|
|
130
137
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
|
|
337
344
|
return False
|
338
345
|
return self.transfer_statuses[room].is_done()
|
339
346
|
|
347
|
+
def _bind_server_socket(self):
|
348
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
349
|
+
|
340
350
|
def _start_bootstrap_thread(self):
|
341
|
-
self.
|
351
|
+
self._bind_server_socket()
|
342
352
|
|
343
353
|
def bootstrap_thread():
|
344
354
|
"""This thread recvs transfer info from the decode engine"""
|
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
452
462
|
|
453
463
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
454
464
|
for bootstrap_info in self.bootstrap_infos:
|
455
|
-
self.prefill_server_url = (
|
456
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
457
|
-
)
|
458
465
|
logger.debug(
|
459
466
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
460
467
|
)
|
468
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
461
469
|
is_dummy = bootstrap_info["is_dummy"]
|
462
470
|
logger.debug(
|
463
|
-
f"Sending to
|
471
|
+
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
464
472
|
)
|
465
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
466
473
|
with lock:
|
467
474
|
sock.send_multipart(
|
468
475
|
[
|
469
476
|
GUARD,
|
470
477
|
str(self.bootstrap_room).encode("ascii"),
|
471
|
-
|
478
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
472
479
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
473
480
|
self.kv_mgr.agent.name.encode("ascii"),
|
474
481
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
494
501
|
|
495
502
|
def _register_kv_args(self):
|
496
503
|
for bootstrap_info in self.bootstrap_infos:
|
497
|
-
|
498
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
-
)
|
504
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
500
505
|
packed_kv_data_ptrs = b"".join(
|
501
506
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
507
|
)
|
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
504
509
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
510
|
)
|
506
511
|
|
507
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
512
|
with lock:
|
509
513
|
sock.send_multipart(
|
510
514
|
[
|
511
515
|
GUARD,
|
512
516
|
"None".encode("ascii"),
|
513
|
-
|
517
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
514
518
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
519
|
self.kv_mgr.agent.name.encode("ascii"),
|
516
520
|
self.kv_mgr.agent.get_agent_metadata(),
|
@@ -4,18 +4,18 @@ import ctypes
|
|
4
4
|
import logging
|
5
5
|
import os
|
6
6
|
from contextlib import contextmanager
|
7
|
-
from
|
8
|
-
from typing import Any, Callable, List, Optional, TypeVar, Union
|
7
|
+
from typing import Any, List, Optional, Union
|
9
8
|
|
10
9
|
import torch
|
11
10
|
import torch.distributed as dist
|
12
11
|
from torch.distributed import ProcessGroup
|
13
|
-
from typing_extensions import ParamSpec
|
14
12
|
|
15
13
|
from sglang.srt import _custom_ops as ops
|
16
14
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
17
15
|
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
18
16
|
gpu_p2p_access_check,
|
17
|
+
is_full_nvlink,
|
18
|
+
is_weak_contiguous,
|
19
19
|
)
|
20
20
|
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
21
21
|
from sglang.srt.utils import is_cuda, is_hip
|
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
_is_hip = is_hip()
|
27
27
|
|
28
|
-
if _is_cuda:
|
29
|
-
try:
|
30
|
-
import pynvml
|
31
|
-
except ImportError as e:
|
32
|
-
logger.warning("Failed to import pynvml with %r", e)
|
33
|
-
|
34
|
-
if _is_hip:
|
35
|
-
try:
|
36
|
-
from amdsmi import (
|
37
|
-
AmdSmiException,
|
38
|
-
amdsmi_get_processor_handles,
|
39
|
-
amdsmi_init,
|
40
|
-
amdsmi_shut_down,
|
41
|
-
amdsmi_topo_get_link_type,
|
42
|
-
)
|
43
|
-
except ImportError as e:
|
44
|
-
logger.warning("Failed to import amdsmi with %r", e)
|
45
28
|
|
46
29
|
try:
|
47
30
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
@@ -57,70 +40,6 @@ except Exception:
|
|
57
40
|
|
58
41
|
logger = logging.getLogger(__name__)
|
59
42
|
|
60
|
-
_P = ParamSpec("_P")
|
61
|
-
_R = TypeVar("_R")
|
62
|
-
|
63
|
-
|
64
|
-
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
65
|
-
@wraps(fn)
|
66
|
-
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
67
|
-
if _is_hip:
|
68
|
-
try:
|
69
|
-
amdsmi_init()
|
70
|
-
return fn(*args, **kwargs)
|
71
|
-
finally:
|
72
|
-
amdsmi_shut_down()
|
73
|
-
else:
|
74
|
-
pynvml.nvmlInit()
|
75
|
-
try:
|
76
|
-
return fn(*args, **kwargs)
|
77
|
-
finally:
|
78
|
-
pynvml.nvmlShutdown()
|
79
|
-
|
80
|
-
return wrapper
|
81
|
-
|
82
|
-
|
83
|
-
@with_nvml_context
|
84
|
-
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
85
|
-
if _is_hip:
|
86
|
-
"""
|
87
|
-
query if the set of gpus are fully connected by xgmi (1 hop)
|
88
|
-
"""
|
89
|
-
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
90
|
-
for i, handle in enumerate(handles):
|
91
|
-
for j, peer_handle in enumerate(handles):
|
92
|
-
if i < j:
|
93
|
-
try:
|
94
|
-
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
95
|
-
# type is 2 for XGMI
|
96
|
-
if link_type["hops"] != 1 or link_type["type"] != 2:
|
97
|
-
return False
|
98
|
-
except AmdSmiException as error:
|
99
|
-
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
100
|
-
return False
|
101
|
-
return True
|
102
|
-
else:
|
103
|
-
"""
|
104
|
-
query if the set of gpus are fully connected by nvlink (1 hop)
|
105
|
-
"""
|
106
|
-
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
107
|
-
for i, handle in enumerate(handles):
|
108
|
-
for j, peer_handle in enumerate(handles):
|
109
|
-
if i < j:
|
110
|
-
try:
|
111
|
-
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
112
|
-
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
113
|
-
)
|
114
|
-
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
115
|
-
return False
|
116
|
-
except pynvml.NVMLError:
|
117
|
-
logger.exception(
|
118
|
-
"NVLink detection failed. This is normal if your"
|
119
|
-
" machine has no NVLink equipped."
|
120
|
-
)
|
121
|
-
return False
|
122
|
-
return True
|
123
|
-
|
124
43
|
|
125
44
|
def _can_p2p(rank: int, world_size: int) -> bool:
|
126
45
|
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
|
136
55
|
return True
|
137
56
|
|
138
57
|
|
139
|
-
def is_weak_contiguous(inp: torch.Tensor):
|
140
|
-
return inp.is_contiguous() or (
|
141
|
-
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
142
|
-
== inp.numel() * inp.element_size()
|
143
|
-
)
|
144
|
-
|
145
|
-
|
146
58
|
class CustomAllreduce:
|
147
59
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
148
60
|
_MAX_CAR_SIZE = 8192 * 1024
|