sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
37
37
  def group_concurrent_contiguous(
38
38
  src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
39
39
  ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
40
- src_groups = []
41
- dst_groups = []
42
- current_src = [src_indices[0]]
43
- current_dst = [dst_indices[0]]
44
-
45
- for i in range(1, len(src_indices)):
46
- src_contiguous = src_indices[i] == src_indices[i - 1] + 1
47
- dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
48
- if src_contiguous and dst_contiguous:
49
- current_src.append(src_indices[i])
50
- current_dst.append(dst_indices[i])
51
- else:
52
- src_groups.append(current_src)
53
- dst_groups.append(current_dst)
54
- current_src = [src_indices[i]]
55
- current_dst = [dst_indices[i]]
40
+ """Vectorised NumPy implementation."""
41
+ if src_indices.size == 0:
42
+ return [], []
43
+
44
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
45
+ src_groups = np.split(src_indices, brk)
46
+ dst_groups = np.split(dst_indices, brk)
56
47
 
57
- src_groups.append(current_src)
58
- dst_groups.append(current_dst)
48
+ src_groups = [g.tolist() for g in src_groups]
49
+ dst_groups = [g.tolist() for g in dst_groups]
59
50
 
60
51
  return src_groups, dst_groups
61
52
 
@@ -77,16 +68,28 @@ class TransferInfo:
77
68
  mooncake_session_id: str
78
69
  dst_kv_indices: npt.NDArray[np.int64]
79
70
  dst_aux_index: int
71
+ required_dst_info_num: int
72
+ is_dummy: bool
80
73
 
81
74
  @classmethod
82
75
  def from_zmq(cls, msg: List[bytes]):
76
+ if msg[4] == b"" and msg[5] == b"":
77
+ is_dummy = True
78
+ dst_kv_indices = np.array([], dtype=np.int64)
79
+ dst_aux_index = None
80
+ else:
81
+ dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
82
+ dst_aux_index = int(msg[5].decode("ascii"))
83
+ is_dummy = False
83
84
  return cls(
84
85
  room=int(msg[0].decode("ascii")),
85
86
  endpoint=msg[1].decode("ascii"),
86
87
  dst_port=int(msg[2].decode("ascii")),
87
88
  mooncake_session_id=msg[3].decode("ascii"),
88
- dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
89
- dst_aux_index=int(msg[5].decode("ascii")),
89
+ dst_kv_indices=dst_kv_indices,
90
+ dst_aux_index=dst_aux_index,
91
+ required_dst_info_num=int(msg[6].decode("ascii")),
92
+ is_dummy=is_dummy,
90
93
  )
91
94
 
92
95
 
@@ -117,6 +120,7 @@ class MooncakeKVManager(BaseKVManager):
117
120
  args: KVArgs,
118
121
  disaggregation_mode: DisaggregationMode,
119
122
  server_args: ServerArgs,
123
+ is_mla_backend: Optional[bool] = False,
120
124
  ):
121
125
  self.kv_args = args
122
126
  self.engine = MooncakeTransferEngine(
@@ -124,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
124
128
  gpu_id=self.kv_args.gpu_id,
125
129
  ib_device=self.kv_args.ib_device,
126
130
  )
131
+ self.is_mla_backend = is_mla_backend
127
132
  self.disaggregation_mode = disaggregation_mode
128
133
  # for p/d multi node infer
129
134
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -141,7 +146,7 @@ class MooncakeKVManager(BaseKVManager):
141
146
  self.register_buffer_to_engine()
142
147
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
143
148
  self.transfer_queue = queue.Queue()
144
- self.transfer_infos: Dict[int, TransferInfo] = {}
149
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
145
150
  self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
146
151
  self.start_prefill_thread()
147
152
  self._register_to_bootstrap()
@@ -154,6 +159,7 @@ class MooncakeKVManager(BaseKVManager):
154
159
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
155
160
  self.start_decode_thread()
156
161
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
162
+ self.prefill_tp_size_table: Dict[str, int] = {}
157
163
  self.prefill_dp_size_table: Dict[str, int] = {}
158
164
  else:
159
165
  raise ValueError(
@@ -227,7 +233,7 @@ class MooncakeKVManager(BaseKVManager):
227
233
  status = future.result()
228
234
  if status != 0:
229
235
  # Immediate shutdown on first error (existing tasks will finish)
230
- executor.shutdown(wait=False)
236
+ self.executor.shutdown(wait=False)
231
237
  for f in futures:
232
238
  f.cancel()
233
239
  return status
@@ -259,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
259
265
  self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
260
266
  [
261
267
  str(room).encode("ascii"),
262
- str(self.request_status[room]).encode("ascii"),
268
+ str(self.check_status(room)).encode("ascii"),
263
269
  ]
264
270
  )
265
271
 
@@ -273,8 +279,8 @@ class MooncakeKVManager(BaseKVManager):
273
279
  while True:
274
280
  waiting_req_bytes = self.server_socket.recv_multipart()
275
281
  room = waiting_req_bytes[0].decode("ascii")
282
+ mooncake_session_id = waiting_req_bytes[3].decode("ascii")
276
283
  if room == "None":
277
- mooncake_session_id = waiting_req_bytes[3].decode("ascii")
278
284
  self.decode_kv_args_table[mooncake_session_id] = (
279
285
  KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
280
286
  )
@@ -282,53 +288,84 @@ class MooncakeKVManager(BaseKVManager):
282
288
  f"Register KVArgs from {mooncake_session_id} successfully"
283
289
  )
284
290
  continue
285
- room = int(room)
286
- self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
287
-
288
- # NOTE: after bootstrapping we can mark the req as waiting for input
289
- self.request_status[room] = KVPoll.WaitingForInput
291
+ else:
292
+ required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
293
+ room = int(room)
294
+ if room not in self.transfer_infos:
295
+ self.transfer_infos[room] = {}
296
+
297
+ self.transfer_infos[room][mooncake_session_id] = (
298
+ TransferInfo.from_zmq(waiting_req_bytes)
299
+ )
300
+ # NOTE: after bootstrapping we can mark the req as waiting for input
301
+ if len(self.transfer_infos[room]) == required_dst_info_num:
302
+ self.update_status(room, KVPoll.WaitingForInput)
290
303
 
291
304
  def transfer_thread():
292
305
  # TODO: Shall we use KVPoll.Transferring state?
293
306
  while True:
294
307
  try:
295
308
  kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
296
- req = self.transfer_infos[kv_chunk.room]
297
- chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
298
- assert len(chunked_dst_kv_indice) == len(
299
- kv_chunk.prefill_kv_indices
300
- ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
301
-
302
- ret = self.send_kvcache(
303
- req.mooncake_session_id,
304
- kv_chunk.prefill_kv_indices,
305
- self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
306
- chunked_dst_kv_indice,
307
- )
308
- if ret != 0:
309
- self.request_status[kv_chunk.room] = KVPoll.Failed
310
- self.sync_status_to_decode_endpoint(
311
- req.endpoint, req.dst_port, req.room
312
- )
313
- continue
314
-
315
- if kv_chunk.is_last:
316
- # Only the last chunk we need to send the aux data
317
- ret = self.send_aux(
318
- req.mooncake_session_id,
319
- kv_chunk.prefill_aux_index,
320
- self.decode_kv_args_table[
321
- req.mooncake_session_id
322
- ].dst_aux_ptrs,
323
- req.dst_aux_index,
324
- )
325
- self.request_status[req.room] = (
326
- KVPoll.Success if ret == 0 else KVPoll.Failed
327
- )
328
- self.sync_status_to_decode_endpoint(
329
- req.endpoint, req.dst_port, req.room
330
- )
331
- self.transfer_infos.pop(req.room)
309
+ reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
310
+ polls = []
311
+ dst_ranks_infos = []
312
+ for req in reqs_to_be_processed:
313
+ if not req.is_dummy:
314
+ chunked_dst_kv_indice = req.dst_kv_indices[
315
+ kv_chunk.index_slice
316
+ ]
317
+ assert len(chunked_dst_kv_indice) == len(
318
+ kv_chunk.prefill_kv_indices
319
+ ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
320
+
321
+ ret = self.send_kvcache(
322
+ req.mooncake_session_id,
323
+ kv_chunk.prefill_kv_indices,
324
+ self.decode_kv_args_table[
325
+ req.mooncake_session_id
326
+ ].dst_kv_ptrs,
327
+ chunked_dst_kv_indice,
328
+ )
329
+ if ret != 0:
330
+ self.update_status(kv_chunk.room, KVPoll.Failed)
331
+ self.sync_status_to_decode_endpoint(
332
+ req.endpoint, req.dst_port, req.room
333
+ )
334
+ continue
335
+
336
+ if kv_chunk.is_last:
337
+ # Only the last chunk we need to send the aux data
338
+ ret = self.send_aux(
339
+ req.mooncake_session_id,
340
+ kv_chunk.prefill_aux_index,
341
+ self.decode_kv_args_table[
342
+ req.mooncake_session_id
343
+ ].dst_aux_ptrs,
344
+ req.dst_aux_index,
345
+ )
346
+ polls.append(True if ret == 0 else False)
347
+ dst_ranks_infos.append(
348
+ (req.endpoint, req.dst_port, req.room)
349
+ )
350
+
351
+ # Only sync status when all the dst ranks have received the kvcache
352
+ if len(polls) == req.required_dst_info_num:
353
+ self.update_status(
354
+ req.room,
355
+ KVPoll.Success if all(polls) else KVPoll.Failed,
356
+ )
357
+ for endpoint, dst_port, room in dst_ranks_infos:
358
+ self.sync_status_to_decode_endpoint(
359
+ endpoint, dst_port, room
360
+ )
361
+ else:
362
+ # Dummy request means the decode instance is not used, so its status can be marked as success directly
363
+ # Dummy request does not need to sync status to decode endpoint
364
+ if kv_chunk.is_last:
365
+ self.update_status(req.room, KVPoll.Success)
366
+
367
+ if self.check_status(kv_chunk.room) == KVPoll.Success:
368
+ self.transfer_infos.pop(kv_chunk.room)
332
369
 
333
370
  except queue.Empty:
334
371
  continue
@@ -345,7 +382,7 @@ class MooncakeKVManager(BaseKVManager):
345
382
  (bootstrap_room, status) = self.server_socket.recv_multipart()
346
383
  status = int(status.decode("ascii"))
347
384
  bootstrap_room = int(bootstrap_room.decode("ascii"))
348
- self.request_status[bootstrap_room] = status
385
+ self.update_status(bootstrap_room, status)
349
386
 
350
387
  threading.Thread(target=decode_thread).start()
351
388
 
@@ -369,11 +406,9 @@ class MooncakeKVManager(BaseKVManager):
369
406
  prefill_aux_index=aux_index,
370
407
  )
371
408
  )
372
- self.request_status[bootstrap_room] = KVPoll.WaitingForInput
409
+ self.update_status(bootstrap_room, KVPoll.WaitingForInput)
373
410
 
374
411
  def check_status(self, bootstrap_room: int):
375
- # TOOD: do we really need the poll()?
376
-
377
412
  return self.request_status[bootstrap_room]
378
413
 
379
414
  def update_status(self, bootstrap_room: int, status: KVPoll):
@@ -478,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
478
513
  self.session_id = self.kv_mgr.get_session_id()
479
514
  self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
480
515
 
481
- if not self.kv_mgr.enable_dp_attention:
482
- # We assume dp_attention should be activated simultaneously for
483
- # both prefill role and decode role. If the decode instance does
484
- # not enable dp_attention, then dp_attention is not enabled on the
485
- # prefill instance as well. Therefore, we should skip questioning
486
- # the prefill dp size to reduce bootstrap overhead.
487
- self.prefill_dp_size = 1
488
- elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
489
- self.prefill_dp_size, tp_size_per_dp_rank = (
516
+ if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
517
+ self.prefill_tp_size, self.prefill_dp_size = (
490
518
  self._get_prefill_dp_size_from_server()
491
519
  )
492
- # Currently, we don't allow prefill instance and decode instance to
493
- # have different TP sizes per DP rank.
494
- assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
495
- if self.prefill_dp_size is None:
520
+ if self.prefill_tp_size is None or self.prefill_dp_size is None:
496
521
  logger.error(
497
- f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
522
+ f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
498
523
  )
499
524
  else:
525
+ self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
526
+ self.prefill_tp_size
527
+ )
500
528
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
501
529
  self.prefill_dp_size
502
530
  )
503
531
  else:
532
+ self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
533
+ self.bootstrap_addr
534
+ ]
504
535
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
505
536
  self.bootstrap_addr
506
537
  ]
507
538
 
508
- # NOTE: key distinguished by bootstrap_addr and engine_rank
539
+ # Currently, we don't allow prefill instance and decode instance to
540
+ # have different TP sizes per DP rank, except for models using MLA.
541
+ local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
542
+ prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
543
+ if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
544
+ self.target_tp_rank = (
545
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
546
+ )
547
+ self.required_dst_info_num = 1
548
+ self.target_tp_ranks = [self.target_tp_rank]
549
+ elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
550
+ assert (
551
+ self.kv_mgr.is_mla_backend
552
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
553
+ self.target_tp_rank = (
554
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
555
+ ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
556
+ self.required_dst_info_num = (
557
+ local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
558
+ )
559
+ self.target_tp_ranks = [self.target_tp_rank]
560
+ else:
561
+ assert (
562
+ self.kv_mgr.is_mla_backend
563
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
564
+
565
+ # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
566
+ self.target_tp_ranks = [
567
+ rank
568
+ for rank in range(
569
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
570
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
571
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
572
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
573
+ )
574
+ ]
575
+
576
+ # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
577
+ # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
578
+ # or the KVPoll will never be set correctly
579
+ self.target_tp_rank = self.target_tp_ranks[0]
580
+ self.required_dst_info_num = 1
581
+
509
582
  self.target_dp_group = bootstrap_room % self.prefill_dp_size
510
- bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
583
+
584
+ # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
585
+ bootstrap_key = (
586
+ f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
587
+ )
511
588
 
512
589
  if bootstrap_key not in self.kv_mgr.connection_pool:
513
- self.bootstrap_info = self._get_bootstrap_info_from_server(
514
- self.kv_mgr.kv_args.engine_rank,
515
- self.target_dp_group,
516
- )
517
- if self.bootstrap_info is None:
590
+ bootstrap_infos = []
591
+ for target_tp_rank in self.target_tp_ranks:
592
+ bootstrap_info = self._get_bootstrap_info_from_server(
593
+ target_tp_rank,
594
+ self.target_dp_group,
595
+ )
596
+ if bootstrap_info is not None:
597
+ # NOTE: only support MLA for now: select one prefill rank as real rank
598
+ bootstrap_info["is_dummy"] = not bool(
599
+ target_tp_rank == self.target_tp_rank
600
+ or self.target_tp_rank is None
601
+ )
602
+ bootstrap_infos.append(bootstrap_info)
603
+ else:
604
+ logger.error(
605
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
606
+ )
607
+ self.bootstrap_infos = bootstrap_infos
608
+
609
+ if len(self.bootstrap_infos) == 0:
518
610
  logger.error(
519
611
  f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
520
612
  )
521
613
  else:
522
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
614
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
523
615
  # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
524
616
  self._register_kv_args()
525
617
  else:
526
- self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
618
+ self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
527
619
 
528
- assert self.bootstrap_info is not None
620
+ assert len(self.bootstrap_infos) > 0
529
621
  self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
530
622
 
531
623
  def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
@@ -552,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
552
644
  response = requests.get(url)
553
645
  if response.status_code == 200:
554
646
  prefill_parallel_info = response.json()
555
- return int(prefill_parallel_info["prefill_dp_size"]), int(
556
- prefill_parallel_info["tp_size_per_dp_rank"]
647
+ return int(prefill_parallel_info["prefill_tp_size"]), int(
648
+ prefill_parallel_info["prefill_dp_size"]
557
649
  )
558
650
  else:
559
651
  logger.error(
@@ -565,29 +657,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
565
657
  return None
566
658
 
567
659
  def _register_kv_args(self):
568
- self.prefill_server_url = (
569
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
570
- )
571
-
572
- packed_kv_data_ptrs = b"".join(
573
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
574
- )
575
- packed_aux_data_ptrs = b"".join(
576
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
577
- )
578
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
579
- with lock:
580
- sock.send_multipart(
581
- [
582
- "None".encode("ascii"),
583
- get_local_ip_by_remote().encode("ascii"),
584
- str(self.kv_mgr.rank_port).encode("ascii"),
585
- self.session_id.encode("ascii"),
586
- packed_kv_data_ptrs,
587
- packed_aux_data_ptrs,
588
- ]
660
+ for bootstrap_info in self.bootstrap_infos:
661
+ self.prefill_server_url = (
662
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
663
+ )
664
+ packed_kv_data_ptrs = b"".join(
665
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
666
+ )
667
+ packed_aux_data_ptrs = b"".join(
668
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
589
669
  )
590
670
 
671
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
672
+ with lock:
673
+ sock.send_multipart(
674
+ [
675
+ "None".encode("ascii"),
676
+ get_local_ip_by_remote().encode("ascii"),
677
+ str(self.kv_mgr.rank_port).encode("ascii"),
678
+ self.session_id.encode("ascii"),
679
+ packed_kv_data_ptrs,
680
+ packed_aux_data_ptrs,
681
+ ]
682
+ )
683
+
591
684
  @classmethod
592
685
  def _connect(cls, endpoint: str):
593
686
  with cls._global_lock:
@@ -599,25 +692,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
599
692
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
600
693
 
601
694
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
602
- self.prefill_server_url = (
603
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
604
- )
605
- logger.debug(
606
- f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
607
- )
608
-
609
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
610
- with lock:
611
- sock.send_multipart(
612
- [
613
- str(self.bootstrap_room).encode("ascii"),
614
- get_local_ip_by_remote().encode("ascii"),
615
- str(self.kv_mgr.rank_port).encode("ascii"),
616
- self.session_id.encode("ascii"),
617
- kv_indices.tobytes(),
618
- str(aux_index).encode("ascii"),
619
- ]
695
+ for bootstrap_info in self.bootstrap_infos:
696
+ self.prefill_server_url = (
697
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
620
698
  )
699
+ logger.debug(
700
+ f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
701
+ )
702
+ is_dummy = bootstrap_info["is_dummy"]
703
+
704
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
705
+ with lock:
706
+ sock.send_multipart(
707
+ [
708
+ str(self.bootstrap_room).encode("ascii"),
709
+ get_local_ip_by_remote().encode("ascii"),
710
+ str(self.kv_mgr.rank_port).encode("ascii"),
711
+ self.session_id.encode("ascii"),
712
+ kv_indices.tobytes() if not is_dummy else b"",
713
+ str(aux_index).encode("ascii") if not is_dummy else b"",
714
+ str(self.required_dst_info_num).encode("ascii"),
715
+ ]
716
+ )
621
717
 
622
718
  def poll(self) -> KVPoll:
623
719
  return self.kv_mgr.check_status(self.bootstrap_room)
@@ -633,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
633
729
  self.store = dict()
634
730
  self.lock = asyncio.Lock()
635
731
  self._setup_routes()
732
+ self.tp_size = None
636
733
  self.dp_size = None
637
734
  self.tp_size_per_dp_rank = None
638
735
  self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
@@ -667,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
667
764
  rank_port = int(data["rank_port"])
668
765
  engine_rank = int(data["engine_rank"])
669
766
 
767
+ if self.tp_size is None:
768
+ self.tp_size = tp_size
769
+
670
770
  if self.dp_size is None:
671
771
  self.dp_size = dp_size
672
772
 
@@ -702,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
702
802
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
703
803
  if int(engine_rank) == -1 and int(target_dp_group) == -1:
704
804
  prefill_parallel_info = {
805
+ "prefill_tp_size": self.tp_size,
705
806
  "prefill_dp_size": self.dp_size,
706
- "tp_size_per_dp_rank": self.tp_size_per_dp_rank,
707
807
  }
708
808
  return web.json_response(prefill_parallel_info, status=200)
709
809
 
710
810
  # Find corresponding prefill info
711
- tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
712
-
713
811
  async with self.lock:
714
812
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
715
- tp_rank_in_dp_group
813
+ int(engine_rank)
716
814
  ]
717
815
 
718
816
  if bootstrap_info is not None: