sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
35
-
36
34
  GUARD = "NixlMsgGuard".encode("ascii")
37
35
 
38
36
 
39
37
  @dataclasses.dataclass
40
38
  class TransferInfo:
39
+ """Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
40
+
41
41
  room: int
42
42
  endpoint: str
43
43
  dst_port: int
44
- agent_metadata: bytes
45
44
  agent_name: str
46
- dst_kv_ptrs: list[int]
47
45
  dst_kv_indices: npt.NDArray[np.int32]
48
- dst_aux_ptrs: list[int]
49
46
  dst_aux_index: int
50
- dst_gpu_id: int
51
47
  required_dst_info_num: int
52
48
 
53
49
  def is_dummy(self):
@@ -59,14 +55,37 @@ class TransferInfo:
59
55
  room=int(msg[0].decode("ascii")),
60
56
  endpoint=msg[1].decode("ascii"),
61
57
  dst_port=int(msg[2].decode("ascii")),
62
- agent_metadata=msg[3],
63
- agent_name=msg[4].decode("ascii"),
58
+ agent_name=msg[3].decode("ascii"),
59
+ dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
60
+ dst_aux_index=int(msg[5].decode("ascii")),
61
+ required_dst_info_num=int(msg[6].decode("ascii")),
62
+ )
63
+
64
+
65
+ @dataclasses.dataclass
66
+ class KVArgsRegisterInfo:
67
+ """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
68
+
69
+ room: str
70
+ endpoint: str
71
+ dst_port: int
72
+ agent_name: str
73
+ agent_metadata: bytes
74
+ dst_kv_ptrs: list[int]
75
+ dst_aux_ptrs: list[int]
76
+ gpu_id: int
77
+
78
+ @classmethod
79
+ def from_zmq(cls, msg: List[bytes]):
80
+ return cls(
81
+ room=str(msg[0].decode("ascii")),
82
+ endpoint=msg[1].decode("ascii"),
83
+ dst_port=int(msg[2].decode("ascii")),
84
+ agent_name=msg[3].decode("ascii"),
85
+ agent_metadata=msg[4],
64
86
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
65
- dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
66
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
67
- dst_aux_index=int(msg[8].decode("ascii")),
68
- dst_gpu_id=int(msg[9].decode("ascii")),
69
- required_dst_info_num=int(msg[10].decode("ascii")),
87
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
88
+ gpu_id=int(msg[7].decode("ascii")),
70
89
  )
71
90
 
72
91
 
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
109
128
  self.register_buffer_to_engine()
110
129
 
111
130
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
112
- self.request_status = {}
113
- self.transfer_infos: Dict[int, TransferInfo] = {}
114
- self.peer_names: Dict[str, str] = {}
131
+ self.request_status: Dict[int, KVPoll] = {}
132
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
133
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
115
134
  self._start_bootstrap_thread()
116
135
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
117
136
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
@@ -140,7 +159,7 @@ class NixlKVManager(CommonKVManager):
140
159
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
141
160
  ):
142
161
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
143
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
162
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
144
163
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
145
164
  if not self.kv_descs:
146
165
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -149,15 +168,18 @@ class NixlKVManager(CommonKVManager):
149
168
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
150
169
  ):
151
170
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
152
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
171
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
153
172
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
154
173
  if not self.aux_descs:
155
174
  raise Exception("NIXL memory registration failed for aux tensors")
156
175
 
157
- def _add_remote(self, agent_name: str, agent_metadata: bytes):
158
- if agent_name not in self.peer_names:
159
- self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
160
- return self.peer_names[agent_name]
176
+ def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
177
+ agent_name = decode_kv_args.agent_name
178
+ if agent_name in self.decode_kv_args_table:
179
+ logger.info(f"Peer {agent_name} was already registered, ignoring.")
180
+ return
181
+ self.decode_kv_args_table[agent_name] = decode_kv_args
182
+ self.agent.add_remote_agent(decode_kv_args.agent_metadata)
161
183
 
162
184
  def send_kvcache(
163
185
  self,
@@ -193,8 +215,8 @@ class NixlKVManager(CommonKVManager):
193
215
  logger.debug(
194
216
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
195
217
  )
196
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
197
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
218
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
219
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
198
220
  # Transfer data
199
221
  xfer_handle = self.agent.initialize_xfer(
200
222
  "WRITE",
@@ -226,8 +248,8 @@ class NixlKVManager(CommonKVManager):
226
248
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
227
249
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
228
250
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
229
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
230
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
251
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
252
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
231
253
  # Transfer data
232
254
  xfer_handle = self.agent.initialize_xfer(
233
255
  "WRITE",
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
262
284
  if req.is_dummy():
263
285
  continue
264
286
 
265
- peer_name = self._add_remote(req.agent_name, req.agent_metadata)
266
287
  chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
267
288
  assert len(chunked_dst_kv_indice) == len(kv_indices)
289
+ assert req.agent_name in self.decode_kv_args_table
268
290
 
269
291
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
270
292
  kv_xfer_handle = self.send_kvcache(
271
- peer_name,
293
+ req.agent_name,
272
294
  kv_indices,
273
- req.dst_kv_ptrs,
295
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
274
296
  chunked_dst_kv_indice,
275
- req.dst_gpu_id,
297
+ self.decode_kv_args_table[req.agent_name].gpu_id,
276
298
  notif,
277
299
  )
278
300
  handles.append(kv_xfer_handle)
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
280
302
  if is_last:
281
303
  assert aux_index is not None
282
304
  aux_xfer_handle = self.send_aux(
283
- peer_name,
305
+ req.agent_name,
284
306
  aux_index,
285
- req.dst_aux_ptrs,
307
+ self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
286
308
  req.dst_aux_index,
287
309
  str(req.room) + "_aux",
288
310
  )
289
311
  handles.append(aux_xfer_handle)
312
+ if is_last:
313
+ del self.transfer_infos[bootstrap_room]
290
314
  return handles
291
315
 
292
316
  def update_transfer_status(self):
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
328
352
  ), f"First message should be {GUARD}. Foreign traffic?"
329
353
  waiting_req_bytes = waiting_req_bytes[1:]
330
354
  room = waiting_req_bytes[0].decode("ascii")
331
-
332
- required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
355
+ agent_name = waiting_req_bytes[3].decode("ascii")
356
+ if room == "None":
357
+ # Register new peer and save KV base pointers.
358
+ self._add_remote_peer(
359
+ KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
360
+ )
361
+ logger.debug(f"Register KVArgs from {agent_name} successfully")
362
+ continue
333
363
  room = int(room)
334
- agent_name = waiting_req_bytes[4].decode("ascii")
335
364
  if room not in self.transfer_infos:
336
365
  self.transfer_infos[room] = {}
337
366
  self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
338
367
  waiting_req_bytes
339
368
  )
340
-
369
+ required_dst_info_num = self.transfer_infos[room][
370
+ agent_name
371
+ ].required_dst_info_num
341
372
  logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
342
373
  if len(self.transfer_infos[room]) == required_dst_info_num:
343
374
  logger.debug(f"{room=} is bootstrapped")
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
391
422
  self.chunk_id += 1
392
423
  if is_last:
393
424
  self.has_sent = True
425
+ del self.kv_mgr.request_status[self.bootstrap_room]
394
426
 
395
427
  def poll(self) -> KVPoll:
396
428
  if not self.has_sent:
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
415
447
  data_parallel_rank: Optional[int] = None,
416
448
  ):
417
449
  self.started_transfer = False
450
+ self.conclude_state = None
418
451
  super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
419
452
 
420
453
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
426
459
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
427
460
  )
428
461
  is_dummy = bootstrap_info["is_dummy"]
429
-
430
- # TODO: send_kv_args earlier
431
- packed_kv_data_ptrs = b"".join(
432
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
433
- )
434
- packed_aux_data_ptrs = b"".join(
435
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
436
- )
437
-
438
462
  logger.debug(
439
- f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
463
+ f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
440
464
  )
441
465
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
442
466
  with lock:
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
446
470
  str(self.bootstrap_room).encode("ascii"),
447
471
  get_local_ip_by_remote().encode("ascii"),
448
472
  str(self.kv_mgr.rank_port).encode("ascii"),
449
- self.kv_mgr.agent.get_agent_metadata(),
450
473
  self.kv_mgr.agent.name.encode("ascii"),
451
- packed_kv_data_ptrs,
452
474
  kv_indices.tobytes() if not is_dummy else b"",
453
- packed_aux_data_ptrs,
454
475
  str(aux_index).encode("ascii"),
455
- str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
456
476
  str(self.required_dst_info_num).encode("ascii"),
457
477
  ]
458
478
  )
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
460
480
  self.started_transfer = True
461
481
 
462
482
  def poll(self) -> KVPoll:
483
+ if self.conclude_state is not None:
484
+ return self.conclude_state
463
485
  if not self.started_transfer:
464
486
  return KVPoll.WaitingForInput # type: ignore
465
487
 
466
488
  self.kv_mgr.update_transfer_status()
467
-
468
489
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
490
+ self.conclude_state = KVPoll.Success
491
+ del self.kv_mgr.transfer_statuses[self.bootstrap_room]
469
492
  return KVPoll.Success # type: ignore
470
493
  return KVPoll.WaitingForInput # type: ignore
471
494
 
472
495
  def _register_kv_args(self):
473
- pass
496
+ for bootstrap_info in self.bootstrap_infos:
497
+ self.prefill_server_url = (
498
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499
+ )
500
+ packed_kv_data_ptrs = b"".join(
501
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502
+ )
503
+ packed_aux_data_ptrs = b"".join(
504
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505
+ )
506
+
507
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
508
+ with lock:
509
+ sock.send_multipart(
510
+ [
511
+ GUARD,
512
+ "None".encode("ascii"),
513
+ get_local_ip_by_remote().encode("ascii"),
514
+ str(self.kv_mgr.rank_port).encode("ascii"),
515
+ self.kv_mgr.agent.name.encode("ascii"),
516
+ self.kv_mgr.agent.get_agent_metadata(),
517
+ packed_kv_data_ptrs,
518
+ packed_aux_data_ptrs,
519
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
520
+ ]
521
+ )
474
522
 
475
523
  def failure_exception(self):
476
524
  raise Exception("Fake KVReceiver Exception")
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
93
93
  self.gpu_id = gpu_id
94
94
  self.bootstrap_port = bootstrap_port
95
95
  self.queue: List[Req] = []
96
- self.pp_rank = pp_rank
97
- self.pp_size = pp_size
98
96
  self.gloo_group = gloo_group
99
97
  self.max_total_num_tokens = max_total_num_tokens
100
98
  self.scheduler = scheduler
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
124
122
  kv_args.kv_data_ptrs = kv_data_ptrs
125
123
  kv_args.kv_data_lens = kv_data_lens
126
124
  kv_args.kv_item_lens = kv_item_lens
125
+ if not self.is_mla_backend:
126
+ kv_args.kv_head_num = self.token_to_kv_pool.head_num
127
+ kv_args.page_size = self.token_to_kv_pool.page_size
127
128
 
128
129
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
129
130
  self.metadata_buffers.get_buf_infos()
@@ -275,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
275
276
  batch = self.get_new_batch_prefill()
276
277
 
277
278
  if require_mlp_sync(self.server_args):
278
- batch, _ = self.prepare_mlp_sync_batch(batch)
279
+ batch = self.prepare_mlp_sync_batch(batch)
279
280
  self.cur_batch = batch
280
281
 
281
282
  if batch:
@@ -309,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
309
310
  batch = self.get_new_batch_prefill()
310
311
 
311
312
  if require_mlp_sync(self.server_args):
312
- batch, _ = self.prepare_mlp_sync_batch(batch)
313
+ batch = self.prepare_mlp_sync_batch(batch)
313
314
  self.cur_batch = batch
314
315
  if batch:
315
316
  result = self.run_batch(batch)
@@ -74,7 +74,7 @@ class ReqToMetadataIdxAllocator:
74
74
  def available_size(self):
75
75
  return len(self.free_slots)
76
76
 
77
- def alloc(self) -> List[int]:
77
+ def alloc(self) -> Optional[int]:
78
78
  if len(self.free_slots) == 0:
79
79
  return None
80
80
 
@@ -107,9 +107,6 @@ class MetadataBuffers:
107
107
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
108
108
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
109
109
 
110
- self.output_hidden_states = torch.zeros(
111
- (size, hidden_size), dtype=dtype, device=device
112
- )
113
110
  self.output_token_logprobs_val = torch.zeros(
114
111
  (size, 16), dtype=torch.float32, device=device
115
112
  )
@@ -122,51 +119,50 @@ class MetadataBuffers:
122
119
  self.output_top_logprobs_idx = torch.zeros(
123
120
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
124
121
  )
122
+ self.output_hidden_states = torch.zeros(
123
+ (size, hidden_size), dtype=dtype, device=device
124
+ )
125
125
 
126
126
  def get_buf_infos(self):
127
127
  ptrs = [
128
128
  self.output_ids.data_ptr(),
129
- self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
130
129
  self.output_token_logprobs_val.data_ptr(),
131
130
  self.output_token_logprobs_idx.data_ptr(),
132
131
  self.output_top_logprobs_val.data_ptr(),
133
132
  self.output_top_logprobs_idx.data_ptr(),
133
+ self.output_hidden_states.data_ptr(),
134
134
  ]
135
135
  data_lens = [
136
136
  self.output_ids.nbytes,
137
- self.output_hidden_states.nbytes,
138
137
  self.output_token_logprobs_val.nbytes,
139
138
  self.output_token_logprobs_idx.nbytes,
140
139
  self.output_top_logprobs_val.nbytes,
141
140
  self.output_top_logprobs_idx.nbytes,
141
+ self.output_hidden_states.nbytes,
142
142
  ]
143
143
  item_lens = [
144
144
  self.output_ids[0].nbytes,
145
- self.output_hidden_states[0].nbytes,
146
145
  self.output_token_logprobs_val[0].nbytes,
147
146
  self.output_token_logprobs_idx[0].nbytes,
148
147
  self.output_top_logprobs_val[0].nbytes,
149
148
  self.output_top_logprobs_idx[0].nbytes,
149
+ self.output_hidden_states[0].nbytes,
150
150
  ]
151
151
  return ptrs, data_lens, item_lens
152
152
 
153
153
  def get_buf(self, idx: int):
154
154
  return (
155
155
  self.output_ids[idx],
156
- self.output_hidden_states[idx],
157
156
  self.output_token_logprobs_val[idx],
158
157
  self.output_token_logprobs_idx[idx],
159
158
  self.output_top_logprobs_val[idx],
160
159
  self.output_top_logprobs_idx[idx],
160
+ self.output_hidden_states[idx],
161
161
  )
162
162
 
163
163
  def set_buf(self, req: Req):
164
164
 
165
165
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
166
- if req.hidden_states_tensor is not None:
167
- self.output_hidden_states[req.metadata_buffer_index].copy_(
168
- req.hidden_states_tensor
169
- )
170
166
  if req.return_logprob:
171
167
  if req.output_token_logprobs_val: # not none or empty list
172
168
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -189,6 +185,11 @@ class MetadataBuffers:
189
185
  ] = torch.tensor(
190
186
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
191
187
  )
188
+ # for PD + spec decode
189
+ if req.hidden_states_tensor is not None:
190
+ self.output_hidden_states[req.metadata_buffer_index].copy_(
191
+ req.hidden_states_tensor
192
+ )
192
193
 
193
194
 
194
195
  #########################
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
42
42
  from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
+ get_int_env_var,
45
46
  is_cuda_alike,
46
47
  is_npu,
48
+ is_shm_available,
47
49
  supports_custom_op,
48
50
  )
49
51
 
@@ -222,6 +224,7 @@ class GroupCoordinator:
222
224
  self.local_rank = local_rank
223
225
  self.device_group = None
224
226
  self.cpu_group = None
227
+ self.local_size = get_int_env_var("LOCAL_SIZE", 0)
225
228
 
226
229
  for ranks in group_ranks:
227
230
  device_group = torch.distributed.new_group(
@@ -440,9 +443,12 @@ class GroupCoordinator:
440
443
  return input_
441
444
 
442
445
  if input_.is_cpu:
443
- import intel_extension_for_pytorch as ipex
444
-
445
- ipex.distributed.all_reduce(input_, group=self.device_group)
446
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
447
+ torch.ops.sgl_kernel.shm_allreduce(
448
+ input_, torch.distributed.ReduceOp.SUM
449
+ )
450
+ else:
451
+ torch.distributed.all_reduce(input_, group=self.device_group)
446
452
  return input_
447
453
 
448
454
  if not supports_custom_op():
@@ -570,6 +576,16 @@ class GroupCoordinator:
570
576
  output_tensor = torch.empty(
571
577
  output_size, dtype=input_.dtype, device=input_.device
572
578
  )
579
+
580
+ if input_.is_cpu:
581
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
582
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
583
+ else:
584
+ torch.distributed.all_gather_into_tensor(
585
+ output_tensor, input_, group=self.device_group
586
+ )
587
+ return output_tensor
588
+
573
589
  # All-gather.
574
590
  self.all_gather_into_tensor(output_tensor, input_)
575
591
  # Reshape
@@ -683,18 +699,25 @@ class GroupCoordinator:
683
699
  )
684
700
 
685
701
  # Serialize object to tensor and get the size as well
686
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
702
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
703
+ device=torch.cuda.current_device()
704
+ )
687
705
 
688
706
  size_tensor = torch.tensor(
689
- [object_tensor.numel()], dtype=torch.long, device="cpu"
707
+ [object_tensor.numel()],
708
+ dtype=torch.long,
709
+ device=torch.cuda.current_device(),
690
710
  )
691
711
 
692
712
  # Send object size
693
-
694
- torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
713
+ torch.distributed.send(
714
+ size_tensor, dst=self.ranks[dst], group=self.device_group
715
+ )
695
716
 
696
717
  # Send object
697
- torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
718
+ torch.distributed.send(
719
+ object_tensor, dst=self.ranks[dst], group=self.device_group
720
+ )
698
721
 
699
722
  return None
700
723
 
@@ -708,29 +731,31 @@ class GroupCoordinator:
708
731
  src != self.rank_in_group
709
732
  ), "Invalid source rank. Source rank is the same as the current rank."
710
733
 
711
- size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
734
+ size_tensor = torch.empty(
735
+ 1, dtype=torch.long, device=torch.cuda.current_device()
736
+ )
712
737
 
713
738
  # Receive object size
714
739
  rank_size = torch.distributed.recv(
715
- size_tensor, src=self.ranks[src], group=self.cpu_group
740
+ size_tensor, src=self.ranks[src], group=self.device_group
716
741
  )
717
742
 
718
743
  # Tensor to receive serialized objects into.
719
744
  object_tensor = torch.empty( # type: ignore[call-overload]
720
745
  size_tensor.item(), # type: ignore[arg-type]
721
746
  dtype=torch.uint8,
722
- device="cpu",
747
+ device=torch.cuda.current_device(),
723
748
  )
724
749
 
725
750
  rank_object = torch.distributed.recv(
726
- object_tensor, src=self.ranks[src], group=self.cpu_group
751
+ object_tensor, src=self.ranks[src], group=self.device_group
727
752
  )
728
753
 
729
754
  assert (
730
755
  rank_object == rank_size
731
756
  ), "Received object sender rank does not match the size sender rank."
732
757
 
733
- obj = pickle.loads(object_tensor.numpy().tobytes())
758
+ obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
734
759
 
735
760
  return obj
736
761
 
@@ -841,14 +866,16 @@ class GroupCoordinator:
841
866
  dst = (self.rank_in_group + 1) % self.world_size
842
867
  assert dst < self.world_size, f"Invalid dst rank ({dst})"
843
868
 
844
- metadata_list: List[Tuple[Any, Any]] = []
845
869
  assert isinstance(
846
870
  tensor_dict, dict
847
871
  ), f"Expecting a dictionary, got {type(tensor_dict)}"
848
872
  metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
849
- # `metadata_list` lives in CPU memory.
850
- # `send_object_list` has serialization & deserialization,
851
- # all happening on CPU. Therefore, we can use the CPU group.
873
+ # Note: While switching to Device-to-Device (D2D) would introduce an extra
874
+ # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
875
+ # show better overall transmission performance with D2D due to:
876
+ # 1. Superior D2D transfer bandwidth
877
+ # 2. Ability to overlap send and recv operations
878
+ # Thus the net performance gain justifies this approach.
852
879
  self.send_object(metadata_list, dst=dst)
853
880
  for tensor in tensor_list:
854
881
  if tensor.numel() == 0:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
48
48
  """Update model weights with in-memory tensor data."""
49
49
  pass
50
50
 
51
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
52
+ """Load a new LoRA adapter without re-launching the engine."""
53
+ pass
54
+
55
+ def unload_lora_adapter(self, lora_name: str):
56
+ """Unload a LoRA adapter without re-launching the engine."""
57
+ pass
58
+
51
59
  @abstractmethod
52
60
  def release_memory_occupation(self):
53
61
  """Release GPU memory occupation temporarily."""