sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -68,16 +68,28 @@ class TransferInfo:
68
68
  mooncake_session_id: str
69
69
  dst_kv_indices: npt.NDArray[np.int64]
70
70
  dst_aux_index: int
71
+ required_dst_info_num: int
72
+ is_dummy: bool
71
73
 
72
74
  @classmethod
73
75
  def from_zmq(cls, msg: List[bytes]):
76
+ if msg[4] == b"" and msg[5] == b"":
77
+ is_dummy = True
78
+ dst_kv_indices = np.array([], dtype=np.int64)
79
+ dst_aux_index = None
80
+ else:
81
+ dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
82
+ dst_aux_index = int(msg[5].decode("ascii"))
83
+ is_dummy = False
74
84
  return cls(
75
85
  room=int(msg[0].decode("ascii")),
76
86
  endpoint=msg[1].decode("ascii"),
77
87
  dst_port=int(msg[2].decode("ascii")),
78
88
  mooncake_session_id=msg[3].decode("ascii"),
79
- dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
80
- dst_aux_index=int(msg[5].decode("ascii")),
89
+ dst_kv_indices=dst_kv_indices,
90
+ dst_aux_index=dst_aux_index,
91
+ required_dst_info_num=int(msg[6].decode("ascii")),
92
+ is_dummy=is_dummy,
81
93
  )
82
94
 
83
95
 
@@ -108,6 +120,7 @@ class MooncakeKVManager(BaseKVManager):
108
120
  args: KVArgs,
109
121
  disaggregation_mode: DisaggregationMode,
110
122
  server_args: ServerArgs,
123
+ is_mla_backend: Optional[bool] = False,
111
124
  ):
112
125
  self.kv_args = args
113
126
  self.engine = MooncakeTransferEngine(
@@ -115,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
115
128
  gpu_id=self.kv_args.gpu_id,
116
129
  ib_device=self.kv_args.ib_device,
117
130
  )
131
+ self.is_mla_backend = is_mla_backend
118
132
  self.disaggregation_mode = disaggregation_mode
119
133
  # for p/d multi node infer
120
134
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -132,7 +146,7 @@ class MooncakeKVManager(BaseKVManager):
132
146
  self.register_buffer_to_engine()
133
147
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
134
148
  self.transfer_queue = queue.Queue()
135
- self.transfer_infos: Dict[int, TransferInfo] = {}
149
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
136
150
  self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
137
151
  self.start_prefill_thread()
138
152
  self._register_to_bootstrap()
@@ -145,6 +159,7 @@ class MooncakeKVManager(BaseKVManager):
145
159
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
146
160
  self.start_decode_thread()
147
161
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
162
+ self.prefill_tp_size_table: Dict[str, int] = {}
148
163
  self.prefill_dp_size_table: Dict[str, int] = {}
149
164
  else:
150
165
  raise ValueError(
@@ -218,7 +233,7 @@ class MooncakeKVManager(BaseKVManager):
218
233
  status = future.result()
219
234
  if status != 0:
220
235
  # Immediate shutdown on first error (existing tasks will finish)
221
- executor.shutdown(wait=False)
236
+ self.executor.shutdown(wait=False)
222
237
  for f in futures:
223
238
  f.cancel()
224
239
  return status
@@ -250,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
250
265
  self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
251
266
  [
252
267
  str(room).encode("ascii"),
253
- str(self.request_status[room]).encode("ascii"),
268
+ str(self.check_status(room)).encode("ascii"),
254
269
  ]
255
270
  )
256
271
 
@@ -264,8 +279,8 @@ class MooncakeKVManager(BaseKVManager):
264
279
  while True:
265
280
  waiting_req_bytes = self.server_socket.recv_multipart()
266
281
  room = waiting_req_bytes[0].decode("ascii")
282
+ mooncake_session_id = waiting_req_bytes[3].decode("ascii")
267
283
  if room == "None":
268
- mooncake_session_id = waiting_req_bytes[3].decode("ascii")
269
284
  self.decode_kv_args_table[mooncake_session_id] = (
270
285
  KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
271
286
  )
@@ -273,53 +288,84 @@ class MooncakeKVManager(BaseKVManager):
273
288
  f"Register KVArgs from {mooncake_session_id} successfully"
274
289
  )
275
290
  continue
276
- room = int(room)
277
- self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
278
-
279
- # NOTE: after bootstrapping we can mark the req as waiting for input
280
- self.request_status[room] = KVPoll.WaitingForInput
291
+ else:
292
+ required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
293
+ room = int(room)
294
+ if room not in self.transfer_infos:
295
+ self.transfer_infos[room] = {}
296
+
297
+ self.transfer_infos[room][mooncake_session_id] = (
298
+ TransferInfo.from_zmq(waiting_req_bytes)
299
+ )
300
+ # NOTE: after bootstrapping we can mark the req as waiting for input
301
+ if len(self.transfer_infos[room]) == required_dst_info_num:
302
+ self.update_status(room, KVPoll.WaitingForInput)
281
303
 
282
304
  def transfer_thread():
283
305
  # TODO: Shall we use KVPoll.Transferring state?
284
306
  while True:
285
307
  try:
286
308
  kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
287
- req = self.transfer_infos[kv_chunk.room]
288
- chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
289
- assert len(chunked_dst_kv_indice) == len(
290
- kv_chunk.prefill_kv_indices
291
- ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
292
-
293
- ret = self.send_kvcache(
294
- req.mooncake_session_id,
295
- kv_chunk.prefill_kv_indices,
296
- self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
297
- chunked_dst_kv_indice,
298
- )
299
- if ret != 0:
300
- self.request_status[kv_chunk.room] = KVPoll.Failed
301
- self.sync_status_to_decode_endpoint(
302
- req.endpoint, req.dst_port, req.room
303
- )
304
- continue
305
-
306
- if kv_chunk.is_last:
307
- # Only the last chunk we need to send the aux data
308
- ret = self.send_aux(
309
- req.mooncake_session_id,
310
- kv_chunk.prefill_aux_index,
311
- self.decode_kv_args_table[
312
- req.mooncake_session_id
313
- ].dst_aux_ptrs,
314
- req.dst_aux_index,
315
- )
316
- self.request_status[req.room] = (
317
- KVPoll.Success if ret == 0 else KVPoll.Failed
318
- )
319
- self.sync_status_to_decode_endpoint(
320
- req.endpoint, req.dst_port, req.room
321
- )
322
- self.transfer_infos.pop(req.room)
309
+ reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
310
+ polls = []
311
+ dst_ranks_infos = []
312
+ for req in reqs_to_be_processed:
313
+ if not req.is_dummy:
314
+ chunked_dst_kv_indice = req.dst_kv_indices[
315
+ kv_chunk.index_slice
316
+ ]
317
+ assert len(chunked_dst_kv_indice) == len(
318
+ kv_chunk.prefill_kv_indices
319
+ ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
320
+
321
+ ret = self.send_kvcache(
322
+ req.mooncake_session_id,
323
+ kv_chunk.prefill_kv_indices,
324
+ self.decode_kv_args_table[
325
+ req.mooncake_session_id
326
+ ].dst_kv_ptrs,
327
+ chunked_dst_kv_indice,
328
+ )
329
+ if ret != 0:
330
+ self.update_status(kv_chunk.room, KVPoll.Failed)
331
+ self.sync_status_to_decode_endpoint(
332
+ req.endpoint, req.dst_port, req.room
333
+ )
334
+ continue
335
+
336
+ if kv_chunk.is_last:
337
+ # Only the last chunk we need to send the aux data
338
+ ret = self.send_aux(
339
+ req.mooncake_session_id,
340
+ kv_chunk.prefill_aux_index,
341
+ self.decode_kv_args_table[
342
+ req.mooncake_session_id
343
+ ].dst_aux_ptrs,
344
+ req.dst_aux_index,
345
+ )
346
+ polls.append(True if ret == 0 else False)
347
+ dst_ranks_infos.append(
348
+ (req.endpoint, req.dst_port, req.room)
349
+ )
350
+
351
+ # Only sync status when all the dst ranks have received the kvcache
352
+ if len(polls) == req.required_dst_info_num:
353
+ self.update_status(
354
+ req.room,
355
+ KVPoll.Success if all(polls) else KVPoll.Failed,
356
+ )
357
+ for endpoint, dst_port, room in dst_ranks_infos:
358
+ self.sync_status_to_decode_endpoint(
359
+ endpoint, dst_port, room
360
+ )
361
+ else:
362
+ # Dummy request means the decode instance is not used, so its status can be marked as success directly
363
+ # Dummy request does not need to sync status to decode endpoint
364
+ if kv_chunk.is_last:
365
+ self.update_status(req.room, KVPoll.Success)
366
+
367
+ if self.check_status(kv_chunk.room) == KVPoll.Success:
368
+ self.transfer_infos.pop(kv_chunk.room)
323
369
 
324
370
  except queue.Empty:
325
371
  continue
@@ -336,7 +382,7 @@ class MooncakeKVManager(BaseKVManager):
336
382
  (bootstrap_room, status) = self.server_socket.recv_multipart()
337
383
  status = int(status.decode("ascii"))
338
384
  bootstrap_room = int(bootstrap_room.decode("ascii"))
339
- self.request_status[bootstrap_room] = status
385
+ self.update_status(bootstrap_room, status)
340
386
 
341
387
  threading.Thread(target=decode_thread).start()
342
388
 
@@ -360,11 +406,9 @@ class MooncakeKVManager(BaseKVManager):
360
406
  prefill_aux_index=aux_index,
361
407
  )
362
408
  )
363
- self.request_status[bootstrap_room] = KVPoll.WaitingForInput
409
+ self.update_status(bootstrap_room, KVPoll.WaitingForInput)
364
410
 
365
411
  def check_status(self, bootstrap_room: int):
366
- # TOOD: do we really need the poll()?
367
-
368
412
  return self.request_status[bootstrap_room]
369
413
 
370
414
  def update_status(self, bootstrap_room: int, status: KVPoll):
@@ -469,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
469
513
  self.session_id = self.kv_mgr.get_session_id()
470
514
  self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
471
515
 
472
- if not self.kv_mgr.enable_dp_attention:
473
- # We assume dp_attention should be activated simultaneously for
474
- # both prefill role and decode role. If the decode instance does
475
- # not enable dp_attention, then dp_attention is not enabled on the
476
- # prefill instance as well. Therefore, we should skip questioning
477
- # the prefill dp size to reduce bootstrap overhead.
478
- self.prefill_dp_size = 1
479
- elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
480
- self.prefill_dp_size, tp_size_per_dp_rank = (
516
+ if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
517
+ self.prefill_tp_size, self.prefill_dp_size = (
481
518
  self._get_prefill_dp_size_from_server()
482
519
  )
483
- # Currently, we don't allow prefill instance and decode instance to
484
- # have different TP sizes per DP rank.
485
- assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
486
- if self.prefill_dp_size is None:
520
+ if self.prefill_tp_size is None or self.prefill_dp_size is None:
487
521
  logger.error(
488
- f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
522
+ f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
489
523
  )
490
524
  else:
525
+ self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
526
+ self.prefill_tp_size
527
+ )
491
528
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
492
529
  self.prefill_dp_size
493
530
  )
494
531
  else:
532
+ self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
533
+ self.bootstrap_addr
534
+ ]
495
535
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
496
536
  self.bootstrap_addr
497
537
  ]
498
538
 
499
- # NOTE: key distinguished by bootstrap_addr and engine_rank
539
+ # Currently, we don't allow prefill instance and decode instance to
540
+ # have different TP sizes per DP rank, except for models using MLA.
541
+ local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
542
+ prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
543
+ if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
544
+ self.target_tp_rank = (
545
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
546
+ )
547
+ self.required_dst_info_num = 1
548
+ self.target_tp_ranks = [self.target_tp_rank]
549
+ elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
550
+ assert (
551
+ self.kv_mgr.is_mla_backend
552
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
553
+ self.target_tp_rank = (
554
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
555
+ ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
556
+ self.required_dst_info_num = (
557
+ local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
558
+ )
559
+ self.target_tp_ranks = [self.target_tp_rank]
560
+ else:
561
+ assert (
562
+ self.kv_mgr.is_mla_backend
563
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
564
+
565
+ # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
566
+ self.target_tp_ranks = [
567
+ rank
568
+ for rank in range(
569
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
570
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
571
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
572
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
573
+ )
574
+ ]
575
+
576
+ # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
577
+ # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
578
+ # or the KVPoll will never be set correctly
579
+ self.target_tp_rank = self.target_tp_ranks[0]
580
+ self.required_dst_info_num = 1
581
+
500
582
  self.target_dp_group = bootstrap_room % self.prefill_dp_size
501
- bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
583
+
584
+ # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
585
+ bootstrap_key = (
586
+ f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
587
+ )
502
588
 
503
589
  if bootstrap_key not in self.kv_mgr.connection_pool:
504
- self.bootstrap_info = self._get_bootstrap_info_from_server(
505
- self.kv_mgr.kv_args.engine_rank,
506
- self.target_dp_group,
507
- )
508
- if self.bootstrap_info is None:
590
+ bootstrap_infos = []
591
+ for target_tp_rank in self.target_tp_ranks:
592
+ bootstrap_info = self._get_bootstrap_info_from_server(
593
+ target_tp_rank,
594
+ self.target_dp_group,
595
+ )
596
+ if bootstrap_info is not None:
597
+ # NOTE: only support MLA for now: select one prefill rank as real rank
598
+ bootstrap_info["is_dummy"] = not bool(
599
+ target_tp_rank == self.target_tp_rank
600
+ or self.target_tp_rank is None
601
+ )
602
+ bootstrap_infos.append(bootstrap_info)
603
+ else:
604
+ logger.error(
605
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
606
+ )
607
+ self.bootstrap_infos = bootstrap_infos
608
+
609
+ if len(self.bootstrap_infos) == 0:
509
610
  logger.error(
510
611
  f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
511
612
  )
512
613
  else:
513
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
614
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
514
615
  # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
515
616
  self._register_kv_args()
516
617
  else:
517
- self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
618
+ self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
518
619
 
519
- assert self.bootstrap_info is not None
620
+ assert len(self.bootstrap_infos) > 0
520
621
  self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
521
622
 
522
623
  def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
@@ -543,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
543
644
  response = requests.get(url)
544
645
  if response.status_code == 200:
545
646
  prefill_parallel_info = response.json()
546
- return int(prefill_parallel_info["prefill_dp_size"]), int(
547
- prefill_parallel_info["tp_size_per_dp_rank"]
647
+ return int(prefill_parallel_info["prefill_tp_size"]), int(
648
+ prefill_parallel_info["prefill_dp_size"]
548
649
  )
549
650
  else:
550
651
  logger.error(
@@ -556,29 +657,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
556
657
  return None
557
658
 
558
659
  def _register_kv_args(self):
559
- self.prefill_server_url = (
560
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
561
- )
562
-
563
- packed_kv_data_ptrs = b"".join(
564
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
565
- )
566
- packed_aux_data_ptrs = b"".join(
567
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
568
- )
569
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
570
- with lock:
571
- sock.send_multipart(
572
- [
573
- "None".encode("ascii"),
574
- get_local_ip_by_remote().encode("ascii"),
575
- str(self.kv_mgr.rank_port).encode("ascii"),
576
- self.session_id.encode("ascii"),
577
- packed_kv_data_ptrs,
578
- packed_aux_data_ptrs,
579
- ]
660
+ for bootstrap_info in self.bootstrap_infos:
661
+ self.prefill_server_url = (
662
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
663
+ )
664
+ packed_kv_data_ptrs = b"".join(
665
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
666
+ )
667
+ packed_aux_data_ptrs = b"".join(
668
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
580
669
  )
581
670
 
671
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
672
+ with lock:
673
+ sock.send_multipart(
674
+ [
675
+ "None".encode("ascii"),
676
+ get_local_ip_by_remote().encode("ascii"),
677
+ str(self.kv_mgr.rank_port).encode("ascii"),
678
+ self.session_id.encode("ascii"),
679
+ packed_kv_data_ptrs,
680
+ packed_aux_data_ptrs,
681
+ ]
682
+ )
683
+
582
684
  @classmethod
583
685
  def _connect(cls, endpoint: str):
584
686
  with cls._global_lock:
@@ -590,25 +692,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
590
692
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
591
693
 
592
694
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
593
- self.prefill_server_url = (
594
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
595
- )
596
- logger.debug(
597
- f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
598
- )
599
-
600
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
601
- with lock:
602
- sock.send_multipart(
603
- [
604
- str(self.bootstrap_room).encode("ascii"),
605
- get_local_ip_by_remote().encode("ascii"),
606
- str(self.kv_mgr.rank_port).encode("ascii"),
607
- self.session_id.encode("ascii"),
608
- kv_indices.tobytes(),
609
- str(aux_index).encode("ascii"),
610
- ]
695
+ for bootstrap_info in self.bootstrap_infos:
696
+ self.prefill_server_url = (
697
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
611
698
  )
699
+ logger.debug(
700
+ f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
701
+ )
702
+ is_dummy = bootstrap_info["is_dummy"]
703
+
704
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
705
+ with lock:
706
+ sock.send_multipart(
707
+ [
708
+ str(self.bootstrap_room).encode("ascii"),
709
+ get_local_ip_by_remote().encode("ascii"),
710
+ str(self.kv_mgr.rank_port).encode("ascii"),
711
+ self.session_id.encode("ascii"),
712
+ kv_indices.tobytes() if not is_dummy else b"",
713
+ str(aux_index).encode("ascii") if not is_dummy else b"",
714
+ str(self.required_dst_info_num).encode("ascii"),
715
+ ]
716
+ )
612
717
 
613
718
  def poll(self) -> KVPoll:
614
719
  return self.kv_mgr.check_status(self.bootstrap_room)
@@ -624,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
624
729
  self.store = dict()
625
730
  self.lock = asyncio.Lock()
626
731
  self._setup_routes()
732
+ self.tp_size = None
627
733
  self.dp_size = None
628
734
  self.tp_size_per_dp_rank = None
629
735
  self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
@@ -658,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
658
764
  rank_port = int(data["rank_port"])
659
765
  engine_rank = int(data["engine_rank"])
660
766
 
767
+ if self.tp_size is None:
768
+ self.tp_size = tp_size
769
+
661
770
  if self.dp_size is None:
662
771
  self.dp_size = dp_size
663
772
 
@@ -693,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
693
802
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
694
803
  if int(engine_rank) == -1 and int(target_dp_group) == -1:
695
804
  prefill_parallel_info = {
805
+ "prefill_tp_size": self.tp_size,
696
806
  "prefill_dp_size": self.dp_size,
697
- "tp_size_per_dp_rank": self.tp_size_per_dp_rank,
698
807
  }
699
808
  return web.json_response(prefill_parallel_info, status=200)
700
809
 
701
810
  # Find corresponding prefill info
702
- tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
703
-
704
811
  async with self.lock:
705
812
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
706
- tp_rank_in_dp_group
813
+ int(engine_rank)
707
814
  ]
708
815
 
709
816
  if bootstrap_info is not None:
@@ -132,6 +132,7 @@ class NixlKVManager(BaseKVManager):
132
132
  args: KVArgs,
133
133
  disaggregation_mode: DisaggregationMode,
134
134
  server_args: ServerArgs,
135
+ is_mla_backend: Optional[bool] = False,
135
136
  ):
136
137
  try:
137
138
  from nixl._api import nixl_agent
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import (
34
34
  ReqToMetadataIdxAllocator,
35
35
  TransferBackend,
36
36
  get_kv_class,
37
+ is_mla_backend,
37
38
  kv_to_page_indices,
38
39
  kv_to_page_num,
39
40
  poll_and_all_reduce,
@@ -69,6 +70,7 @@ class PrefillBootstrapQueue:
69
70
  scheduler: Scheduler,
70
71
  ):
71
72
  self.token_to_kv_pool = token_to_kv_pool
73
+ self.is_mla_backend = is_mla_backend(token_to_kv_pool)
72
74
  self.aux_dtype = aux_dtype
73
75
 
74
76
  self.metadata_buffers = metadata_buffers
@@ -112,7 +114,10 @@ class PrefillBootstrapQueue:
112
114
  kv_args.gpu_id = self.scheduler.gpu_id
113
115
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
114
116
  kv_manager = kv_manager_class(
115
- kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
117
+ kv_args,
118
+ DisaggregationMode.PREFILL,
119
+ self.scheduler.server_args,
120
+ self.is_mla_backend,
116
121
  )
117
122
  return kv_manager
118
123
 
@@ -277,19 +282,17 @@ class SchedulerDisaggregationPrefillMixin:
277
282
  next_token_ids,
278
283
  extend_input_len_per_req,
279
284
  extend_logprob_start_len_per_req,
280
- bid,
281
285
  ) = (
282
286
  result.logits_output,
283
287
  result.next_token_ids,
284
288
  result.extend_input_len_per_req,
285
289
  result.extend_logprob_start_len_per_req,
286
- result.bid,
287
290
  )
288
291
 
289
292
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
290
293
  if self.enable_overlap:
291
294
  # wait
292
- _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
295
+ _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
293
296
  else:
294
297
  next_token_ids = result.next_token_ids.tolist()
295
298
 
@@ -112,7 +112,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
112
112
 
113
113
 
114
114
  def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
115
- # 1. The page is guaruanteed to be full except the last page.
115
+ # 1. The page is guaranteed to be full except the last page.
116
116
  # 2. page index = kv_index // page_size
117
117
  # The return vector is kv_indices[::page_size] // page_size
118
118
  if page_size == 1: # shortcut
@@ -162,3 +162,9 @@ def register_disaggregation_server(
162
162
  warnings.warn(
163
163
  f"Failed to register disaggregation server: {res.status_code} {res.text}"
164
164
  )
165
+
166
+
167
+ def is_mla_backend(target_kv_pool) -> bool:
168
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
169
+
170
+ return isinstance(target_kv_pool, MLATokenToKVPool)
@@ -285,6 +285,21 @@ class Engine(EngineBase):
285
285
  ret = loop.run_until_complete(generator.__anext__())
286
286
  return ret
287
287
 
288
+ async def async_encode(
289
+ self,
290
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
291
+ image_data: Optional[Union[List[str], str]] = None,
292
+ ) -> Dict:
293
+ """
294
+ Asynchronous version of encode method.
295
+
296
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
297
+ Please refer to `EmbeddingReqInput` for the documentation.
298
+ """
299
+ obj = EmbeddingReqInput(text=prompt, image_data=image_data)
300
+ generator = self.tokenizer_manager.generate_request(obj, None)
301
+ return await generator.__anext__()
302
+
288
303
  def shutdown(self):
289
304
  """Shutdown the engine"""
290
305
  kill_process_tree(os.getpid(), include_parent=False)
@@ -315,7 +330,7 @@ class Engine(EngineBase):
315
330
  return {
316
331
  **dataclasses.asdict(self.tokenizer_manager.server_args),
317
332
  **self.scheduler_info,
318
- **internal_states,
333
+ "internal_states": internal_states,
319
334
  "version": __version__,
320
335
  }
321
336
 
@@ -471,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
471
486
  if _is_cuda:
472
487
  assert_pkg_version(
473
488
  "sgl-kernel",
474
- "0.1.1",
489
+ "0.1.2.post1",
475
490
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
476
491
  )
477
492