sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ import threading
10
10
  import uuid
11
11
  from collections import defaultdict
12
12
  from functools import cache
13
- from typing import Dict, List, Optional, Tuple, Union
13
+ from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
14
14
 
15
15
  import numpy as np
16
16
  import numpy.typing as npt
@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
35
+ NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
36
+
37
+
38
+ # From Mooncake backend.
39
+ def group_concurrent_contiguous(
40
+ src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
41
+ ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
42
+ src_groups = []
43
+ dst_groups = []
44
+ current_src = [src_indices[0]]
45
+ current_dst = [dst_indices[0]]
46
+
47
+ for i in range(1, len(src_indices)):
48
+ src_contiguous = src_indices[i] == src_indices[i - 1] + 1
49
+ dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
50
+ if src_contiguous and dst_contiguous:
51
+ current_src.append(src_indices[i])
52
+ current_dst.append(dst_indices[i])
53
+ else:
54
+ src_groups.append(current_src)
55
+ dst_groups.append(current_dst)
56
+ current_src = [src_indices[i]]
57
+ current_dst = [dst_indices[i]]
58
+
59
+ src_groups.append(current_src)
60
+ dst_groups.append(current_dst)
61
+
62
+ return src_groups, dst_groups
63
+
64
+
65
+ GUARD = "NixlMsgGuard".encode("ascii")
66
+
35
67
 
36
68
  @dataclasses.dataclass
37
69
  class TransferInfo:
@@ -45,19 +77,36 @@ class TransferInfo:
45
77
  dst_aux_index: int
46
78
  dst_gpu_id: int
47
79
 
80
+ def is_dummy(self):
81
+ return self.endpoint == ""
82
+
48
83
  @classmethod
49
84
  def from_zmq(cls, msg: List[bytes]):
50
- return cls(
51
- room=int(msg[0].decode("ascii")),
52
- endpoint=msg[1].decode("ascii"),
53
- dst_port=int(msg[2].decode("ascii")),
54
- agent_metadata=msg[3],
55
- dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
56
- dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
57
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
58
- dst_aux_index=int(msg[7].decode("ascii")),
59
- dst_gpu_id=int(msg[8].decode("ascii")),
60
- )
85
+ if len(msg) == 1:
86
+ # dummy msg
87
+ return cls(
88
+ room=int(msg[0].decode("ascii")),
89
+ endpoint="",
90
+ dst_port=0,
91
+ agent_metadata=b"",
92
+ dst_kv_ptrs=[],
93
+ dst_kv_indices=np.array([], dtype=np.int64),
94
+ dst_aux_ptrs=[],
95
+ dst_aux_index=0,
96
+ dst_gpu_id=0,
97
+ )
98
+ else:
99
+ return cls(
100
+ room=int(msg[0].decode("ascii")),
101
+ endpoint=msg[1].decode("ascii"),
102
+ dst_port=int(msg[2].decode("ascii")),
103
+ agent_metadata=msg[3],
104
+ dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
105
+ dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
106
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
107
+ dst_aux_index=int(msg[7].decode("ascii")),
108
+ dst_gpu_id=int(msg[8].decode("ascii")),
109
+ )
61
110
 
62
111
 
63
112
  @dataclasses.dataclass
@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager):
98
147
  # for p/d multi node infer
99
148
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
100
149
  self.dist_init_addr = server_args.dist_init_addr
150
+ self.tp_size = server_args.tp_size
151
+
152
+ self.tp_rank = args.engine_rank
153
+ self.enable_dp_attention = server_args.enable_dp_attention
154
+ if self.enable_dp_attention:
155
+ assert (
156
+ server_args.dp_size > 1
157
+ ), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
158
+ self.dp_size = server_args.dp_size
159
+ self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
160
+ self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
161
+ self.dp_rank = args.engine_rank // self.tp_size_of_dp
162
+
101
163
  self.rank_port = None
102
164
  self.server_socket = zmq.Context().socket(zmq.PULL)
103
165
  self.register_buffer_to_engine()
@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager):
110
172
  self._start_bootstrap_thread()
111
173
  self._register_to_bootstrap()
112
174
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
113
- self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
175
+ # bootstrap key -> (remote_engine_rank -> possible remote source info)
176
+ self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
114
177
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
115
178
  TransferStatus
116
179
  )
@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager):
126
189
  ):
127
190
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
128
191
  self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
192
+ logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
129
193
  if not self.kv_descs:
130
194
  raise Exception("NIXL memory registration failed for kv tensors")
131
195
  aux_addrs = []
@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager):
134
198
  ):
135
199
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
136
200
  self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
201
+ logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
137
202
  if not self.aux_descs:
138
203
  raise Exception("NIXL memory registration failed for aux tensors")
139
204
 
@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager):
157
222
  dst_gpu_id: int,
158
223
  notif: str,
159
224
  ):
225
+ # group by indices
226
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
227
+ prefill_kv_indices, dst_kv_indices
228
+ )
229
+
230
+ logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
160
231
  # Make descs
161
232
  num_layers = len(self.kv_args.kv_data_ptrs)
162
233
  src_addrs = []
@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager):
166
237
  dst_ptr = dst_kv_ptrs[layer_id]
167
238
  item_len = self.kv_args.kv_item_lens[layer_id]
168
239
 
169
- for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
170
- src_addr = src_ptr + int(prefill_index) * item_len
171
- dst_addr = dst_ptr + int(decode_index) * item_len
172
- length = item_len
240
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
241
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
242
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
243
+ length = item_len * len(prefill_index)
173
244
  src_addrs.append((src_addr, length, self.kv_args.gpu_id))
174
245
  dst_addrs.append((dst_addr, length, dst_gpu_id))
246
+
247
+ logger.debug(
248
+ f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
249
+ )
175
250
  src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
176
251
  dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
177
252
  # Transfer data
@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager):
180
255
  src_descs,
181
256
  dst_descs,
182
257
  peer_name,
183
- notif.encode("ascii"),
258
+ notif.encode("ascii"), # type: ignore
184
259
  )
185
260
  if not xfer_handle:
186
261
  raise Exception("KVSender failed to create transfer")
@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager):
213
288
  src_descs,
214
289
  dst_descs,
215
290
  peer_name,
216
- notif.encode("ascii"),
291
+ notif.encode("ascii"), # type: ignore
217
292
  )
218
293
  if not xfer_handle:
219
294
  raise Exception("KVSender failed to create transfer")
@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager):
240
315
  req = self.transfer_infos[bootstrap_room]
241
316
  assert bootstrap_room == req.room
242
317
 
318
+ if req.is_dummy():
319
+ return []
320
+
243
321
  peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
244
322
  chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
245
323
  assert len(chunked_dst_kv_indice) == len(kv_indices)
@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager):
256
334
  handles = [kv_xfer_handle]
257
335
  # Only the last chunk we need to send the aux data.
258
336
  if is_last:
337
+ assert aux_index is not None
259
338
  aux_xfer_handle = self.send_aux(
260
339
  peer_name,
261
340
  aux_index,
@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager):
325
404
  """This thread recvs transfer info from the decode engine"""
326
405
  while True:
327
406
  waiting_req_bytes = self.server_socket.recv_multipart()
407
+ logger.debug(
408
+ f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
409
+ )
410
+ assert (
411
+ waiting_req_bytes[0] == GUARD
412
+ ), f"First message should be {GUARD}. Foreign traffic?"
413
+ waiting_req_bytes = waiting_req_bytes[1:]
328
414
  room = waiting_req_bytes[0].decode("ascii")
329
415
  if room == "None":
330
416
  continue
@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender):
372
458
 
373
459
  def poll(self) -> KVPoll:
374
460
  if not self.has_sent:
375
- return KVPoll.WaitingForInput
376
-
461
+ return KVPoll.WaitingForInput # type: ignore
377
462
  states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
378
463
  if all([x == "DONE" for x in states]):
379
- return KVPoll.Success
464
+ return KVPoll.Success # type: ignore
380
465
  if any([x == "ERR" for x in states]):
381
466
  raise Exception("KVSender transfer encountered an error.")
382
- return KVPoll.WaitingForInput
467
+ return KVPoll.WaitingForInput # type: ignore
383
468
 
384
469
  def failure_exception(self):
385
470
  raise Exception("Fake KVSender Exception")
@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver):
401
486
  # NOTE: key distinguished by bootstrap_addr and engine_rank
402
487
  bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
403
488
 
404
- if bootstrap_key not in self.kv_mgr.connection_pool:
489
+ if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
405
490
  self.bootstrap_info = self._get_bootstrap_info_from_server(
406
491
  self.kv_mgr.kv_args.engine_rank
407
492
  )
@@ -410,25 +495,79 @@ class NixlKVReceiver(BaseKVReceiver):
410
495
  f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
411
496
  )
412
497
  else:
413
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
498
+ self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
414
499
  else:
415
- self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
416
-
500
+ self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
417
501
  assert self.bootstrap_info is not None
418
502
 
419
- def _get_bootstrap_info_from_server(self, engine_rank):
503
+ # return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
504
+ # In each dict, there are multiple possible remotes named "equal sources".
505
+ # We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
506
+ def _get_bootstrap_info_from_server(
507
+ self, engine_rank
508
+ ) -> Optional[List[Dict[int, NixlEngineInfo]]]:
420
509
  """Fetch the bootstrap info from the bootstrap server."""
421
510
  try:
422
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
423
- response = requests.get(url)
424
- if response.status_code == 200:
511
+ if self.kv_mgr.enable_dp_attention:
512
+ url = f"http://{self.bootstrap_addr}/route"
513
+ response = requests.get(url)
514
+ if response.status_code != 200:
515
+ logger.error(
516
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
517
+ )
518
+ return None
519
+
425
520
  bootstrap_info = response.json()
426
- return bootstrap_info
427
- else:
428
- logger.error(
429
- f"Failed to get prefill server info: {response.status_code}, {response.text}"
521
+ assert isinstance(bootstrap_info, dict)
522
+ bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
523
+
524
+ # split out who need to send to this rank.
525
+ # currently for dpsk mla model, those ranks share the same latent cache.
526
+ # pick one as the real source
527
+
528
+ prefill_tp_size = len(bootstrap_info.keys())
529
+
530
+ assert (
531
+ prefill_tp_size >= self.kv_mgr.tp_size_of_dp
532
+ ), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
533
+
534
+ num_remote_tp_rank_we_managed = (
535
+ prefill_tp_size // self.kv_mgr.tp_size_of_dp
536
+ )
537
+
538
+ # We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
539
+ remote_tp_ranks = list(range(0, prefill_tp_size))
540
+ # split it into tp_size_of_dp parts and get our part
541
+ remote_tp_ranks_grouped = [
542
+ remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
543
+ for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
544
+ ]
545
+ managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
546
+
547
+ assert len(managed_ranks) == num_remote_tp_rank_we_managed
548
+
549
+ logger.debug(
550
+ f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
430
551
  )
431
- return None
552
+
553
+ return [
554
+ {
555
+ rk: bootstrap_info[rk]
556
+ for rk in bootstrap_info.keys()
557
+ if rk in managed_ranks
558
+ }
559
+ ]
560
+ else:
561
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
562
+ response = requests.get(url)
563
+ if response.status_code == 200:
564
+ bootstrap_info = response.json()
565
+ return [{engine_rank: bootstrap_info}]
566
+ else:
567
+ logger.error(
568
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
569
+ )
570
+ return None
432
571
  except Exception as e:
433
572
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
434
573
  return None
@@ -440,43 +579,67 @@ class NixlKVReceiver(BaseKVReceiver):
440
579
  return socket
441
580
 
442
581
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
443
- self.prefill_server_url = (
444
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
445
- )
446
- logger.debug(
447
- f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
448
- )
449
582
 
450
- packed_kv_data_ptrs = b"".join(
451
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
452
- )
453
- packed_aux_data_ptrs = b"".join(
454
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
455
- )
456
- self._connect("tcp://" + self.prefill_server_url).send_multipart(
457
- [
458
- str(self.bootstrap_room).encode("ascii"),
459
- get_local_ip_by_remote().encode("ascii"),
460
- str(self.kv_mgr.rank_port).encode("ascii"),
461
- self.kv_mgr.agent.get_agent_metadata(),
462
- packed_kv_data_ptrs,
463
- kv_indices.tobytes(),
464
- packed_aux_data_ptrs,
465
- str(aux_index).encode("ascii"),
466
- str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
583
+ assert self.bootstrap_info is not None
584
+ assert self.bootstrap_room is not None
585
+
586
+ for equal_sources in self.bootstrap_info:
587
+ remote_rank = list(equal_sources.keys())[
588
+ self.bootstrap_room % len(equal_sources)
467
589
  ]
468
- )
590
+ self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
591
+ logger.debug(
592
+ f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
593
+ )
594
+
595
+ packed_kv_data_ptrs = b"".join(
596
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
597
+ )
598
+ packed_aux_data_ptrs = b"".join(
599
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
600
+ )
601
+
602
+ logger.debug(
603
+ f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
604
+ )
605
+ self._connect("tcp://" + self.prefill_server_url).send_multipart(
606
+ [
607
+ GUARD,
608
+ str(self.bootstrap_room).encode("ascii"),
609
+ get_local_ip_by_remote().encode("ascii"),
610
+ str(self.kv_mgr.rank_port).encode("ascii"),
611
+ self.kv_mgr.agent.get_agent_metadata(),
612
+ packed_kv_data_ptrs,
613
+ kv_indices.tobytes(),
614
+ packed_aux_data_ptrs,
615
+ str(aux_index).encode("ascii"),
616
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
617
+ ]
618
+ )
619
+
620
+ for dummy_rank in equal_sources.keys():
621
+ if dummy_rank == remote_rank:
622
+ continue
623
+ dummy_info = equal_sources[dummy_rank]
624
+ dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
625
+ self._connect("tcp://" + dummy_url).send_multipart(
626
+ [
627
+ GUARD,
628
+ str(self.bootstrap_room).encode("ascii"),
629
+ ]
630
+ )
631
+
469
632
  self.started_transfer = True
470
633
 
471
634
  def poll(self) -> KVPoll:
472
635
  if not self.started_transfer:
473
- return KVPoll.WaitingForInput
636
+ return KVPoll.WaitingForInput # type: ignore
474
637
 
475
638
  self.kv_mgr.update_transfer_status()
476
639
 
477
- if self.kv_mgr.check_transfer_done(self.bootstrap_room):
478
- return KVPoll.Success
479
- return KVPoll.WaitingForInput
640
+ if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
641
+ return KVPoll.Success # type: ignore
642
+ return KVPoll.WaitingForInput # type: ignore
480
643
 
481
644
  def failure_exception(self):
482
645
  raise Exception("Fake KVReceiver Exception")
@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver):
484
647
 
485
648
  class NixlKVBootstrapServer(BaseKVBootstrapServer):
486
649
  def __init__(self, port: int):
650
+ logger.debug(f"NixlKVBootstrapServer started on port {port}")
487
651
  self.port = port
488
652
  self.app = web.Application()
489
653
  self.store = dict()
@@ -564,13 +728,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
564
728
  engine_rank = int(data["engine_rank"])
565
729
  agent_name = data["agent_name"]
566
730
 
567
- # Add lock to make sure thread-safe
568
731
  if role == "Prefill":
569
- self.prefill_port_table[engine_rank] = {
570
- "rank_ip": rank_ip,
571
- "rank_port": rank_port,
572
- "agent_name": agent_name,
573
- }
732
+ async with self.lock:
733
+ self.prefill_port_table[engine_rank] = {
734
+ "rank_ip": rank_ip,
735
+ "rank_port": rank_port,
736
+ "agent_name": agent_name,
737
+ }
574
738
  logger.info(
575
739
  f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
576
740
  )
@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
580
744
  async def _handle_route_get(self, request: web.Request):
581
745
  engine_rank = request.query.get("engine_rank")
582
746
  if not engine_rank:
583
- return web.Response(text="Missing rank", status=400)
747
+ logger.debug(
748
+ f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
749
+ )
750
+ # Return a dict of all engine_rank
751
+ async with self.lock:
752
+ bootstrap_info = self.prefill_port_table
753
+ return web.json_response(bootstrap_info, status=200)
584
754
 
585
755
  # Find corresponding prefill info
586
756
  async with self.lock:
@@ -1,13 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
4
+ import warnings
3
5
  from collections import deque
4
6
  from enum import Enum
5
- from typing import List
7
+ from typing import List, Optional
6
8
 
7
9
  import numpy as np
10
+ import requests
8
11
  import torch
9
12
  import torch.distributed as dist
10
13
 
14
+ from sglang.srt.utils import get_ip
15
+
11
16
 
12
17
  class DisaggregationMode(Enum):
13
18
  NULL = "null"
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
119
124
  def kv_to_page_num(num_kv_indices: int, page_size: int):
120
125
  # ceil(num_kv_indices / page_size)
121
126
  return (num_kv_indices + page_size - 1) // page_size
127
+
128
+
129
+ @dataclasses.dataclass
130
+ class PDRegistryRequest:
131
+ """A request to register a machine itself to the LB."""
132
+
133
+ mode: str
134
+ registry_url: str
135
+ bootstrap_port: Optional[int] = None
136
+
137
+ def __post_init__(self):
138
+ if self.mode == "prefill" and self.bootstrap_port is None:
139
+ raise ValueError("Bootstrap port must be set in PREFILL mode.")
140
+ elif self.mode == "decode" and self.bootstrap_port is not None:
141
+ raise ValueError("Bootstrap port must not be set in DECODE mode.")
142
+ elif self.mode not in ["prefill", "decode"]:
143
+ raise ValueError(
144
+ f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
145
+ )
146
+
147
+
148
+ def register_disaggregation_server(
149
+ mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
150
+ ):
151
+ boostrap_port = bootstrap_port if mode == "prefill" else None
152
+ registry_request = PDRegistryRequest(
153
+ mode=mode,
154
+ registry_url=f"http://{get_ip()}:{server_port}",
155
+ bootstrap_port=boostrap_port,
156
+ )
157
+ res = requests.post(
158
+ f"{pdlb_url}/register",
159
+ json=dataclasses.asdict(registry_request),
160
+ )
161
+ if res.status_code != 200:
162
+ warnings.warn(
163
+ f"Failed to register disaggregation server: {res.status_code} {res.text}"
164
+ )
@@ -296,7 +296,6 @@ class CustomAllreduce:
296
296
  self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
297
297
  )
298
298
  self.register_buffer(self.buffer)
299
- self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
300
299
 
301
300
  self.disabled = False
302
301
 
@@ -430,13 +429,7 @@ class CustomAllreduce:
430
429
 
431
430
  if _is_hip:
432
431
  if self.full_nvlink:
433
- if self.world_size == 8:
434
- if self.MSCCL:
435
- return False
436
- else:
437
- return inp_size < self.max_size
438
- else:
439
- return inp_size < self.max_size
432
+ return inp_size < self.max_size
440
433
  return False
441
434
 
442
435
  return False
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+
5
+ from sglang.srt.utils import is_npu
6
+
7
+
8
+ class NpuCommunicator:
9
+
10
+ def __init__(self, group: ProcessGroup):
11
+ if not is_npu():
12
+ self.disabled = True
13
+ return
14
+ self.disabled = False
15
+ self.group = group
16
+ self.world_size = dist.get_world_size(self.group)
17
+
18
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
19
+ dist.all_reduce(x, group=self.group)
20
+ return x
21
+
22
+ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
23
+ world_size = self.world_size
24
+ if dim < 0:
25
+ # Convert negative dim to positive.
26
+ dim += x.dim()
27
+ input_size = x.size()
28
+ output_size = (input_size[0] * world_size,) + input_size[1:]
29
+ # Allocate output tensor.
30
+ output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
31
+ # All-gather.
32
+ dist.all_gather_into_tensor(output_tensor, x, group=self.group)
33
+ # Reshape
34
+ output_tensor = output_tensor.reshape((world_size,) + input_size)
35
+ output_tensor = output_tensor.movedim(0, dim)
36
+ output_tensor = output_tensor.reshape(
37
+ input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
38
+ )
39
+ return output_tensor
@@ -75,7 +75,8 @@ class PyNcclCommunicator:
75
75
  self.available = True
76
76
  self.disabled = False
77
77
 
78
- logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
78
+ if self.rank == 0:
79
+ logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
79
80
 
80
81
  if self.rank == 0:
81
82
  # get the unique id from NCCL
@@ -225,7 +225,8 @@ class MessageQueue:
225
225
  remote_subscribe_port = get_open_port()
226
226
  if is_valid_ipv6_address(connect_ip):
227
227
  self.remote_socket.setsockopt(IPV6, 1)
228
- socket_addr = f"tcp://*:{remote_subscribe_port}"
228
+ connect_ip = f"[{connect_ip}]"
229
+ socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
229
230
  self.remote_socket.bind(socket_addr)
230
231
 
231
232
  else: