sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
51
51
  return src_groups, dst_groups
52
52
 
53
53
 
54
+ # prefill
54
55
  @dataclasses.dataclass
55
56
  class TransferKVChunk:
56
57
  room: int
@@ -60,6 +61,7 @@ class TransferKVChunk:
60
61
  prefill_aux_index: Optional[int]
61
62
 
62
63
 
64
+ # decode
63
65
  @dataclasses.dataclass
64
66
  class TransferInfo:
65
67
  room: int
@@ -68,19 +70,32 @@ class TransferInfo:
68
70
  mooncake_session_id: str
69
71
  dst_kv_indices: npt.NDArray[np.int64]
70
72
  dst_aux_index: int
73
+ required_dst_info_num: int
74
+ is_dummy: bool
71
75
 
72
76
  @classmethod
73
77
  def from_zmq(cls, msg: List[bytes]):
78
+ if msg[4] == b"" and msg[5] == b"":
79
+ is_dummy = True
80
+ dst_kv_indices = np.array([], dtype=np.int64)
81
+ dst_aux_index = None
82
+ else:
83
+ dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
84
+ dst_aux_index = int(msg[5].decode("ascii"))
85
+ is_dummy = False
74
86
  return cls(
75
87
  room=int(msg[0].decode("ascii")),
76
88
  endpoint=msg[1].decode("ascii"),
77
89
  dst_port=int(msg[2].decode("ascii")),
78
90
  mooncake_session_id=msg[3].decode("ascii"),
79
- dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
80
- dst_aux_index=int(msg[5].decode("ascii")),
91
+ dst_kv_indices=dst_kv_indices,
92
+ dst_aux_index=dst_aux_index,
93
+ required_dst_info_num=int(msg[6].decode("ascii")),
94
+ is_dummy=is_dummy,
81
95
  )
82
96
 
83
97
 
98
+ # decode
84
99
  @dataclasses.dataclass
85
100
  class KVArgsRegisterInfo:
86
101
  room: str
@@ -108,6 +123,7 @@ class MooncakeKVManager(BaseKVManager):
108
123
  args: KVArgs,
109
124
  disaggregation_mode: DisaggregationMode,
110
125
  server_args: ServerArgs,
126
+ is_mla_backend: Optional[bool] = False,
111
127
  ):
112
128
  self.kv_args = args
113
129
  self.engine = MooncakeTransferEngine(
@@ -115,6 +131,7 @@ class MooncakeKVManager(BaseKVManager):
115
131
  gpu_id=self.kv_args.gpu_id,
116
132
  ib_device=self.kv_args.ib_device,
117
133
  )
134
+ self.is_mla_backend = is_mla_backend
118
135
  self.disaggregation_mode = disaggregation_mode
119
136
  # for p/d multi node infer
120
137
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -132,7 +149,7 @@ class MooncakeKVManager(BaseKVManager):
132
149
  self.register_buffer_to_engine()
133
150
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
134
151
  self.transfer_queue = queue.Queue()
135
- self.transfer_infos: Dict[int, TransferInfo] = {}
152
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
136
153
  self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
137
154
  self.start_prefill_thread()
138
155
  self._register_to_bootstrap()
@@ -145,6 +162,7 @@ class MooncakeKVManager(BaseKVManager):
145
162
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
146
163
  self.start_decode_thread()
147
164
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
165
+ self.prefill_tp_size_table: Dict[str, int] = {}
148
166
  self.prefill_dp_size_table: Dict[str, int] = {}
149
167
  else:
150
168
  raise ValueError(
@@ -218,7 +236,7 @@ class MooncakeKVManager(BaseKVManager):
218
236
  status = future.result()
219
237
  if status != 0:
220
238
  # Immediate shutdown on first error (existing tasks will finish)
221
- executor.shutdown(wait=False)
239
+ self.executor.shutdown(wait=False)
222
240
  for f in futures:
223
241
  f.cancel()
224
242
  return status
@@ -250,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
250
268
  self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
251
269
  [
252
270
  str(room).encode("ascii"),
253
- str(self.request_status[room]).encode("ascii"),
271
+ str(self.check_status(room)).encode("ascii"),
254
272
  ]
255
273
  )
256
274
 
@@ -264,8 +282,8 @@ class MooncakeKVManager(BaseKVManager):
264
282
  while True:
265
283
  waiting_req_bytes = self.server_socket.recv_multipart()
266
284
  room = waiting_req_bytes[0].decode("ascii")
285
+ mooncake_session_id = waiting_req_bytes[3].decode("ascii")
267
286
  if room == "None":
268
- mooncake_session_id = waiting_req_bytes[3].decode("ascii")
269
287
  self.decode_kv_args_table[mooncake_session_id] = (
270
288
  KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
271
289
  )
@@ -273,53 +291,84 @@ class MooncakeKVManager(BaseKVManager):
273
291
  f"Register KVArgs from {mooncake_session_id} successfully"
274
292
  )
275
293
  continue
276
- room = int(room)
277
- self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
278
-
279
- # NOTE: after bootstrapping we can mark the req as waiting for input
280
- self.request_status[room] = KVPoll.WaitingForInput
294
+ else:
295
+ required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
296
+ room = int(room)
297
+ if room not in self.transfer_infos:
298
+ self.transfer_infos[room] = {}
299
+
300
+ self.transfer_infos[room][mooncake_session_id] = (
301
+ TransferInfo.from_zmq(waiting_req_bytes)
302
+ )
303
+ # NOTE: after bootstrapping we can mark the req as waiting for input
304
+ if len(self.transfer_infos[room]) == required_dst_info_num:
305
+ self.update_status(room, KVPoll.WaitingForInput)
281
306
 
282
307
  def transfer_thread():
283
308
  # TODO: Shall we use KVPoll.Transferring state?
284
309
  while True:
285
310
  try:
286
311
  kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
287
- req = self.transfer_infos[kv_chunk.room]
288
- chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
289
- assert len(chunked_dst_kv_indice) == len(
290
- kv_chunk.prefill_kv_indices
291
- ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
292
-
293
- ret = self.send_kvcache(
294
- req.mooncake_session_id,
295
- kv_chunk.prefill_kv_indices,
296
- self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
297
- chunked_dst_kv_indice,
298
- )
299
- if ret != 0:
300
- self.request_status[kv_chunk.room] = KVPoll.Failed
301
- self.sync_status_to_decode_endpoint(
302
- req.endpoint, req.dst_port, req.room
303
- )
304
- continue
305
-
306
- if kv_chunk.is_last:
307
- # Only the last chunk we need to send the aux data
308
- ret = self.send_aux(
309
- req.mooncake_session_id,
310
- kv_chunk.prefill_aux_index,
311
- self.decode_kv_args_table[
312
- req.mooncake_session_id
313
- ].dst_aux_ptrs,
314
- req.dst_aux_index,
315
- )
316
- self.request_status[req.room] = (
317
- KVPoll.Success if ret == 0 else KVPoll.Failed
318
- )
319
- self.sync_status_to_decode_endpoint(
320
- req.endpoint, req.dst_port, req.room
321
- )
322
- self.transfer_infos.pop(req.room)
312
+ reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
313
+ polls = []
314
+ dst_ranks_infos = []
315
+ for req in reqs_to_be_processed:
316
+ if not req.is_dummy:
317
+ chunked_dst_kv_indice = req.dst_kv_indices[
318
+ kv_chunk.index_slice
319
+ ]
320
+ assert len(chunked_dst_kv_indice) == len(
321
+ kv_chunk.prefill_kv_indices
322
+ ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
323
+
324
+ ret = self.send_kvcache(
325
+ req.mooncake_session_id,
326
+ kv_chunk.prefill_kv_indices,
327
+ self.decode_kv_args_table[
328
+ req.mooncake_session_id
329
+ ].dst_kv_ptrs,
330
+ chunked_dst_kv_indice,
331
+ )
332
+ if ret != 0:
333
+ self.update_status(kv_chunk.room, KVPoll.Failed)
334
+ self.sync_status_to_decode_endpoint(
335
+ req.endpoint, req.dst_port, req.room
336
+ )
337
+ continue
338
+
339
+ if kv_chunk.is_last:
340
+ # Only the last chunk we need to send the aux data
341
+ ret = self.send_aux(
342
+ req.mooncake_session_id,
343
+ kv_chunk.prefill_aux_index,
344
+ self.decode_kv_args_table[
345
+ req.mooncake_session_id
346
+ ].dst_aux_ptrs,
347
+ req.dst_aux_index,
348
+ )
349
+ polls.append(True if ret == 0 else False)
350
+ dst_ranks_infos.append(
351
+ (req.endpoint, req.dst_port, req.room)
352
+ )
353
+
354
+ # Only sync status when all the dst ranks have received the kvcache
355
+ if len(polls) == req.required_dst_info_num:
356
+ self.update_status(
357
+ req.room,
358
+ KVPoll.Success if all(polls) else KVPoll.Failed,
359
+ )
360
+ for endpoint, dst_port, room in dst_ranks_infos:
361
+ self.sync_status_to_decode_endpoint(
362
+ endpoint, dst_port, room
363
+ )
364
+ else:
365
+ # Dummy request means the decode instance is not used, so its status can be marked as success directly
366
+ # Dummy request does not need to sync status to decode endpoint
367
+ if kv_chunk.is_last:
368
+ self.update_status(req.room, KVPoll.Success)
369
+
370
+ if self.check_status(kv_chunk.room) == KVPoll.Success:
371
+ self.transfer_infos.pop(kv_chunk.room)
323
372
 
324
373
  except queue.Empty:
325
374
  continue
@@ -336,7 +385,7 @@ class MooncakeKVManager(BaseKVManager):
336
385
  (bootstrap_room, status) = self.server_socket.recv_multipart()
337
386
  status = int(status.decode("ascii"))
338
387
  bootstrap_room = int(bootstrap_room.decode("ascii"))
339
- self.request_status[bootstrap_room] = status
388
+ self.update_status(bootstrap_room, status)
340
389
 
341
390
  threading.Thread(target=decode_thread).start()
342
391
 
@@ -360,11 +409,9 @@ class MooncakeKVManager(BaseKVManager):
360
409
  prefill_aux_index=aux_index,
361
410
  )
362
411
  )
363
- self.request_status[bootstrap_room] = KVPoll.WaitingForInput
412
+ self.update_status(bootstrap_room, KVPoll.WaitingForInput)
364
413
 
365
414
  def check_status(self, bootstrap_room: int):
366
- # TOOD: do we really need the poll()?
367
-
368
415
  return self.request_status[bootstrap_room]
369
416
 
370
417
  def update_status(self, bootstrap_room: int, status: KVPoll):
@@ -420,6 +467,8 @@ class MooncakeKVSender(BaseKVSender):
420
467
  self.aux_index = None
421
468
  self.bootstrap_server_url = bootstrap_addr
422
469
  self.session_id = self.kv_mgr.get_session_id()
470
+ # inner state
471
+ self.curr_idx = 0
423
472
 
424
473
  def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
425
474
  self.num_kv_indices = num_kv_indices
@@ -428,9 +477,11 @@ class MooncakeKVSender(BaseKVSender):
428
477
  def send(
429
478
  self,
430
479
  kv_indices: npt.NDArray[np.int64],
431
- index_slice: slice,
432
- is_last: bool,
433
480
  ):
481
+ index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
482
+ self.curr_idx += len(kv_indices)
483
+ is_last = self.curr_idx == self.num_kv_indices
484
+
434
485
  if not is_last:
435
486
  self.kv_mgr.add_transfer_request(
436
487
  self.bootstrap_room, kv_indices, index_slice, False
@@ -448,6 +499,7 @@ class MooncakeKVSender(BaseKVSender):
448
499
  return self.kv_mgr.check_status(self.bootstrap_room)
449
500
 
450
501
  def failure_exception(self):
502
+ # TODO: raise a real exception
451
503
  raise Exception("Fake KVSender Exception")
452
504
 
453
505
 
@@ -469,54 +521,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
469
521
  self.session_id = self.kv_mgr.get_session_id()
470
522
  self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
471
523
 
472
- if not self.kv_mgr.enable_dp_attention:
473
- # We assume dp_attention should be activated simultaneously for
474
- # both prefill role and decode role. If the decode instance does
475
- # not enable dp_attention, then dp_attention is not enabled on the
476
- # prefill instance as well. Therefore, we should skip questioning
477
- # the prefill dp size to reduce bootstrap overhead.
478
- self.prefill_dp_size = 1
479
- elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
480
- self.prefill_dp_size, tp_size_per_dp_rank = (
524
+ if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
525
+ self.prefill_tp_size, self.prefill_dp_size = (
481
526
  self._get_prefill_dp_size_from_server()
482
527
  )
483
- # Currently, we don't allow prefill instance and decode instance to
484
- # have different TP sizes per DP rank.
485
- assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
486
- if self.prefill_dp_size is None:
528
+ if self.prefill_tp_size is None or self.prefill_dp_size is None:
487
529
  logger.error(
488
- f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
530
+ f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
489
531
  )
490
532
  else:
533
+ self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
534
+ self.prefill_tp_size
535
+ )
491
536
  self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
492
537
  self.prefill_dp_size
493
538
  )
494
539
  else:
540
+ self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
541
+ self.bootstrap_addr
542
+ ]
495
543
  self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
496
544
  self.bootstrap_addr
497
545
  ]
498
546
 
499
- # NOTE: key distinguished by bootstrap_addr and engine_rank
547
+ # Currently, we don't allow prefill instance and decode instance to
548
+ # have different TP sizes per DP rank, except for models using MLA.
549
+ local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
550
+ prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
551
+ if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
552
+ self.target_tp_rank = (
553
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
554
+ )
555
+ self.required_dst_info_num = 1
556
+ self.target_tp_ranks = [self.target_tp_rank]
557
+ elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
558
+ assert (
559
+ self.kv_mgr.is_mla_backend
560
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
561
+ self.target_tp_rank = (
562
+ self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
563
+ ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
564
+ self.required_dst_info_num = (
565
+ local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
566
+ )
567
+ self.target_tp_ranks = [self.target_tp_rank]
568
+ else:
569
+ assert (
570
+ self.kv_mgr.is_mla_backend
571
+ ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
572
+
573
+ # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
574
+ self.target_tp_ranks = [
575
+ rank
576
+ for rank in range(
577
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
578
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
579
+ (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
580
+ * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
581
+ )
582
+ ]
583
+
584
+ # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
585
+ # multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
586
+ # or the KVPoll will never be set correctly
587
+ self.target_tp_rank = self.target_tp_ranks[0]
588
+ self.required_dst_info_num = 1
589
+
500
590
  self.target_dp_group = bootstrap_room % self.prefill_dp_size
501
- bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
591
+
592
+ # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
593
+ bootstrap_key = (
594
+ f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
595
+ )
502
596
 
503
597
  if bootstrap_key not in self.kv_mgr.connection_pool:
504
- self.bootstrap_info = self._get_bootstrap_info_from_server(
505
- self.kv_mgr.kv_args.engine_rank,
506
- self.target_dp_group,
507
- )
508
- if self.bootstrap_info is None:
598
+ bootstrap_infos = []
599
+ for target_tp_rank in self.target_tp_ranks:
600
+ bootstrap_info = self._get_bootstrap_info_from_server(
601
+ target_tp_rank,
602
+ self.target_dp_group,
603
+ )
604
+ if bootstrap_info is not None:
605
+ # NOTE: only support MLA for now: select one prefill rank as real rank
606
+ bootstrap_info["is_dummy"] = not bool(
607
+ target_tp_rank == self.target_tp_rank
608
+ or self.target_tp_rank is None
609
+ )
610
+ bootstrap_infos.append(bootstrap_info)
611
+ else:
612
+ logger.error(
613
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
614
+ )
615
+ self.bootstrap_infos = bootstrap_infos
616
+
617
+ if len(self.bootstrap_infos) == 0:
509
618
  logger.error(
510
619
  f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
511
620
  )
512
621
  else:
513
- self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
622
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
514
623
  # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
515
624
  self._register_kv_args()
516
625
  else:
517
- self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
626
+ self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
518
627
 
519
- assert self.bootstrap_info is not None
628
+ assert len(self.bootstrap_infos) > 0
520
629
  self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
521
630
 
522
631
  def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
@@ -543,8 +652,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
543
652
  response = requests.get(url)
544
653
  if response.status_code == 200:
545
654
  prefill_parallel_info = response.json()
546
- return int(prefill_parallel_info["prefill_dp_size"]), int(
547
- prefill_parallel_info["tp_size_per_dp_rank"]
655
+ return int(prefill_parallel_info["prefill_tp_size"]), int(
656
+ prefill_parallel_info["prefill_dp_size"]
548
657
  )
549
658
  else:
550
659
  logger.error(
@@ -556,29 +665,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
556
665
  return None
557
666
 
558
667
  def _register_kv_args(self):
559
- self.prefill_server_url = (
560
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
561
- )
562
-
563
- packed_kv_data_ptrs = b"".join(
564
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
565
- )
566
- packed_aux_data_ptrs = b"".join(
567
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
568
- )
569
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
570
- with lock:
571
- sock.send_multipart(
572
- [
573
- "None".encode("ascii"),
574
- get_local_ip_by_remote().encode("ascii"),
575
- str(self.kv_mgr.rank_port).encode("ascii"),
576
- self.session_id.encode("ascii"),
577
- packed_kv_data_ptrs,
578
- packed_aux_data_ptrs,
579
- ]
668
+ for bootstrap_info in self.bootstrap_infos:
669
+ self.prefill_server_url = (
670
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
671
+ )
672
+ packed_kv_data_ptrs = b"".join(
673
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
674
+ )
675
+ packed_aux_data_ptrs = b"".join(
676
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
580
677
  )
581
678
 
679
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
680
+ with lock:
681
+ sock.send_multipart(
682
+ [
683
+ "None".encode("ascii"),
684
+ get_local_ip_by_remote().encode("ascii"),
685
+ str(self.kv_mgr.rank_port).encode("ascii"),
686
+ self.session_id.encode("ascii"),
687
+ packed_kv_data_ptrs,
688
+ packed_aux_data_ptrs,
689
+ ]
690
+ )
691
+
582
692
  @classmethod
583
693
  def _connect(cls, endpoint: str):
584
694
  with cls._global_lock:
@@ -590,30 +700,34 @@ class MooncakeKVReceiver(BaseKVReceiver):
590
700
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
591
701
 
592
702
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
593
- self.prefill_server_url = (
594
- f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
595
- )
596
- logger.debug(
597
- f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
598
- )
599
-
600
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
601
- with lock:
602
- sock.send_multipart(
603
- [
604
- str(self.bootstrap_room).encode("ascii"),
605
- get_local_ip_by_remote().encode("ascii"),
606
- str(self.kv_mgr.rank_port).encode("ascii"),
607
- self.session_id.encode("ascii"),
608
- kv_indices.tobytes(),
609
- str(aux_index).encode("ascii"),
610
- ]
703
+ for bootstrap_info in self.bootstrap_infos:
704
+ self.prefill_server_url = (
705
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
611
706
  )
707
+ logger.debug(
708
+ f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
709
+ )
710
+ is_dummy = bootstrap_info["is_dummy"]
711
+
712
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
713
+ with lock:
714
+ sock.send_multipart(
715
+ [
716
+ str(self.bootstrap_room).encode("ascii"),
717
+ get_local_ip_by_remote().encode("ascii"),
718
+ str(self.kv_mgr.rank_port).encode("ascii"),
719
+ self.session_id.encode("ascii"),
720
+ kv_indices.tobytes() if not is_dummy else b"",
721
+ str(aux_index).encode("ascii") if not is_dummy else b"",
722
+ str(self.required_dst_info_num).encode("ascii"),
723
+ ]
724
+ )
612
725
 
613
726
  def poll(self) -> KVPoll:
614
727
  return self.kv_mgr.check_status(self.bootstrap_room)
615
728
 
616
729
  def failure_exception(self):
730
+ # TODO: raise a real exception
617
731
  raise Exception("Fake KVReceiver Exception")
618
732
 
619
733
 
@@ -624,6 +738,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
624
738
  self.store = dict()
625
739
  self.lock = asyncio.Lock()
626
740
  self._setup_routes()
741
+ self.tp_size = None
627
742
  self.dp_size = None
628
743
  self.tp_size_per_dp_rank = None
629
744
  self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
@@ -658,6 +773,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
658
773
  rank_port = int(data["rank_port"])
659
774
  engine_rank = int(data["engine_rank"])
660
775
 
776
+ if self.tp_size is None:
777
+ self.tp_size = tp_size
778
+
661
779
  if self.dp_size is None:
662
780
  self.dp_size = dp_size
663
781
 
@@ -693,17 +811,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
693
811
  # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
694
812
  if int(engine_rank) == -1 and int(target_dp_group) == -1:
695
813
  prefill_parallel_info = {
814
+ "prefill_tp_size": self.tp_size,
696
815
  "prefill_dp_size": self.dp_size,
697
- "tp_size_per_dp_rank": self.tp_size_per_dp_rank,
698
816
  }
699
817
  return web.json_response(prefill_parallel_info, status=200)
700
818
 
701
819
  # Find corresponding prefill info
702
- tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
703
-
704
820
  async with self.lock:
705
821
  bootstrap_info = self.prefill_port_table[int(target_dp_group)][
706
- tp_rank_in_dp_group
822
+ int(engine_rank)
707
823
  ]
708
824
 
709
825
  if bootstrap_info is not None:
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
61
61
  self, session_id: str, buffer: int, peer_buffer_address: int, length: int
62
62
  ) -> int:
63
63
  """Synchronously transfer data to the specified address."""
64
-
64
+ # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
65
+ # later: based on the cached queue pair to send data
65
66
  ret = self.engine.transfer_sync_write(
66
67
  session_id, buffer, peer_buffer_address, length
67
68
  )
@@ -35,29 +35,19 @@ logger = logging.getLogger(__name__)
35
35
  NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
36
36
 
37
37
 
38
- # From Mooncake backend.
39
38
  def group_concurrent_contiguous(
40
39
  src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
41
40
  ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
42
- src_groups = []
43
- dst_groups = []
44
- current_src = [src_indices[0]]
45
- current_dst = [dst_indices[0]]
46
-
47
- for i in range(1, len(src_indices)):
48
- src_contiguous = src_indices[i] == src_indices[i - 1] + 1
49
- dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
50
- if src_contiguous and dst_contiguous:
51
- current_src.append(src_indices[i])
52
- current_dst.append(dst_indices[i])
53
- else:
54
- src_groups.append(current_src)
55
- dst_groups.append(current_dst)
56
- current_src = [src_indices[i]]
57
- current_dst = [dst_indices[i]]
41
+ """Vectorised NumPy implementation."""
42
+ if src_indices.size == 0:
43
+ return [], []
44
+
45
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
46
+ src_groups = np.split(src_indices, brk)
47
+ dst_groups = np.split(dst_indices, brk)
58
48
 
59
- src_groups.append(current_src)
60
- dst_groups.append(current_dst)
49
+ src_groups = [g.tolist() for g in src_groups]
50
+ dst_groups = [g.tolist() for g in dst_groups]
61
51
 
62
52
  return src_groups, dst_groups
63
53
 
@@ -132,6 +122,7 @@ class NixlKVManager(BaseKVManager):
132
122
  args: KVArgs,
133
123
  disaggregation_mode: DisaggregationMode,
134
124
  server_args: ServerArgs,
125
+ is_mla_backend: Optional[bool] = False,
135
126
  ):
136
127
  try:
137
128
  from nixl._api import nixl_agent