sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ import threading
10
10
  import uuid
11
11
  from collections import defaultdict
12
12
  from functools import cache
13
- from typing import Dict, List, Optional, Tuple, Union
13
+ from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
14
14
 
15
15
  import numpy as np
16
16
  import numpy.typing as npt
@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
35
+ NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
36
+
37
+
38
+ # From Mooncake backend.
39
+ def group_concurrent_contiguous(
40
+ src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
41
+ ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
42
+ src_groups = []
43
+ dst_groups = []
44
+ current_src = [src_indices[0]]
45
+ current_dst = [dst_indices[0]]
46
+
47
+ for i in range(1, len(src_indices)):
48
+ src_contiguous = src_indices[i] == src_indices[i - 1] + 1
49
+ dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
50
+ if src_contiguous and dst_contiguous:
51
+ current_src.append(src_indices[i])
52
+ current_dst.append(dst_indices[i])
53
+ else:
54
+ src_groups.append(current_src)
55
+ dst_groups.append(current_dst)
56
+ current_src = [src_indices[i]]
57
+ current_dst = [dst_indices[i]]
58
+
59
+ src_groups.append(current_src)
60
+ dst_groups.append(current_dst)
61
+
62
+ return src_groups, dst_groups
63
+
64
+
65
+ GUARD = "NixlMsgGuard".encode("ascii")
66
+
35
67
 
36
68
  @dataclasses.dataclass
37
69
  class TransferInfo:
@@ -45,19 +77,36 @@ class TransferInfo:
45
77
  dst_aux_index: int
46
78
  dst_gpu_id: int
47
79
 
80
+ def is_dummy(self):
81
+ return self.endpoint == ""
82
+
48
83
  @classmethod
49
84
  def from_zmq(cls, msg: List[bytes]):
50
- return cls(
51
- room=int(msg[0].decode("ascii")),
52
- endpoint=msg[1].decode("ascii"),
53
- dst_port=int(msg[2].decode("ascii")),
54
- agent_metadata=msg[3],
55
- dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
56
- dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
57
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
58
- dst_aux_index=int(msg[7].decode("ascii")),
59
- dst_gpu_id=int(msg[8].decode("ascii")),
60
- )
85
+ if len(msg) == 1:
86
+ # dummy msg
87
+ return cls(
88
+ room=int(msg[0].decode("ascii")),
89
+ endpoint="",
90
+ dst_port=0,
91
+ agent_metadata=b"",
92
+ dst_kv_ptrs=[],
93
+ dst_kv_indices=np.array([], dtype=np.int64),
94
+ dst_aux_ptrs=[],
95
+ dst_aux_index=0,
96
+ dst_gpu_id=0,
97
+ )
98
+ else:
99
+ return cls(
100
+ room=int(msg[0].decode("ascii")),
101
+ endpoint=msg[1].decode("ascii"),
102
+ dst_port=int(msg[2].decode("ascii")),
103
+ agent_metadata=msg[3],
104
+ dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
105
+ dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
106
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
107
+ dst_aux_index=int(msg[7].decode("ascii")),
108
+ dst_gpu_id=int(msg[8].decode("ascii")),
109
+ )
61
110
 
62
111
 
63
112
  @dataclasses.dataclass
@@ -83,6 +132,7 @@ class NixlKVManager(BaseKVManager):
83
132
  args: KVArgs,
84
133
  disaggregation_mode: DisaggregationMode,
85
134
  server_args: ServerArgs,
135
+ is_mla_backend: Optional[bool] = False,
86
136
  ):
87
137
  try:
88
138
  from nixl._api import nixl_agent
@@ -98,6 +148,19 @@ class NixlKVManager(BaseKVManager):
98
148
  # for p/d multi node infer
99
149
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
100
150
  self.dist_init_addr = server_args.dist_init_addr
151
+ self.tp_size = server_args.tp_size
152
+
153
+ self.tp_rank = args.engine_rank
154
+ self.enable_dp_attention = server_args.enable_dp_attention
155
+ if self.enable_dp_attention:
156
+ assert (
157
+ server_args.dp_size > 1
158
+ ), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
159
+ self.dp_size = server_args.dp_size
160
+ self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
161
+ self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
162
+ self.dp_rank = args.engine_rank // self.tp_size_of_dp
163
+
101
164
  self.rank_port = None
102
165
  self.server_socket = zmq.Context().socket(zmq.PULL)
103
166
  self.register_buffer_to_engine()
@@ -110,7 +173,8 @@ class NixlKVManager(BaseKVManager):
110
173
  self._start_bootstrap_thread()
111
174
  self._register_to_bootstrap()
112
175
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
113
- self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
176
+ # bootstrap key -> (remote_engine_rank -> possible remote source info)
177
+ self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
114
178
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
115
179
  TransferStatus
116
180
  )
@@ -126,6 +190,7 @@ class NixlKVManager(BaseKVManager):
126
190
  ):
127
191
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
128
192
  self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
193
+ logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
129
194
  if not self.kv_descs:
130
195
  raise Exception("NIXL memory registration failed for kv tensors")
131
196
  aux_addrs = []
@@ -134,6 +199,7 @@ class NixlKVManager(BaseKVManager):
134
199
  ):
135
200
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
136
201
  self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
202
+ logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
137
203
  if not self.aux_descs:
138
204
  raise Exception("NIXL memory registration failed for aux tensors")
139
205
 
@@ -157,6 +223,12 @@ class NixlKVManager(BaseKVManager):
157
223
  dst_gpu_id: int,
158
224
  notif: str,
159
225
  ):
226
+ # group by indices
227
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
228
+ prefill_kv_indices, dst_kv_indices
229
+ )
230
+
231
+ logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
160
232
  # Make descs
161
233
  num_layers = len(self.kv_args.kv_data_ptrs)
162
234
  src_addrs = []
@@ -166,12 +238,16 @@ class NixlKVManager(BaseKVManager):
166
238
  dst_ptr = dst_kv_ptrs[layer_id]
167
239
  item_len = self.kv_args.kv_item_lens[layer_id]
168
240
 
169
- for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
170
- src_addr = src_ptr + int(prefill_index) * item_len
171
- dst_addr = dst_ptr + int(decode_index) * item_len
172
- length = item_len
241
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
242
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
243
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
244
+ length = item_len * len(prefill_index)
173
245
  src_addrs.append((src_addr, length, self.kv_args.gpu_id))
174
246
  dst_addrs.append((dst_addr, length, dst_gpu_id))
247
+
248
+ logger.debug(
249
+ f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
250
+ )
175
251
  src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
176
252
  dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
177
253
  # Transfer data
@@ -180,7 +256,7 @@ class NixlKVManager(BaseKVManager):
180
256
  src_descs,
181
257
  dst_descs,
182
258
  peer_name,
183
- notif.encode("ascii"),
259
+ notif.encode("ascii"), # type: ignore
184
260
  )
185
261
  if not xfer_handle:
186
262
  raise Exception("KVSender failed to create transfer")
@@ -213,7 +289,7 @@ class NixlKVManager(BaseKVManager):
213
289
  src_descs,
214
290
  dst_descs,
215
291
  peer_name,
216
- notif.encode("ascii"),
292
+ notif.encode("ascii"), # type: ignore
217
293
  )
218
294
  if not xfer_handle:
219
295
  raise Exception("KVSender failed to create transfer")
@@ -240,6 +316,9 @@ class NixlKVManager(BaseKVManager):
240
316
  req = self.transfer_infos[bootstrap_room]
241
317
  assert bootstrap_room == req.room
242
318
 
319
+ if req.is_dummy():
320
+ return []
321
+
243
322
  peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
244
323
  chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
245
324
  assert len(chunked_dst_kv_indice) == len(kv_indices)
@@ -256,6 +335,7 @@ class NixlKVManager(BaseKVManager):
256
335
  handles = [kv_xfer_handle]
257
336
  # Only the last chunk we need to send the aux data.
258
337
  if is_last:
338
+ assert aux_index is not None
259
339
  aux_xfer_handle = self.send_aux(
260
340
  peer_name,
261
341
  aux_index,
@@ -325,6 +405,13 @@ class NixlKVManager(BaseKVManager):
325
405
  """This thread recvs transfer info from the decode engine"""
326
406
  while True:
327
407
  waiting_req_bytes = self.server_socket.recv_multipart()
408
+ logger.debug(
409
+ f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
410
+ )
411
+ assert (
412
+ waiting_req_bytes[0] == GUARD
413
+ ), f"First message should be {GUARD}. Foreign traffic?"
414
+ waiting_req_bytes = waiting_req_bytes[1:]
328
415
  room = waiting_req_bytes[0].decode("ascii")
329
416
  if room == "None":
330
417
  continue
@@ -372,14 +459,13 @@ class NixlKVSender(BaseKVSender):
372
459
 
373
460
  def poll(self) -> KVPoll:
374
461
  if not self.has_sent:
375
- return KVPoll.WaitingForInput
376
-
462
+ return KVPoll.WaitingForInput # type: ignore
377
463
  states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
378
464
  if all([x == "DONE" for x in states]):
379
- return KVPoll.Success
465
+ return KVPoll.Success # type: ignore
380
466
  if any([x == "ERR" for x in states]):
381
467
  raise Exception("KVSender transfer encountered an error.")
382
- return KVPoll.WaitingForInput
468
+ return KVPoll.WaitingForInput # type: ignore
383
469
 
384
470
  def failure_exception(self):
385
471
  raise Exception("Fake KVSender Exception")
@@ -401,7 +487,7 @@ class NixlKVReceiver(BaseKVReceiver):
401
487
  # NOTE: key distinguished by bootstrap_addr and engine_rank
402
488
  bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
403
489
 
404
- if bootstrap_key not in self.kv_mgr.connection_pool:
490
+ if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
405
491
  self.bootstrap_info = self._get_bootstrap_info_from_server(
406
492
  self.kv_mgr.kv_args.engine_rank
407
493
  )
@@ -410,25 +496,79 @@ class NixlKVReceiver(BaseKVReceiver):
410
496
  f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
411
497
  )
412
498
  else:
413
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
499
+ self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
414
500
  else:
415
- self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
416
-
501
+ self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
417
502
  assert self.bootstrap_info is not None
418
503
 
419
- def _get_bootstrap_info_from_server(self, engine_rank):
504
+ # return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
505
+ # In each dict, there are multiple possible remotes named "equal sources".
506
+ # We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
507
+ def _get_bootstrap_info_from_server(
508
+ self, engine_rank
509
+ ) -> Optional[List[Dict[int, NixlEngineInfo]]]:
420
510
  """Fetch the bootstrap info from the bootstrap server."""
421
511
  try:
422
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
423
- response = requests.get(url)
424
- if response.status_code == 200:
512
+ if self.kv_mgr.enable_dp_attention:
513
+ url = f"http://{self.bootstrap_addr}/route"
514
+ response = requests.get(url)
515
+ if response.status_code != 200:
516
+ logger.error(
517
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
518
+ )
519
+ return None
520
+
425
521
  bootstrap_info = response.json()
426
- return bootstrap_info
427
- else:
428
- logger.error(
429
- f"Failed to get prefill server info: {response.status_code}, {response.text}"
522
+ assert isinstance(bootstrap_info, dict)
523
+ bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
524
+
525
+ # split out who need to send to this rank.
526
+ # currently for dpsk mla model, those ranks share the same latent cache.
527
+ # pick one as the real source
528
+
529
+ prefill_tp_size = len(bootstrap_info.keys())
530
+
531
+ assert (
532
+ prefill_tp_size >= self.kv_mgr.tp_size_of_dp
533
+ ), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
534
+
535
+ num_remote_tp_rank_we_managed = (
536
+ prefill_tp_size // self.kv_mgr.tp_size_of_dp
537
+ )
538
+
539
+ # We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
540
+ remote_tp_ranks = list(range(0, prefill_tp_size))
541
+ # split it into tp_size_of_dp parts and get our part
542
+ remote_tp_ranks_grouped = [
543
+ remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
544
+ for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
545
+ ]
546
+ managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
547
+
548
+ assert len(managed_ranks) == num_remote_tp_rank_we_managed
549
+
550
+ logger.debug(
551
+ f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
430
552
  )
431
- return None
553
+
554
+ return [
555
+ {
556
+ rk: bootstrap_info[rk]
557
+ for rk in bootstrap_info.keys()
558
+ if rk in managed_ranks
559
+ }
560
+ ]
561
+ else:
562
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
563
+ response = requests.get(url)
564
+ if response.status_code == 200:
565
+ bootstrap_info = response.json()
566
+ return [{engine_rank: bootstrap_info}]
567
+ else:
568
+ logger.error(
569
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
570
+ )
571
+ return None
432
572
  except Exception as e:
433
573
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
434
574
  return None
@@ -440,43 +580,67 @@ class NixlKVReceiver(BaseKVReceiver):
440
580
  return socket
441
581
 
442
582
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
443
- self.prefill_server_url = (
444
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
445
- )
446
- logger.debug(
447
- f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
448
- )
449
583
 
450
- packed_kv_data_ptrs = b"".join(
451
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
452
- )
453
- packed_aux_data_ptrs = b"".join(
454
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
455
- )
456
- self._connect("tcp://" + self.prefill_server_url).send_multipart(
457
- [
458
- str(self.bootstrap_room).encode("ascii"),
459
- get_local_ip_by_remote().encode("ascii"),
460
- str(self.kv_mgr.rank_port).encode("ascii"),
461
- self.kv_mgr.agent.get_agent_metadata(),
462
- packed_kv_data_ptrs,
463
- kv_indices.tobytes(),
464
- packed_aux_data_ptrs,
465
- str(aux_index).encode("ascii"),
466
- str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
584
+ assert self.bootstrap_info is not None
585
+ assert self.bootstrap_room is not None
586
+
587
+ for equal_sources in self.bootstrap_info:
588
+ remote_rank = list(equal_sources.keys())[
589
+ self.bootstrap_room % len(equal_sources)
467
590
  ]
468
- )
591
+ self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
592
+ logger.debug(
593
+ f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
594
+ )
595
+
596
+ packed_kv_data_ptrs = b"".join(
597
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
598
+ )
599
+ packed_aux_data_ptrs = b"".join(
600
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
601
+ )
602
+
603
+ logger.debug(
604
+ f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
605
+ )
606
+ self._connect("tcp://" + self.prefill_server_url).send_multipart(
607
+ [
608
+ GUARD,
609
+ str(self.bootstrap_room).encode("ascii"),
610
+ get_local_ip_by_remote().encode("ascii"),
611
+ str(self.kv_mgr.rank_port).encode("ascii"),
612
+ self.kv_mgr.agent.get_agent_metadata(),
613
+ packed_kv_data_ptrs,
614
+ kv_indices.tobytes(),
615
+ packed_aux_data_ptrs,
616
+ str(aux_index).encode("ascii"),
617
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
618
+ ]
619
+ )
620
+
621
+ for dummy_rank in equal_sources.keys():
622
+ if dummy_rank == remote_rank:
623
+ continue
624
+ dummy_info = equal_sources[dummy_rank]
625
+ dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
626
+ self._connect("tcp://" + dummy_url).send_multipart(
627
+ [
628
+ GUARD,
629
+ str(self.bootstrap_room).encode("ascii"),
630
+ ]
631
+ )
632
+
469
633
  self.started_transfer = True
470
634
 
471
635
  def poll(self) -> KVPoll:
472
636
  if not self.started_transfer:
473
- return KVPoll.WaitingForInput
637
+ return KVPoll.WaitingForInput # type: ignore
474
638
 
475
639
  self.kv_mgr.update_transfer_status()
476
640
 
477
- if self.kv_mgr.check_transfer_done(self.bootstrap_room):
478
- return KVPoll.Success
479
- return KVPoll.WaitingForInput
641
+ if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
642
+ return KVPoll.Success # type: ignore
643
+ return KVPoll.WaitingForInput # type: ignore
480
644
 
481
645
  def failure_exception(self):
482
646
  raise Exception("Fake KVReceiver Exception")
@@ -484,6 +648,7 @@ class NixlKVReceiver(BaseKVReceiver):
484
648
 
485
649
  class NixlKVBootstrapServer(BaseKVBootstrapServer):
486
650
  def __init__(self, port: int):
651
+ logger.debug(f"NixlKVBootstrapServer started on port {port}")
487
652
  self.port = port
488
653
  self.app = web.Application()
489
654
  self.store = dict()
@@ -564,13 +729,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
564
729
  engine_rank = int(data["engine_rank"])
565
730
  agent_name = data["agent_name"]
566
731
 
567
- # Add lock to make sure thread-safe
568
732
  if role == "Prefill":
569
- self.prefill_port_table[engine_rank] = {
570
- "rank_ip": rank_ip,
571
- "rank_port": rank_port,
572
- "agent_name": agent_name,
573
- }
733
+ async with self.lock:
734
+ self.prefill_port_table[engine_rank] = {
735
+ "rank_ip": rank_ip,
736
+ "rank_port": rank_port,
737
+ "agent_name": agent_name,
738
+ }
574
739
  logger.info(
575
740
  f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
576
741
  )
@@ -580,7 +745,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
580
745
  async def _handle_route_get(self, request: web.Request):
581
746
  engine_rank = request.query.get("engine_rank")
582
747
  if not engine_rank:
583
- return web.Response(text="Missing rank", status=400)
748
+ logger.debug(
749
+ f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
750
+ )
751
+ # Return a dict of all engine_rank
752
+ async with self.lock:
753
+ bootstrap_info = self.prefill_port_table
754
+ return web.json_response(bootstrap_info, status=200)
584
755
 
585
756
  # Find corresponding prefill info
586
757
  async with self.lock:
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import (
34
34
  ReqToMetadataIdxAllocator,
35
35
  TransferBackend,
36
36
  get_kv_class,
37
+ is_mla_backend,
37
38
  kv_to_page_indices,
38
39
  kv_to_page_num,
39
40
  poll_and_all_reduce,
@@ -69,6 +70,7 @@ class PrefillBootstrapQueue:
69
70
  scheduler: Scheduler,
70
71
  ):
71
72
  self.token_to_kv_pool = token_to_kv_pool
73
+ self.is_mla_backend = is_mla_backend(token_to_kv_pool)
72
74
  self.aux_dtype = aux_dtype
73
75
 
74
76
  self.metadata_buffers = metadata_buffers
@@ -112,7 +114,10 @@ class PrefillBootstrapQueue:
112
114
  kv_args.gpu_id = self.scheduler.gpu_id
113
115
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
114
116
  kv_manager = kv_manager_class(
115
- kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
117
+ kv_args,
118
+ DisaggregationMode.PREFILL,
119
+ self.scheduler.server_args,
120
+ self.is_mla_backend,
116
121
  )
117
122
  return kv_manager
118
123
 
@@ -277,19 +282,17 @@ class SchedulerDisaggregationPrefillMixin:
277
282
  next_token_ids,
278
283
  extend_input_len_per_req,
279
284
  extend_logprob_start_len_per_req,
280
- bid,
281
285
  ) = (
282
286
  result.logits_output,
283
287
  result.next_token_ids,
284
288
  result.extend_input_len_per_req,
285
289
  result.extend_logprob_start_len_per_req,
286
- result.bid,
287
290
  )
288
291
 
289
292
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
290
293
  if self.enable_overlap:
291
294
  # wait
292
- _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
295
+ _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
293
296
  else:
294
297
  next_token_ids = result.next_token_ids.tolist()
295
298
 
@@ -1,13 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
4
+ import warnings
3
5
  from collections import deque
4
6
  from enum import Enum
5
- from typing import List
7
+ from typing import List, Optional
6
8
 
7
9
  import numpy as np
10
+ import requests
8
11
  import torch
9
12
  import torch.distributed as dist
10
13
 
14
+ from sglang.srt.utils import get_ip
15
+
11
16
 
12
17
  class DisaggregationMode(Enum):
13
18
  NULL = "null"
@@ -107,7 +112,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
107
112
 
108
113
 
109
114
  def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
110
- # 1. The page is guaruanteed to be full except the last page.
115
+ # 1. The page is guaranteed to be full except the last page.
111
116
  # 2. page index = kv_index // page_size
112
117
  # The return vector is kv_indices[::page_size] // page_size
113
118
  if page_size == 1: # shortcut
@@ -119,3 +124,47 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
119
124
  def kv_to_page_num(num_kv_indices: int, page_size: int):
120
125
  # ceil(num_kv_indices / page_size)
121
126
  return (num_kv_indices + page_size - 1) // page_size
127
+
128
+
129
+ @dataclasses.dataclass
130
+ class PDRegistryRequest:
131
+ """A request to register a machine itself to the LB."""
132
+
133
+ mode: str
134
+ registry_url: str
135
+ bootstrap_port: Optional[int] = None
136
+
137
+ def __post_init__(self):
138
+ if self.mode == "prefill" and self.bootstrap_port is None:
139
+ raise ValueError("Bootstrap port must be set in PREFILL mode.")
140
+ elif self.mode == "decode" and self.bootstrap_port is not None:
141
+ raise ValueError("Bootstrap port must not be set in DECODE mode.")
142
+ elif self.mode not in ["prefill", "decode"]:
143
+ raise ValueError(
144
+ f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
145
+ )
146
+
147
+
148
+ def register_disaggregation_server(
149
+ mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
150
+ ):
151
+ boostrap_port = bootstrap_port if mode == "prefill" else None
152
+ registry_request = PDRegistryRequest(
153
+ mode=mode,
154
+ registry_url=f"http://{get_ip()}:{server_port}",
155
+ bootstrap_port=boostrap_port,
156
+ )
157
+ res = requests.post(
158
+ f"{pdlb_url}/register",
159
+ json=dataclasses.asdict(registry_request),
160
+ )
161
+ if res.status_code != 200:
162
+ warnings.warn(
163
+ f"Failed to register disaggregation server: {res.status_code} {res.text}"
164
+ )
165
+
166
+
167
+ def is_mla_backend(target_kv_pool) -> bool:
168
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
169
+
170
+ return isinstance(target_kv_pool, MLATokenToKVPool)
@@ -296,7 +296,6 @@ class CustomAllreduce:
296
296
  self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
297
297
  )
298
298
  self.register_buffer(self.buffer)
299
- self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
300
299
 
301
300
  self.disabled = False
302
301
 
@@ -430,13 +429,7 @@ class CustomAllreduce:
430
429
 
431
430
  if _is_hip:
432
431
  if self.full_nvlink:
433
- if self.world_size == 8:
434
- if self.MSCCL:
435
- return False
436
- else:
437
- return inp_size < self.max_size
438
- else:
439
- return inp_size < self.max_size
432
+ return inp_size < self.max_size
440
433
  return False
441
434
 
442
435
  return False
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+
5
+ from sglang.srt.utils import is_npu
6
+
7
+
8
+ class NpuCommunicator:
9
+
10
+ def __init__(self, group: ProcessGroup):
11
+ if not is_npu():
12
+ self.disabled = True
13
+ return
14
+ self.disabled = False
15
+ self.group = group
16
+ self.world_size = dist.get_world_size(self.group)
17
+
18
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
19
+ dist.all_reduce(x, group=self.group)
20
+ return x
21
+
22
+ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
23
+ world_size = self.world_size
24
+ if dim < 0:
25
+ # Convert negative dim to positive.
26
+ dim += x.dim()
27
+ input_size = x.size()
28
+ output_size = (input_size[0] * world_size,) + input_size[1:]
29
+ # Allocate output tensor.
30
+ output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
31
+ # All-gather.
32
+ dist.all_gather_into_tensor(output_tensor, x, group=self.group)
33
+ # Reshape
34
+ output_tensor = output_tensor.reshape((world_size,) + input_size)
35
+ output_tensor = output_tensor.movedim(0, dim)
36
+ output_tensor = output_tensor.reshape(
37
+ input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
38
+ )
39
+ return output_tensor