sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
|
|
23
23
|
)
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
25
25
|
from sglang.srt.server_args import ServerArgs
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import (
|
27
|
+
format_tcp_address,
|
28
|
+
get_free_port,
|
29
|
+
get_ip,
|
30
|
+
get_local_ip_by_remote,
|
31
|
+
is_valid_ipv6_address,
|
32
|
+
maybe_wrap_ipv6_address,
|
33
|
+
)
|
27
34
|
|
28
35
|
logger = logging.getLogger(__name__)
|
29
36
|
|
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
|
|
65
72
|
def _register_to_bootstrap(self):
|
66
73
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
67
74
|
if self.dist_init_addr:
|
68
|
-
|
75
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
|
+
if self.dist_init_addr.endswith("]"):
|
77
|
+
host = self.dist_init_addr
|
78
|
+
else:
|
79
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
80
|
+
else:
|
81
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
69
82
|
else:
|
70
|
-
|
83
|
+
host = get_ip()
|
84
|
+
host = maybe_wrap_ipv6_address(host)
|
71
85
|
|
72
|
-
bootstrap_server_url = f"{
|
86
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
73
87
|
url = f"http://{bootstrap_server_url}/route"
|
74
88
|
payload = {
|
75
89
|
"role": "Prefill",
|
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
|
|
92
106
|
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
93
107
|
|
94
108
|
@cache
|
95
|
-
def _connect(self, endpoint: str):
|
109
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
96
110
|
socket = zmq.Context().socket(zmq.PUSH)
|
111
|
+
if is_ipv6:
|
112
|
+
socket.setsockopt(zmq.IPV6, 1)
|
97
113
|
socket.connect(endpoint)
|
98
114
|
return socket
|
99
115
|
|
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
263
279
|
return None
|
264
280
|
|
265
281
|
@classmethod
|
266
|
-
def _connect(cls, endpoint: str):
|
282
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
267
283
|
with cls._global_lock:
|
268
284
|
if endpoint not in cls._socket_cache:
|
269
285
|
sock = cls._ctx.socket(zmq.PUSH)
|
286
|
+
if is_ipv6:
|
287
|
+
sock.setsockopt(zmq.IPV6, 1)
|
270
288
|
sock.connect(endpoint)
|
271
289
|
cls._socket_cache[endpoint] = sock
|
272
290
|
cls._socket_locks[endpoint] = threading.Lock()
|
273
291
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
274
292
|
|
293
|
+
@classmethod
|
294
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
295
|
+
ip_address = bootstrap_info["rank_ip"]
|
296
|
+
port = bootstrap_info["rank_port"]
|
297
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
298
|
+
sock, lock = cls._connect(
|
299
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
300
|
+
)
|
301
|
+
return sock, lock
|
302
|
+
|
275
303
|
def _register_kv_args(self):
|
276
304
|
pass
|
277
305
|
|
@@ -1,10 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
+
from http import HTTPStatus
|
4
5
|
from typing import TYPE_CHECKING
|
5
6
|
|
6
7
|
import torch
|
7
8
|
|
9
|
+
from sglang.srt.disaggregation.utils import prepare_abort
|
8
10
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
9
11
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
12
|
|
@@ -102,7 +104,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
102
104
|
self.output_ids.append(req.output_ids[-1])
|
103
105
|
self.tree_cache.cache_unfinished_req(req)
|
104
106
|
if req.grammar is not None:
|
105
|
-
|
107
|
+
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
108
|
+
try:
|
109
|
+
req.grammar.accept_token(req.output_ids[-1])
|
110
|
+
except ValueError as e:
|
111
|
+
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
112
|
+
# This can happen if the grammar is not set correctly or the token is invalid.
|
113
|
+
error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
|
114
|
+
self.tree_cache.cache_finished_req(req)
|
115
|
+
prepare_abort(
|
116
|
+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
117
|
+
)
|
106
118
|
req.grammar.finished = req.finished()
|
107
119
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
108
120
|
|
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
|
|
17
17
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
|
+
from sglang.srt.utils import maybe_wrap_ipv6_address
|
20
21
|
|
21
22
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
23
|
1024 * 64
|
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
|
|
271
272
|
|
272
273
|
# Parse and transform prefill_server for bootstrap data
|
273
274
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
274
|
-
hostname = parsed_url.hostname
|
275
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
275
276
|
modified_request = request_data.copy()
|
276
277
|
|
277
278
|
batch_size = _get_request_batch_size(modified_request)
|
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|
309
310
|
|
310
311
|
# Parse and transform prefill_server for bootstrap data
|
311
312
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
312
|
-
hostname = parsed_url.hostname
|
313
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
313
314
|
modified_request = request_data.copy()
|
314
315
|
modified_request.update(
|
315
316
|
{
|
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
format_tcp_address,
|
40
|
+
get_free_port,
|
41
|
+
get_int_env_var,
|
42
|
+
get_ip,
|
43
|
+
get_local_ip_auto,
|
44
|
+
is_valid_ipv6_address,
|
45
|
+
maybe_wrap_ipv6_address,
|
46
|
+
)
|
39
47
|
|
40
48
|
logger = logging.getLogger(__name__)
|
41
49
|
|
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
148
156
|
self.request_status: Dict[int, KVPoll] = {}
|
149
157
|
self.rank_port = None
|
150
158
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
159
|
+
if is_valid_ipv6_address(self.local_ip):
|
160
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
161
|
+
|
151
162
|
self.register_buffer_to_engine()
|
152
163
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
153
164
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
240
251
|
self.engine.register(aux_data_ptr, aux_data_len)
|
241
252
|
|
242
253
|
@cache
|
243
|
-
def _connect(self, endpoint: str):
|
254
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
244
255
|
socket = zmq.Context().socket(zmq.PUSH)
|
256
|
+
if is_ipv6:
|
257
|
+
socket.setsockopt(zmq.IPV6, 1)
|
245
258
|
socket.connect(endpoint)
|
246
259
|
return socket
|
247
260
|
|
@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
484
|
def sync_status_to_decode_endpoint(
|
472
485
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
473
486
|
):
|
474
|
-
|
475
|
-
remote =
|
476
|
-
|
487
|
+
self._connect(
|
488
|
+
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
489
|
+
).send_multipart(
|
477
490
|
[
|
478
491
|
str(room).encode("ascii"),
|
479
492
|
str(status).encode("ascii"),
|
@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
616
629
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
617
630
|
)
|
618
631
|
|
632
|
+
def _bind_server_socket(self):
|
633
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
634
|
+
|
619
635
|
def start_prefill_thread(self):
|
620
636
|
self.rank_port = get_free_port()
|
621
|
-
self.
|
637
|
+
self._bind_server_socket()
|
622
638
|
|
623
639
|
def bootstrap_thread():
|
624
640
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
657
673
|
|
658
674
|
def start_decode_thread(self):
|
659
675
|
self.rank_port = get_free_port()
|
660
|
-
self.
|
676
|
+
self._bind_server_socket()
|
661
677
|
|
662
678
|
def decode_thread():
|
663
679
|
while True:
|
@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
776
792
|
# requests with the same dst_sessions will be added into the same
|
777
793
|
# queue, which enables early abort with failed sessions.
|
778
794
|
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
779
|
-
session_port_sum = sum(int(session.
|
795
|
+
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
|
780
796
|
shard_idx = session_port_sum % len(self.transfer_queues)
|
781
797
|
|
782
798
|
self.transfer_queues[shard_idx].put(
|
@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
|
|
814
830
|
def _register_to_bootstrap(self):
|
815
831
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
816
832
|
if self.dist_init_addr:
|
817
|
-
|
833
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
834
|
+
if self.dist_init_addr.endswith("]"):
|
835
|
+
host = self.dist_init_addr
|
836
|
+
else:
|
837
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
838
|
+
else:
|
839
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
818
840
|
else:
|
819
|
-
|
841
|
+
host = get_ip()
|
842
|
+
host = maybe_wrap_ipv6_address(host)
|
820
843
|
|
821
|
-
bootstrap_server_url = f"{
|
844
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
822
845
|
url = f"http://{bootstrap_server_url}/route"
|
823
846
|
payload = {
|
824
847
|
"role": "Prefill",
|
@@ -969,6 +992,14 @@ class MooncakeKVSender(BaseKVSender):
|
|
969
992
|
)
|
970
993
|
raise KVTransferError(self.bootstrap_room, failure_reason)
|
971
994
|
|
995
|
+
def abort(self):
|
996
|
+
self.kv_mgr.record_failure(
|
997
|
+
self.bootstrap_room,
|
998
|
+
"Aborted by AbortReq.",
|
999
|
+
)
|
1000
|
+
# Explicitly set the status to failure since this request has been aborted
|
1001
|
+
self.conclude_state = KVPoll.Failed
|
1002
|
+
|
972
1003
|
|
973
1004
|
class MooncakeKVReceiver(BaseKVReceiver):
|
974
1005
|
_ctx = zmq.Context()
|
@@ -1163,9 +1194,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1163
1194
|
|
1164
1195
|
def _register_kv_args(self):
|
1165
1196
|
for bootstrap_info in self.bootstrap_infos:
|
1166
|
-
self.prefill_server_url = (
|
1167
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1168
|
-
)
|
1169
1197
|
packed_kv_data_ptrs = b"".join(
|
1170
1198
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
1171
1199
|
)
|
@@ -1179,7 +1207,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1179
1207
|
dst_tp_size = str(tp_size).encode("ascii")
|
1180
1208
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
1181
1209
|
|
1182
|
-
sock, lock = self.
|
1210
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1183
1211
|
with lock:
|
1184
1212
|
sock.send_multipart(
|
1185
1213
|
[
|
@@ -1196,23 +1224,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1196
1224
|
)
|
1197
1225
|
|
1198
1226
|
@classmethod
|
1199
|
-
def _connect(cls, endpoint: str):
|
1227
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
1200
1228
|
with cls._global_lock:
|
1201
1229
|
if endpoint not in cls._socket_cache:
|
1202
1230
|
sock = cls._ctx.socket(zmq.PUSH)
|
1231
|
+
if is_ipv6:
|
1232
|
+
sock.setsockopt(zmq.IPV6, 1)
|
1203
1233
|
sock.connect(endpoint)
|
1204
1234
|
cls._socket_cache[endpoint] = sock
|
1205
1235
|
cls._socket_locks[endpoint] = threading.Lock()
|
1206
1236
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
1207
1237
|
|
1238
|
+
@classmethod
|
1239
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
1240
|
+
ip_address = bootstrap_info["rank_ip"]
|
1241
|
+
port = bootstrap_info["rank_port"]
|
1242
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
1243
|
+
sock, lock = cls._connect(
|
1244
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
1245
|
+
)
|
1246
|
+
return sock, lock
|
1247
|
+
|
1208
1248
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
1209
1249
|
for bootstrap_info in self.bootstrap_infos:
|
1210
|
-
|
1211
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1212
|
-
)
|
1250
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1213
1251
|
is_dummy = bootstrap_info["is_dummy"]
|
1214
1252
|
|
1215
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
1216
1253
|
with lock:
|
1217
1254
|
sock.send_multipart(
|
1218
1255
|
[
|
@@ -1276,6 +1313,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1276
1313
|
)
|
1277
1314
|
raise KVTransferError(self.bootstrap_room, failure_reason)
|
1278
1315
|
|
1316
|
+
def abort(self):
|
1317
|
+
self.kv_mgr.record_failure(
|
1318
|
+
self.bootstrap_room,
|
1319
|
+
"Aborted by AbortReq.",
|
1320
|
+
)
|
1321
|
+
# Explicitly set the status to failure since this request has been aborted
|
1322
|
+
self.conclude_state = KVPoll.Failed
|
1323
|
+
|
1279
1324
|
|
1280
1325
|
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
1281
1326
|
def __init__(self, port: int):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from sglang.srt.utils import get_bool_env_var, get_free_port
|
4
|
+
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
5
5
|
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
|
|
27
27
|
hostname=self.hostname,
|
28
28
|
device_name=self.ib_device,
|
29
29
|
)
|
30
|
-
self.session_id =
|
30
|
+
self.session_id = (
|
31
|
+
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
32
|
+
)
|
31
33
|
|
32
34
|
def register(self, ptr, length):
|
33
35
|
try:
|
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
|
|
27
27
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
28
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
29
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
30
|
+
from sglang.srt.utils import (
|
31
|
+
format_tcp_address,
|
32
|
+
get_local_ip_auto,
|
33
|
+
is_valid_ipv6_address,
|
34
|
+
)
|
31
35
|
|
32
36
|
logger = logging.getLogger(__name__)
|
33
37
|
|
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
|
|
124
128
|
"to run SGLang with NixlTransferEngine."
|
125
129
|
) from e
|
126
130
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
131
|
+
self.local_ip = get_local_ip_auto()
|
127
132
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
133
|
+
if is_valid_ipv6_address(self.local_ip):
|
134
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
128
135
|
self.register_buffer_to_engine()
|
129
136
|
|
130
137
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
|
|
337
344
|
return False
|
338
345
|
return self.transfer_statuses[room].is_done()
|
339
346
|
|
347
|
+
def _bind_server_socket(self):
|
348
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
349
|
+
|
340
350
|
def _start_bootstrap_thread(self):
|
341
|
-
self.
|
351
|
+
self._bind_server_socket()
|
342
352
|
|
343
353
|
def bootstrap_thread():
|
344
354
|
"""This thread recvs transfer info from the decode engine"""
|
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
452
462
|
|
453
463
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
454
464
|
for bootstrap_info in self.bootstrap_infos:
|
455
|
-
self.prefill_server_url = (
|
456
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
457
|
-
)
|
458
465
|
logger.debug(
|
459
466
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
460
467
|
)
|
468
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
461
469
|
is_dummy = bootstrap_info["is_dummy"]
|
462
470
|
logger.debug(
|
463
|
-
f"Sending to
|
471
|
+
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
464
472
|
)
|
465
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
466
473
|
with lock:
|
467
474
|
sock.send_multipart(
|
468
475
|
[
|
469
476
|
GUARD,
|
470
477
|
str(self.bootstrap_room).encode("ascii"),
|
471
|
-
|
478
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
472
479
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
473
480
|
self.kv_mgr.agent.name.encode("ascii"),
|
474
481
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
494
501
|
|
495
502
|
def _register_kv_args(self):
|
496
503
|
for bootstrap_info in self.bootstrap_infos:
|
497
|
-
|
498
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
-
)
|
504
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
500
505
|
packed_kv_data_ptrs = b"".join(
|
501
506
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
507
|
)
|
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
504
509
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
510
|
)
|
506
511
|
|
507
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
512
|
with lock:
|
509
513
|
sock.send_multipart(
|
510
514
|
[
|
511
515
|
GUARD,
|
512
516
|
"None".encode("ascii"),
|
513
|
-
|
517
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
514
518
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
519
|
self.kv_mgr.agent.name.encode("ascii"),
|
516
520
|
self.kv_mgr.agent.get_agent_metadata(),
|
@@ -425,7 +425,19 @@ class SchedulerDisaggregationPrefillMixin:
|
|
425
425
|
self.send_kv_chunk(req, last_chunk=True)
|
426
426
|
|
427
427
|
if req.grammar is not None:
|
428
|
-
|
428
|
+
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
429
|
+
try:
|
430
|
+
req.grammar.accept_token(next_token_id)
|
431
|
+
except ValueError as e:
|
432
|
+
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
433
|
+
# This can happen if the grammar is not set correctly or the token is invalid.
|
434
|
+
error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
435
|
+
self.tree_cache.cache_finished_req(req)
|
436
|
+
prepare_abort(
|
437
|
+
req,
|
438
|
+
error_message,
|
439
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
440
|
+
)
|
429
441
|
req.grammar.finished = req.finished()
|
430
442
|
else:
|
431
443
|
# being chunked reqs' prefill is not finished
|
@@ -4,18 +4,18 @@ import ctypes
|
|
4
4
|
import logging
|
5
5
|
import os
|
6
6
|
from contextlib import contextmanager
|
7
|
-
from
|
8
|
-
from typing import Any, Callable, List, Optional, TypeVar, Union
|
7
|
+
from typing import Any, List, Optional, Union
|
9
8
|
|
10
9
|
import torch
|
11
10
|
import torch.distributed as dist
|
12
11
|
from torch.distributed import ProcessGroup
|
13
|
-
from typing_extensions import ParamSpec
|
14
12
|
|
15
13
|
from sglang.srt import _custom_ops as ops
|
16
14
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
17
15
|
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
18
16
|
gpu_p2p_access_check,
|
17
|
+
is_full_nvlink,
|
18
|
+
is_weak_contiguous,
|
19
19
|
)
|
20
20
|
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
21
21
|
from sglang.srt.utils import is_cuda, is_hip
|
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
_is_hip = is_hip()
|
27
27
|
|
28
|
-
if _is_cuda:
|
29
|
-
try:
|
30
|
-
import pynvml
|
31
|
-
except ImportError as e:
|
32
|
-
logger.warning("Failed to import pynvml with %r", e)
|
33
|
-
|
34
|
-
if _is_hip:
|
35
|
-
try:
|
36
|
-
from amdsmi import (
|
37
|
-
AmdSmiException,
|
38
|
-
amdsmi_get_processor_handles,
|
39
|
-
amdsmi_init,
|
40
|
-
amdsmi_shut_down,
|
41
|
-
amdsmi_topo_get_link_type,
|
42
|
-
)
|
43
|
-
except ImportError as e:
|
44
|
-
logger.warning("Failed to import amdsmi with %r", e)
|
45
28
|
|
46
29
|
try:
|
47
30
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
@@ -57,70 +40,6 @@ except Exception:
|
|
57
40
|
|
58
41
|
logger = logging.getLogger(__name__)
|
59
42
|
|
60
|
-
_P = ParamSpec("_P")
|
61
|
-
_R = TypeVar("_R")
|
62
|
-
|
63
|
-
|
64
|
-
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
65
|
-
@wraps(fn)
|
66
|
-
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
67
|
-
if _is_hip:
|
68
|
-
try:
|
69
|
-
amdsmi_init()
|
70
|
-
return fn(*args, **kwargs)
|
71
|
-
finally:
|
72
|
-
amdsmi_shut_down()
|
73
|
-
else:
|
74
|
-
pynvml.nvmlInit()
|
75
|
-
try:
|
76
|
-
return fn(*args, **kwargs)
|
77
|
-
finally:
|
78
|
-
pynvml.nvmlShutdown()
|
79
|
-
|
80
|
-
return wrapper
|
81
|
-
|
82
|
-
|
83
|
-
@with_nvml_context
|
84
|
-
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
85
|
-
if _is_hip:
|
86
|
-
"""
|
87
|
-
query if the set of gpus are fully connected by xgmi (1 hop)
|
88
|
-
"""
|
89
|
-
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
90
|
-
for i, handle in enumerate(handles):
|
91
|
-
for j, peer_handle in enumerate(handles):
|
92
|
-
if i < j:
|
93
|
-
try:
|
94
|
-
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
95
|
-
# type is 2 for XGMI
|
96
|
-
if link_type["hops"] != 1 or link_type["type"] != 2:
|
97
|
-
return False
|
98
|
-
except AmdSmiException as error:
|
99
|
-
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
100
|
-
return False
|
101
|
-
return True
|
102
|
-
else:
|
103
|
-
"""
|
104
|
-
query if the set of gpus are fully connected by nvlink (1 hop)
|
105
|
-
"""
|
106
|
-
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
107
|
-
for i, handle in enumerate(handles):
|
108
|
-
for j, peer_handle in enumerate(handles):
|
109
|
-
if i < j:
|
110
|
-
try:
|
111
|
-
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
112
|
-
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
113
|
-
)
|
114
|
-
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
115
|
-
return False
|
116
|
-
except pynvml.NVMLError:
|
117
|
-
logger.exception(
|
118
|
-
"NVLink detection failed. This is normal if your"
|
119
|
-
" machine has no NVLink equipped."
|
120
|
-
)
|
121
|
-
return False
|
122
|
-
return True
|
123
|
-
|
124
43
|
|
125
44
|
def _can_p2p(rank: int, world_size: int) -> bool:
|
126
45
|
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
|
136
55
|
return True
|
137
56
|
|
138
57
|
|
139
|
-
def is_weak_contiguous(inp: torch.Tensor):
|
140
|
-
return inp.is_contiguous() or (
|
141
|
-
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
142
|
-
== inp.numel() * inp.element_size()
|
143
|
-
)
|
144
|
-
|
145
|
-
|
146
58
|
class CustomAllreduce:
|
147
59
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
148
60
|
_MAX_CAR_SIZE = 8192 * 1024
|