sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import (
39
- get_free_port,
40
- get_int_env_var,
41
- get_ip,
42
- get_local_ip_by_remote,
43
- )
38
+ from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
44
39
 
45
40
  logger = logging.getLogger(__name__)
46
41
 
@@ -108,6 +103,9 @@ class KVArgsRegisterInfo:
108
103
  mooncake_session_id: str
109
104
  dst_kv_ptrs: list[int]
110
105
  dst_aux_ptrs: list[int]
106
+ dst_tp_rank: int
107
+ dst_tp_size: int
108
+ dst_kv_item_len: int
111
109
 
112
110
  @classmethod
113
111
  def from_zmq(cls, msg: List[bytes]):
@@ -118,6 +116,9 @@ class KVArgsRegisterInfo:
118
116
  mooncake_session_id=msg[3].decode("ascii"),
119
117
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
120
118
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
119
+ dst_tp_rank=int(msg[6].decode("ascii")),
120
+ dst_tp_size=int(msg[7].decode("ascii")),
121
+ dst_kv_item_len=int(msg[8].decode("ascii")),
121
122
  )
122
123
 
123
124
 
@@ -130,8 +131,9 @@ class MooncakeKVManager(BaseKVManager):
130
131
  is_mla_backend: Optional[bool] = False,
131
132
  ):
132
133
  self.kv_args = args
134
+ self.local_ip = get_local_ip_auto()
133
135
  self.engine = MooncakeTransferEngine(
134
- hostname=get_local_ip_by_remote(),
136
+ hostname=self.local_ip,
135
137
  gpu_id=self.kv_args.gpu_id,
136
138
  ib_device=self.kv_args.ib_device,
137
139
  )
@@ -185,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
185
187
  ).start()
186
188
 
187
189
  self.bootstrap_time_out = get_int_env_var(
188
- "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
190
+ "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
189
191
  )
190
192
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
191
193
  self.heartbeat_failures = {}
@@ -193,6 +195,8 @@ class MooncakeKVManager(BaseKVManager):
193
195
  self.session_pool_lock = threading.Lock()
194
196
  self.addr_to_rooms_tracker = defaultdict(set)
195
197
  self.connection_lock = threading.Lock()
198
+ self.required_prefill_response_num_table: Dict[int, int] = {}
199
+ self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
196
200
  # Heartbeat interval should be at least 2 seconds
197
201
  self.heartbeat_interval = max(
198
202
  float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
@@ -255,17 +259,19 @@ class MooncakeKVManager(BaseKVManager):
255
259
 
256
260
  # Worker function for processing a single layer
257
261
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
262
+ src_addr_list = []
263
+ dst_addr_list = []
264
+ length_list = []
258
265
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
259
266
  src_addr = src_ptr + int(prefill_index[0]) * item_len
260
267
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
261
268
  length = item_len * len(prefill_index)
262
-
263
- status = self.engine.transfer_sync(
264
- mooncake_session_id, src_addr, dst_addr, length
265
- )
266
- if status != 0:
267
- return status
268
- return 0
269
+ src_addr_list.append(src_addr)
270
+ dst_addr_list.append(dst_addr)
271
+ length_list.append(length)
272
+ return self.engine.batch_transfer_sync(
273
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
274
+ )
269
275
 
270
276
  futures = [
271
277
  executor.submit(
@@ -286,6 +292,162 @@ class MooncakeKVManager(BaseKVManager):
286
292
 
287
293
  return 0
288
294
 
295
+ def send_kvcache_slice(
296
+ self,
297
+ mooncake_session_id: str,
298
+ prefill_kv_indices: npt.NDArray[np.int64],
299
+ dst_kv_ptrs: list[int],
300
+ dst_kv_indices: npt.NDArray[np.int64],
301
+ dst_tp_rank: int,
302
+ dst_tp_size: int,
303
+ dst_kv_item_len: int,
304
+ executor: concurrent.futures.ThreadPoolExecutor,
305
+ ):
306
+ """
307
+ Sends KV cache slices from this Prefill rank to a target Decode rank,
308
+ supporting generic M-to-N TP size configurations.
309
+
310
+ NOTE: This implementation calls the transfer engine for each token slot within
311
+ each page to ensure correctness for any page_size and head-slicing configuration.
312
+ This may introduce performance overhead (increased TTFT) for long sequences.
313
+ """
314
+ # Extract configuration
315
+ local_tp_rank = self.kv_args.engine_rank
316
+ local_tp_size = self.tp_size // self.dp_size
317
+ num_kv_heads = self.kv_args.kv_head_num
318
+ num_layers = len(self.kv_args.kv_data_ptrs)
319
+ page_size = self.kv_args.page_size
320
+
321
+ # Calculate head distribution
322
+ heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
323
+ heads_per_prefill_rank = num_kv_heads
324
+ decode_global_head_start = dst_tp_rank * heads_per_decode_rank
325
+ prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
326
+ bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
327
+
328
+ decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
329
+
330
+ # Determine slicing parameters based on TP configuration
331
+ if local_tp_size > dst_tp_size:
332
+ src_head_offset = 0
333
+ num_heads_to_send = heads_per_prefill_rank
334
+ dst_head_offset = prefill_global_head_start - decode_global_head_start
335
+ else:
336
+ src_head_offset = decode_global_head_start - prefill_global_head_start
337
+ num_heads_to_send = heads_per_decode_rank
338
+ dst_head_offset = 0
339
+
340
+ layer_transfer_params = []
341
+ for layer_id in range(num_layers):
342
+ item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
343
+
344
+ # Page stride on the target dst decode rank for its slice pages
345
+ item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
346
+
347
+ if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
348
+ logger.error(
349
+ f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
350
+ )
351
+ return -1
352
+
353
+ # Calculate precise byte offset and length for the sub-slice within the prefill page data
354
+ src_slice_offset = src_head_offset * bytes_per_head
355
+ dst_slice_offset = dst_head_offset * bytes_per_head
356
+ slice_lens_per_page = num_heads_to_send * bytes_per_head
357
+
358
+ # Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
359
+ # This means slice_lens_per_page <= item_len_of_decode_rank_page
360
+ if slice_lens_per_page > item_len_of_decode_rank_page:
361
+ logger.error(
362
+ f"[{mooncake_session_id}] Layer {layer_id}: "
363
+ f"slice size ({slice_lens_per_page}) exceeds "
364
+ f"target page size ({item_len_of_decode_rank_page})"
365
+ )
366
+ return -1
367
+ layer_transfer_params.append(
368
+ (
369
+ self.kv_args.kv_data_ptrs[layer_id],
370
+ dst_kv_ptrs[layer_id],
371
+ item_len_of_prefill_rank_page,
372
+ item_len_of_decode_rank_page,
373
+ src_slice_offset,
374
+ dst_slice_offset,
375
+ slice_lens_per_page,
376
+ )
377
+ )
378
+
379
+ def process_layer_tp_aware(layer_params):
380
+ (
381
+ src_ptr,
382
+ dst_ptr,
383
+ src_item_len,
384
+ dst_item_len,
385
+ src_offset,
386
+ dst_offset,
387
+ slice_lens_per_page,
388
+ ) = layer_params
389
+ src_addr_list = []
390
+ dst_addr_list = []
391
+ length_list = []
392
+
393
+ # Calculate strides for a single token slot
394
+ bytes_per_token_on_prefill = src_item_len // page_size
395
+ bytes_per_token_on_decode = dst_item_len // page_size
396
+
397
+ for i in range(len(prefill_kv_indices)):
398
+ prefill_page_idx = int(prefill_kv_indices[i])
399
+ decode_page_idx = int(dst_kv_indices[i])
400
+
401
+ # Get the starting addresses for the current src and dst pages
402
+ src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
403
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
404
+
405
+ # Iterate through each valid token slot within the current page
406
+ for token_slot_in_page in range(page_size):
407
+ # Calculate the start address of the current token slot
408
+ src_token_slot_start_addr = (
409
+ src_page_start_addr
410
+ + token_slot_in_page * bytes_per_token_on_prefill
411
+ )
412
+ dst_token_slot_start_addr = (
413
+ dst_page_start_addr
414
+ + token_slot_in_page * bytes_per_token_on_decode
415
+ )
416
+
417
+ # Calculate final src and dst addresses by applying head-slice offsets
418
+ src_slice_addr = src_token_slot_start_addr + src_offset
419
+ dst_slice_addr = dst_token_slot_start_addr + dst_offset
420
+
421
+ src_addr_list.append(src_slice_addr)
422
+ dst_addr_list.append(dst_slice_addr)
423
+ length_list.append(slice_lens_per_page)
424
+
425
+ logger.debug(
426
+ f"SYNC: sid={mooncake_session_id}, "
427
+ f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
428
+ )
429
+
430
+ return self.engine.batch_transfer_sync(
431
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
432
+ )
433
+
434
+ futures = [
435
+ executor.submit(
436
+ process_layer_tp_aware,
437
+ layer_params,
438
+ )
439
+ for layer_params in layer_transfer_params
440
+ ]
441
+
442
+ for future in concurrent.futures.as_completed(futures):
443
+ status = future.result()
444
+ if status != 0:
445
+ for f in futures:
446
+ f.cancel()
447
+ return status
448
+
449
+ return 0
450
+
289
451
  def send_aux(
290
452
  self,
291
453
  mooncake_session_id: str,
@@ -293,18 +455,24 @@ class MooncakeKVManager(BaseKVManager):
293
455
  dst_aux_ptrs: list[int],
294
456
  dst_aux_index: int,
295
457
  ):
296
- aux_item_len = self.kv_args.aux_item_lens[0]
297
- prefill_aux_addr = (
298
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
299
- )
300
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
301
- status = self.engine.transfer_sync(
302
- mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
458
+ src_addr_list = []
459
+ dst_addr_list = []
460
+ length_list = []
461
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
462
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
463
+ for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
464
+ length = prefill_aux_item_lens[i]
465
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
466
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
467
+ src_addr_list.append(src_addr)
468
+ dst_addr_list.append(dst_addr)
469
+ length_list.append(length)
470
+ return self.engine.batch_transfer_sync(
471
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
303
472
  )
304
- return status
305
473
 
306
474
  def sync_status_to_decode_endpoint(
307
- self, remote: str, dst_port: int, room: int, status: int
475
+ self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
308
476
  ):
309
477
  if ":" in remote:
310
478
  remote = remote.split(":")[0]
@@ -312,6 +480,7 @@ class MooncakeKVManager(BaseKVManager):
312
480
  [
313
481
  str(room).encode("ascii"),
314
482
  str(status).encode("ascii"),
483
+ str(prefill_rank).encode("ascii"),
315
484
  ]
316
485
  )
317
486
 
@@ -328,6 +497,7 @@ class MooncakeKVManager(BaseKVManager):
328
497
  )
329
498
  polls = []
330
499
  dst_ranks_infos = []
500
+ local_rank = self.kv_args.engine_rank
331
501
  for req in reqs_to_be_processed:
332
502
  if not req.is_dummy:
333
503
  # Early exit if the request has failed
@@ -343,6 +513,7 @@ class MooncakeKVManager(BaseKVManager):
343
513
  req.dst_port,
344
514
  req.room,
345
515
  KVPoll.Failed,
516
+ local_rank,
346
517
  )
347
518
  break
348
519
 
@@ -360,15 +531,31 @@ class MooncakeKVManager(BaseKVManager):
360
531
  f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
361
532
  )
362
533
 
363
- ret = self.send_kvcache(
364
- req.mooncake_session_id,
365
- kv_chunk.prefill_kv_indices,
366
- self.decode_kv_args_table[
367
- req.mooncake_session_id
368
- ].dst_kv_ptrs,
369
- chunked_dst_kv_indice,
370
- executor,
534
+ target_rank_registration_info: KVArgsRegisterInfo = (
535
+ self.decode_kv_args_table[req.mooncake_session_id]
371
536
  )
537
+ local_tp_size = self.tp_size // self.dp_size
538
+ if self.is_mla_backend or (
539
+ local_tp_size == target_rank_registration_info.dst_tp_size
540
+ ):
541
+ ret = self.send_kvcache(
542
+ req.mooncake_session_id,
543
+ kv_chunk.prefill_kv_indices,
544
+ target_rank_registration_info.dst_kv_ptrs,
545
+ chunked_dst_kv_indice,
546
+ executor,
547
+ )
548
+ else:
549
+ ret = self.send_kvcache_slice(
550
+ req.mooncake_session_id,
551
+ kv_chunk.prefill_kv_indices,
552
+ target_rank_registration_info.dst_kv_ptrs,
553
+ chunked_dst_kv_indice,
554
+ target_rank_registration_info.dst_tp_rank,
555
+ target_rank_registration_info.dst_tp_size,
556
+ target_rank_registration_info.dst_kv_item_len,
557
+ executor,
558
+ )
372
559
  if ret != 0:
373
560
  with self.session_lock:
374
561
  self.session_failures[req.mooncake_session_id] += 1
@@ -384,7 +571,11 @@ class MooncakeKVManager(BaseKVManager):
384
571
  )
385
572
  self.update_status(kv_chunk.room, KVPoll.Failed)
386
573
  self.sync_status_to_decode_endpoint(
387
- req.endpoint, req.dst_port, req.room, KVPoll.Failed
574
+ req.endpoint,
575
+ req.dst_port,
576
+ req.room,
577
+ KVPoll.Failed,
578
+ local_rank,
388
579
  )
389
580
  break
390
581
 
@@ -393,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
393
584
  ret = self.send_aux(
394
585
  req.mooncake_session_id,
395
586
  kv_chunk.prefill_aux_index,
396
- self.decode_kv_args_table[
397
- req.mooncake_session_id
398
- ].dst_aux_ptrs,
587
+ target_rank_registration_info.dst_aux_ptrs,
399
588
  req.dst_aux_index,
400
589
  )
401
590
  polls.append(True if ret == 0 else False)
@@ -409,7 +598,7 @@ class MooncakeKVManager(BaseKVManager):
409
598
  self.update_status(req.room, status)
410
599
  for endpoint, dst_port, room in dst_ranks_infos:
411
600
  self.sync_status_to_decode_endpoint(
412
- endpoint, dst_port, room, status
601
+ endpoint, dst_port, room, status, local_rank
413
602
  )
414
603
  else:
415
604
  # Dummy request means the decode instance is not used, so its status can be marked as success directly
@@ -432,7 +621,7 @@ class MooncakeKVManager(BaseKVManager):
432
621
 
433
622
  def start_prefill_thread(self):
434
623
  self.rank_port = get_free_port()
435
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
624
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
436
625
 
437
626
  def bootstrap_thread():
438
627
  """This thread recvs pre-alloc notification from the decode engine"""
@@ -471,19 +660,37 @@ class MooncakeKVManager(BaseKVManager):
471
660
 
472
661
  def start_decode_thread(self):
473
662
  self.rank_port = get_free_port()
474
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
663
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
475
664
 
476
665
  def decode_thread():
477
666
  while True:
478
- (bootstrap_room, status) = self.server_socket.recv_multipart()
667
+ (bootstrap_room, status, prefill_rank) = (
668
+ self.server_socket.recv_multipart()
669
+ )
479
670
  status = int(status.decode("ascii"))
480
671
  bootstrap_room = int(bootstrap_room.decode("ascii"))
481
- if status == KVPoll.Failed:
672
+ prefill_rank = int(prefill_rank.decode("ascii"))
673
+
674
+ if status == KVPoll.Success:
675
+ if bootstrap_room in self.request_status:
676
+ self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
677
+ expected_response_num = (
678
+ self.required_prefill_response_num_table[bootstrap_room]
679
+ )
680
+ arrived_response_num = len(
681
+ self.prefill_response_tracker[bootstrap_room]
682
+ )
683
+ if (
684
+ self.is_mla_backend
685
+ or arrived_response_num == expected_response_num
686
+ ):
687
+ self.update_status(bootstrap_room, KVPoll.Success)
688
+ elif status == KVPoll.Failed:
482
689
  self.record_failure(
483
690
  bootstrap_room,
484
691
  f"Failed to get kvcache from prefill instance, it might be dead",
485
692
  )
486
- self.update_status(bootstrap_room, status)
693
+ self.update_status(bootstrap_room, status)
487
694
 
488
695
  def heartbeat_checker():
489
696
  while True:
@@ -620,7 +827,7 @@ class MooncakeKVManager(BaseKVManager):
620
827
  "role": "Prefill",
621
828
  "tp_size": self.tp_size,
622
829
  "dp_size": self.dp_size,
623
- "rank_ip": get_local_ip_by_remote(),
830
+ "rank_ip": self.local_ip,
624
831
  "rank_port": self.rank_port,
625
832
  "engine_rank": self.kv_args.engine_rank,
626
833
  }
@@ -690,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
690
897
  self.aux_index = None
691
898
  self.bootstrap_server_url = bootstrap_addr
692
899
  self.conclude_state = None
693
- self.init_time = None
900
+ self.init_time = time.time()
694
901
  # inner state
695
902
  self.curr_idx = 0
696
903
 
697
904
  def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
698
905
  self.num_kv_indices = num_kv_indices
699
906
  self.aux_index = aux_index
700
- self.init_time = time.time()
701
907
 
702
908
  def send(
703
909
  self,
@@ -709,7 +915,10 @@ class MooncakeKVSender(BaseKVSender):
709
915
 
710
916
  if not is_last:
711
917
  self.kv_mgr.add_transfer_request(
712
- self.bootstrap_room, kv_indices, index_slice, False
918
+ self.bootstrap_room,
919
+ kv_indices,
920
+ index_slice,
921
+ False,
713
922
  )
714
923
  else:
715
924
  self.kv_mgr.add_transfer_request(
@@ -746,12 +955,12 @@ class MooncakeKVSender(BaseKVSender):
746
955
  self.kv_mgr.request_status.pop(self.bootstrap_room)
747
956
 
748
957
  def failure_exception(self):
749
- self.clear()
750
-
751
958
  # Explicitly set the status to failure since this request has failed in another rank
752
959
  if self.conclude_state is None:
753
960
  self.conclude_state = KVPoll.Failed
754
961
 
962
+ self.clear()
963
+
755
964
  with self.kv_mgr.failure_lock:
756
965
  failure_reason = self.kv_mgr.failure_records.pop(
757
966
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -818,23 +1027,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
818
1027
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
819
1028
  )
820
1029
  self.required_dst_info_num = 1
1030
+ self.required_prefill_response_num = 1
821
1031
  self.target_tp_ranks = [self.target_tp_rank]
822
1032
  elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
823
- assert (
824
- self.kv_mgr.is_mla_backend
825
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
1033
+ if not self.kv_mgr.is_mla_backend:
1034
+ logger.warning_once(
1035
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1036
+ )
826
1037
  self.target_tp_rank = (
827
1038
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
828
1039
  ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
829
1040
  self.required_dst_info_num = (
830
1041
  local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
831
1042
  )
1043
+ self.required_prefill_response_num = 1
832
1044
  self.target_tp_ranks = [self.target_tp_rank]
833
1045
  else:
834
- assert (
835
- self.kv_mgr.is_mla_backend
836
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
837
-
1046
+ if not self.kv_mgr.is_mla_backend:
1047
+ logger.warning_once(
1048
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1049
+ )
838
1050
  # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
839
1051
  self.target_tp_ranks = [
840
1052
  rank
@@ -851,6 +1063,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
851
1063
  # or the KVPoll will never be set correctly
852
1064
  self.target_tp_rank = self.target_tp_ranks[0]
853
1065
  self.required_dst_info_num = 1
1066
+ self.required_prefill_response_num = (
1067
+ prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
1068
+ )
854
1069
 
855
1070
  if self.data_parallel_rank is not None:
856
1071
  logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
@@ -858,6 +1073,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
858
1073
  else:
859
1074
  self.target_dp_group = bootstrap_room % self.prefill_dp_size
860
1075
 
1076
+ self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1077
+ self.required_prefill_response_num
1078
+ )
861
1079
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
862
1080
  bootstrap_key = (
863
1081
  f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
@@ -871,11 +1089,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
871
1089
  self.target_dp_group,
872
1090
  )
873
1091
  if bootstrap_info is not None:
874
- # NOTE: only support MLA for now: select one prefill rank as real rank
875
- bootstrap_info["is_dummy"] = not bool(
876
- target_tp_rank == self.target_tp_rank
877
- or self.target_tp_rank is None
878
- )
1092
+ if self.kv_mgr.is_mla_backend:
1093
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1094
+ bootstrap_info["is_dummy"] = not bool(
1095
+ target_tp_rank == self.target_tp_rank
1096
+ or self.target_tp_rank is None
1097
+ )
1098
+ else:
1099
+ # For non-MLA: all target_tp_ranks are selected real ranks
1100
+ bootstrap_info["is_dummy"] = False
879
1101
  logger.debug(
880
1102
  f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
881
1103
  )
@@ -947,17 +1169,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
947
1169
  packed_aux_data_ptrs = b"".join(
948
1170
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
949
1171
  )
1172
+ tp_rank = self.kv_mgr.kv_args.engine_rank
1173
+ tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1174
+ kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
1175
+ dst_tp_rank = str(tp_rank).encode("ascii")
1176
+ dst_tp_size = str(tp_size).encode("ascii")
1177
+ dst_kv_item_len = str(kv_item_len).encode("ascii")
950
1178
 
951
1179
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
952
1180
  with lock:
953
1181
  sock.send_multipart(
954
1182
  [
955
1183
  "None".encode("ascii"),
956
- get_local_ip_by_remote().encode("ascii"),
1184
+ self.kv_mgr.local_ip.encode("ascii"),
957
1185
  str(self.kv_mgr.rank_port).encode("ascii"),
958
1186
  self.session_id.encode("ascii"),
959
1187
  packed_kv_data_ptrs,
960
1188
  packed_aux_data_ptrs,
1189
+ dst_tp_rank,
1190
+ dst_tp_size,
1191
+ dst_kv_item_len,
961
1192
  ]
962
1193
  )
963
1194
 
@@ -983,7 +1214,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
983
1214
  sock.send_multipart(
984
1215
  [
985
1216
  str(self.bootstrap_room).encode("ascii"),
986
- get_local_ip_by_remote().encode("ascii"),
1217
+ self.kv_mgr.local_ip.encode("ascii"),
987
1218
  str(self.kv_mgr.rank_port).encode("ascii"),
988
1219
  self.session_id.encode("ascii"),
989
1220
  kv_indices.tobytes() if not is_dummy else b"",
@@ -1006,13 +1237,19 @@ class MooncakeKVReceiver(BaseKVReceiver):
1006
1237
  if self.bootstrap_room in self.kv_mgr.request_status:
1007
1238
  self.kv_mgr.request_status.pop(self.bootstrap_room)
1008
1239
 
1009
- def failure_exception(self):
1010
- self.clear()
1240
+ if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
1241
+ self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
1242
+
1243
+ if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
1244
+ self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
1011
1245
 
1246
+ def failure_exception(self):
1012
1247
  # Explicitly set the status to failure since this request has failed in another rank
1013
1248
  if self.conclude_state is None:
1014
1249
  self.conclude_state = KVPoll.Failed
1015
1250
 
1251
+ self.clear()
1252
+
1016
1253
  with self.kv_mgr.failure_lock:
1017
1254
  failure_reason = self.kv_mgr.failure_records.pop(
1018
1255
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from dataclasses import dataclass
4
- from typing import Optional
4
+ from typing import List, Optional
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -90,5 +90,35 @@ class MooncakeTransferEngine:
90
90
 
91
91
  return ret
92
92
 
93
+ def batch_transfer_sync(
94
+ self,
95
+ session_id: str,
96
+ buffers: List[int],
97
+ peer_buffer_addresses: List[int],
98
+ lengths: List[int],
99
+ ) -> int:
100
+ """Synchronously transfer data to the specified addresses in batches."""
101
+ try:
102
+ ret = self.engine.batch_transfer_sync_write(
103
+ session_id, buffers, peer_buffer_addresses, lengths
104
+ )
105
+ except Exception:
106
+ ret = -1
107
+ # Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
108
+ if not hasattr(self.engine, "batch_transfer_sync_write"):
109
+ raise RuntimeError(
110
+ "Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
111
+ "Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
112
+ )
113
+
114
+ if ret < 0:
115
+ logger.debug(
116
+ "Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
117
+ buffers,
118
+ session_id,
119
+ peer_buffer_addresses,
120
+ )
121
+ return ret
122
+
93
123
  def get_session_id(self):
94
124
  return self.session_id