sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
23
23
  )
24
24
  from sglang.srt.disaggregation.utils import DisaggregationMode
25
25
  from sglang.srt.server_args import ServerArgs
26
- from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
26
+ from sglang.srt.utils import (
27
+ format_tcp_address,
28
+ get_free_port,
29
+ get_ip,
30
+ get_local_ip_by_remote,
31
+ is_valid_ipv6_address,
32
+ maybe_wrap_ipv6_address,
33
+ )
27
34
 
28
35
  logger = logging.getLogger(__name__)
29
36
 
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
65
72
  def _register_to_bootstrap(self):
66
73
  """Register KVSender to bootstrap server via HTTP POST."""
67
74
  if self.dist_init_addr:
68
- ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
75
+ if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
76
+ if self.dist_init_addr.endswith("]"):
77
+ host = self.dist_init_addr
78
+ else:
79
+ host, _ = self.dist_init_addr.rsplit(":", 1)
80
+ else:
81
+ host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
69
82
  else:
70
- ip_address = get_ip()
83
+ host = get_ip()
84
+ host = maybe_wrap_ipv6_address(host)
71
85
 
72
- bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
86
+ bootstrap_server_url = f"{host}:{self.bootstrap_port}"
73
87
  url = f"http://{bootstrap_server_url}/route"
74
88
  payload = {
75
89
  "role": "Prefill",
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
92
106
  logger.error(f"Prefill Failed to register to bootstrap server: {e}")
93
107
 
94
108
  @cache
95
- def _connect(self, endpoint: str):
109
+ def _connect(self, endpoint: str, is_ipv6: bool = False):
96
110
  socket = zmq.Context().socket(zmq.PUSH)
111
+ if is_ipv6:
112
+ socket.setsockopt(zmq.IPV6, 1)
97
113
  socket.connect(endpoint)
98
114
  return socket
99
115
 
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
263
279
  return None
264
280
 
265
281
  @classmethod
266
- def _connect(cls, endpoint: str):
282
+ def _connect(cls, endpoint: str, is_ipv6: bool = False):
267
283
  with cls._global_lock:
268
284
  if endpoint not in cls._socket_cache:
269
285
  sock = cls._ctx.socket(zmq.PUSH)
286
+ if is_ipv6:
287
+ sock.setsockopt(zmq.IPV6, 1)
270
288
  sock.connect(endpoint)
271
289
  cls._socket_cache[endpoint] = sock
272
290
  cls._socket_locks[endpoint] = threading.Lock()
273
291
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
274
292
 
293
+ @classmethod
294
+ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
295
+ ip_address = bootstrap_info["rank_ip"]
296
+ port = bootstrap_info["rank_port"]
297
+ is_ipv6_address = is_valid_ipv6_address(ip_address)
298
+ sock, lock = cls._connect(
299
+ format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
300
+ )
301
+ return sock, lock
302
+
275
303
  def _register_kv_args(self):
276
304
  pass
277
305
 
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from http import HTTPStatus
4
5
  from typing import TYPE_CHECKING
5
6
 
6
7
  import torch
7
8
 
9
+ from sglang.srt.disaggregation.utils import prepare_abort
8
10
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
9
11
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
12
 
@@ -102,7 +104,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
102
104
  self.output_ids.append(req.output_ids[-1])
103
105
  self.tree_cache.cache_unfinished_req(req)
104
106
  if req.grammar is not None:
105
- req.grammar.accept_token(req.output_ids[-1])
107
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
108
+ try:
109
+ req.grammar.accept_token(req.output_ids[-1])
110
+ except ValueError as e:
111
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
112
+ # This can happen if the grammar is not set correctly or the token is invalid.
113
+ error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
114
+ self.tree_cache.cache_finished_req(req)
115
+ prepare_abort(
116
+ req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
117
+ )
106
118
  req.grammar.finished = req.finished()
107
119
  self.output_ids = torch.tensor(self.output_ids, device=self.device)
108
120
 
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
17
17
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
18
 
19
19
  from sglang.srt.disaggregation.utils import PDRegistryRequest
20
+ from sglang.srt.utils import maybe_wrap_ipv6_address
20
21
 
21
22
  AIOHTTP_STREAM_READ_CHUNK_SIZE = (
22
23
  1024 * 64
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
271
272
 
272
273
  # Parse and transform prefill_server for bootstrap data
273
274
  parsed_url = urllib.parse.urlparse(prefill_server)
274
- hostname = parsed_url.hostname
275
+ hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
275
276
  modified_request = request_data.copy()
276
277
 
277
278
  batch_size = _get_request_batch_size(modified_request)
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
309
310
 
310
311
  # Parse and transform prefill_server for bootstrap data
311
312
  parsed_url = urllib.parse.urlparse(prefill_server)
312
- hostname = parsed_url.hostname
313
+ hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
313
314
  modified_request = request_data.copy()
314
315
  modified_request.update(
315
316
  {
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
38
+ from sglang.srt.utils import (
39
+ format_tcp_address,
40
+ get_free_port,
41
+ get_int_env_var,
42
+ get_ip,
43
+ get_local_ip_auto,
44
+ is_valid_ipv6_address,
45
+ maybe_wrap_ipv6_address,
46
+ )
39
47
 
40
48
  logger = logging.getLogger(__name__)
41
49
 
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
148
156
  self.request_status: Dict[int, KVPoll] = {}
149
157
  self.rank_port = None
150
158
  self.server_socket = zmq.Context().socket(zmq.PULL)
159
+ if is_valid_ipv6_address(self.local_ip):
160
+ self.server_socket.setsockopt(zmq.IPV6, 1)
161
+
151
162
  self.register_buffer_to_engine()
152
163
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
153
164
  self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
240
251
  self.engine.register(aux_data_ptr, aux_data_len)
241
252
 
242
253
  @cache
243
- def _connect(self, endpoint: str):
254
+ def _connect(self, endpoint: str, is_ipv6: bool = False):
244
255
  socket = zmq.Context().socket(zmq.PUSH)
256
+ if is_ipv6:
257
+ socket.setsockopt(zmq.IPV6, 1)
245
258
  socket.connect(endpoint)
246
259
  return socket
247
260
 
@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
471
484
  def sync_status_to_decode_endpoint(
472
485
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
473
486
  ):
474
- if ":" in remote:
475
- remote = remote.split(":")[0]
476
- self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
487
+ self._connect(
488
+ format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
489
+ ).send_multipart(
477
490
  [
478
491
  str(room).encode("ascii"),
479
492
  str(status).encode("ascii"),
@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
616
629
  f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
617
630
  )
618
631
 
632
+ def _bind_server_socket(self):
633
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
634
+
619
635
  def start_prefill_thread(self):
620
636
  self.rank_port = get_free_port()
621
- self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
637
+ self._bind_server_socket()
622
638
 
623
639
  def bootstrap_thread():
624
640
  """This thread recvs pre-alloc notification from the decode engine"""
@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
657
673
 
658
674
  def start_decode_thread(self):
659
675
  self.rank_port = get_free_port()
660
- self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
676
+ self._bind_server_socket()
661
677
 
662
678
  def decode_thread():
663
679
  while True:
@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
776
792
  # requests with the same dst_sessions will be added into the same
777
793
  # queue, which enables early abort with failed sessions.
778
794
  dst_infos = self.transfer_infos[bootstrap_room].keys()
779
- session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
795
+ session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
780
796
  shard_idx = session_port_sum % len(self.transfer_queues)
781
797
 
782
798
  self.transfer_queues[shard_idx].put(
@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
814
830
  def _register_to_bootstrap(self):
815
831
  """Register KVSender to bootstrap server via HTTP POST."""
816
832
  if self.dist_init_addr:
817
- ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
833
+ if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
834
+ if self.dist_init_addr.endswith("]"):
835
+ host = self.dist_init_addr
836
+ else:
837
+ host, _ = self.dist_init_addr.rsplit(":", 1)
838
+ else:
839
+ host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
818
840
  else:
819
- ip_address = get_ip()
841
+ host = get_ip()
842
+ host = maybe_wrap_ipv6_address(host)
820
843
 
821
- bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
844
+ bootstrap_server_url = f"{host}:{self.bootstrap_port}"
822
845
  url = f"http://{bootstrap_server_url}/route"
823
846
  payload = {
824
847
  "role": "Prefill",
@@ -969,6 +992,14 @@ class MooncakeKVSender(BaseKVSender):
969
992
  )
970
993
  raise KVTransferError(self.bootstrap_room, failure_reason)
971
994
 
995
+ def abort(self):
996
+ self.kv_mgr.record_failure(
997
+ self.bootstrap_room,
998
+ "Aborted by AbortReq.",
999
+ )
1000
+ # Explicitly set the status to failure since this request has been aborted
1001
+ self.conclude_state = KVPoll.Failed
1002
+
972
1003
 
973
1004
  class MooncakeKVReceiver(BaseKVReceiver):
974
1005
  _ctx = zmq.Context()
@@ -1163,9 +1194,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1163
1194
 
1164
1195
  def _register_kv_args(self):
1165
1196
  for bootstrap_info in self.bootstrap_infos:
1166
- self.prefill_server_url = (
1167
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
1168
- )
1169
1197
  packed_kv_data_ptrs = b"".join(
1170
1198
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
1171
1199
  )
@@ -1179,7 +1207,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1179
1207
  dst_tp_size = str(tp_size).encode("ascii")
1180
1208
  dst_kv_item_len = str(kv_item_len).encode("ascii")
1181
1209
 
1182
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
1210
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1183
1211
  with lock:
1184
1212
  sock.send_multipart(
1185
1213
  [
@@ -1196,23 +1224,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
1196
1224
  )
1197
1225
 
1198
1226
  @classmethod
1199
- def _connect(cls, endpoint: str):
1227
+ def _connect(cls, endpoint: str, is_ipv6: bool = False):
1200
1228
  with cls._global_lock:
1201
1229
  if endpoint not in cls._socket_cache:
1202
1230
  sock = cls._ctx.socket(zmq.PUSH)
1231
+ if is_ipv6:
1232
+ sock.setsockopt(zmq.IPV6, 1)
1203
1233
  sock.connect(endpoint)
1204
1234
  cls._socket_cache[endpoint] = sock
1205
1235
  cls._socket_locks[endpoint] = threading.Lock()
1206
1236
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
1207
1237
 
1238
+ @classmethod
1239
+ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
1240
+ ip_address = bootstrap_info["rank_ip"]
1241
+ port = bootstrap_info["rank_port"]
1242
+ is_ipv6_address = is_valid_ipv6_address(ip_address)
1243
+ sock, lock = cls._connect(
1244
+ format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
1245
+ )
1246
+ return sock, lock
1247
+
1208
1248
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1209
1249
  for bootstrap_info in self.bootstrap_infos:
1210
- self.prefill_server_url = (
1211
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
1212
- )
1250
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1213
1251
  is_dummy = bootstrap_info["is_dummy"]
1214
1252
 
1215
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
1216
1253
  with lock:
1217
1254
  sock.send_multipart(
1218
1255
  [
@@ -1276,6 +1313,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
1276
1313
  )
1277
1314
  raise KVTransferError(self.bootstrap_room, failure_reason)
1278
1315
 
1316
+ def abort(self):
1317
+ self.kv_mgr.record_failure(
1318
+ self.bootstrap_room,
1319
+ "Aborted by AbortReq.",
1320
+ )
1321
+ # Explicitly set the status to failure since this request has been aborted
1322
+ self.conclude_state = KVPoll.Failed
1323
+
1279
1324
 
1280
1325
  class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1281
1326
  def __init__(self, port: int):
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import List, Optional
3
3
 
4
- from sglang.srt.utils import get_bool_env_var, get_free_port
4
+ from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
27
27
  hostname=self.hostname,
28
28
  device_name=self.ib_device,
29
29
  )
30
- self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
30
+ self.session_id = (
31
+ f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
32
+ )
31
33
 
32
34
  def register(self, ptr, length):
33
35
  try:
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
27
27
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
28
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
29
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import get_local_ip_by_remote
30
+ from sglang.srt.utils import (
31
+ format_tcp_address,
32
+ get_local_ip_auto,
33
+ is_valid_ipv6_address,
34
+ )
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
124
128
  "to run SGLang with NixlTransferEngine."
125
129
  ) from e
126
130
  self.agent = nixl_agent(str(uuid.uuid4()))
131
+ self.local_ip = get_local_ip_auto()
127
132
  self.server_socket = zmq.Context().socket(zmq.PULL)
133
+ if is_valid_ipv6_address(self.local_ip):
134
+ self.server_socket.setsockopt(zmq.IPV6, 1)
128
135
  self.register_buffer_to_engine()
129
136
 
130
137
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
337
344
  return False
338
345
  return self.transfer_statuses[room].is_done()
339
346
 
347
+ def _bind_server_socket(self):
348
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
349
+
340
350
  def _start_bootstrap_thread(self):
341
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
351
+ self._bind_server_socket()
342
352
 
343
353
  def bootstrap_thread():
344
354
  """This thread recvs transfer info from the decode engine"""
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
452
462
 
453
463
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
454
464
  for bootstrap_info in self.bootstrap_infos:
455
- self.prefill_server_url = (
456
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
457
- )
458
465
  logger.debug(
459
466
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
460
467
  )
468
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
461
469
  is_dummy = bootstrap_info["is_dummy"]
462
470
  logger.debug(
463
- f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
471
+ f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
464
472
  )
465
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
466
473
  with lock:
467
474
  sock.send_multipart(
468
475
  [
469
476
  GUARD,
470
477
  str(self.bootstrap_room).encode("ascii"),
471
- get_local_ip_by_remote().encode("ascii"),
478
+ self.kv_mgr.local_ip.encode("ascii"),
472
479
  str(self.kv_mgr.rank_port).encode("ascii"),
473
480
  self.kv_mgr.agent.name.encode("ascii"),
474
481
  kv_indices.tobytes() if not is_dummy else b"",
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
494
501
 
495
502
  def _register_kv_args(self):
496
503
  for bootstrap_info in self.bootstrap_infos:
497
- self.prefill_server_url = (
498
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499
- )
504
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
500
505
  packed_kv_data_ptrs = b"".join(
501
506
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502
507
  )
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
504
509
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505
510
  )
506
511
 
507
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
508
512
  with lock:
509
513
  sock.send_multipart(
510
514
  [
511
515
  GUARD,
512
516
  "None".encode("ascii"),
513
- get_local_ip_by_remote().encode("ascii"),
517
+ self.kv_mgr.local_ip.encode("ascii"),
514
518
  str(self.kv_mgr.rank_port).encode("ascii"),
515
519
  self.kv_mgr.agent.name.encode("ascii"),
516
520
  self.kv_mgr.agent.get_agent_metadata(),
@@ -425,7 +425,19 @@ class SchedulerDisaggregationPrefillMixin:
425
425
  self.send_kv_chunk(req, last_chunk=True)
426
426
 
427
427
  if req.grammar is not None:
428
- req.grammar.accept_token(next_token_id)
428
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
429
+ try:
430
+ req.grammar.accept_token(next_token_id)
431
+ except ValueError as e:
432
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
433
+ # This can happen if the grammar is not set correctly or the token is invalid.
434
+ error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
435
+ self.tree_cache.cache_finished_req(req)
436
+ prepare_abort(
437
+ req,
438
+ error_message,
439
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
440
+ )
429
441
  req.grammar.finished = req.finished()
430
442
  else:
431
443
  # being chunked reqs' prefill is not finished
@@ -4,18 +4,18 @@ import ctypes
4
4
  import logging
5
5
  import os
6
6
  from contextlib import contextmanager
7
- from functools import wraps
8
- from typing import Any, Callable, List, Optional, TypeVar, Union
7
+ from typing import Any, List, Optional, Union
9
8
 
10
9
  import torch
11
10
  import torch.distributed as dist
12
11
  from torch.distributed import ProcessGroup
13
- from typing_extensions import ParamSpec
14
12
 
15
13
  from sglang.srt import _custom_ops as ops
16
14
  from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
17
15
  from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
18
16
  gpu_p2p_access_check,
17
+ is_full_nvlink,
18
+ is_weak_contiguous,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
21
  from sglang.srt.utils import is_cuda, is_hip
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
25
25
  _is_cuda = is_cuda()
26
26
  _is_hip = is_hip()
27
27
 
28
- if _is_cuda:
29
- try:
30
- import pynvml
31
- except ImportError as e:
32
- logger.warning("Failed to import pynvml with %r", e)
33
-
34
- if _is_hip:
35
- try:
36
- from amdsmi import (
37
- AmdSmiException,
38
- amdsmi_get_processor_handles,
39
- amdsmi_init,
40
- amdsmi_shut_down,
41
- amdsmi_topo_get_link_type,
42
- )
43
- except ImportError as e:
44
- logger.warning("Failed to import amdsmi with %r", e)
45
28
 
46
29
  try:
47
30
  if ops.use_vllm_custom_allreduce and not _is_hip:
@@ -57,70 +40,6 @@ except Exception:
57
40
 
58
41
  logger = logging.getLogger(__name__)
59
42
 
60
- _P = ParamSpec("_P")
61
- _R = TypeVar("_R")
62
-
63
-
64
- def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
65
- @wraps(fn)
66
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
67
- if _is_hip:
68
- try:
69
- amdsmi_init()
70
- return fn(*args, **kwargs)
71
- finally:
72
- amdsmi_shut_down()
73
- else:
74
- pynvml.nvmlInit()
75
- try:
76
- return fn(*args, **kwargs)
77
- finally:
78
- pynvml.nvmlShutdown()
79
-
80
- return wrapper
81
-
82
-
83
- @with_nvml_context
84
- def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
85
- if _is_hip:
86
- """
87
- query if the set of gpus are fully connected by xgmi (1 hop)
88
- """
89
- handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
90
- for i, handle in enumerate(handles):
91
- for j, peer_handle in enumerate(handles):
92
- if i < j:
93
- try:
94
- link_type = amdsmi_topo_get_link_type(handle, peer_handle)
95
- # type is 2 for XGMI
96
- if link_type["hops"] != 1 or link_type["type"] != 2:
97
- return False
98
- except AmdSmiException as error:
99
- logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
100
- return False
101
- return True
102
- else:
103
- """
104
- query if the set of gpus are fully connected by nvlink (1 hop)
105
- """
106
- handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
107
- for i, handle in enumerate(handles):
108
- for j, peer_handle in enumerate(handles):
109
- if i < j:
110
- try:
111
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
112
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
113
- )
114
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
115
- return False
116
- except pynvml.NVMLError:
117
- logger.exception(
118
- "NVLink detection failed. This is normal if your"
119
- " machine has no NVLink equipped."
120
- )
121
- return False
122
- return True
123
-
124
43
 
125
44
  def _can_p2p(rank: int, world_size: int) -> bool:
126
45
  # SGLANG_SKIP_P2P_CHECK can be set to False in sglang
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
136
55
  return True
137
56
 
138
57
 
139
- def is_weak_contiguous(inp: torch.Tensor):
140
- return inp.is_contiguous() or (
141
- inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
142
- == inp.numel() * inp.element_size()
143
- )
144
-
145
-
146
58
  class CustomAllreduce:
147
59
  _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
148
60
  _MAX_CAR_SIZE = 8192 * 1024