sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -565,6 +565,7 @@ multimodal_model_archs = [
565
565
  "CLIPModel",
566
566
  "DeepseekVL2ForCausalLM",
567
567
  "Gemma3ForConditionalGeneration",
568
+ "Gemma3nForConditionalGeneration",
568
569
  "Grok1VForCausalLM",
569
570
  "Grok1AForCausalLM",
570
571
  "LlavaLlamaForCausalLM",
@@ -823,6 +823,7 @@ register_conv_template(
823
823
  sep_style=SeparatorStyle.GEMMA3,
824
824
  stop_str=["<end_of_turn>"],
825
825
  image_token="<start_of_image>",
826
+ audio_token="<start_of_audio>",
826
827
  )
827
828
  )
828
829
 
sglang/srt/custom_op.py CHANGED
@@ -1,11 +1,12 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
3
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
4
4
 
5
5
  _is_cuda = is_cuda()
6
6
  _is_hip = is_hip()
7
7
  _is_cpu = is_cpu()
8
8
  _is_cpu_amx_available = cpu_has_amx_support()
9
+ _is_npu = is_npu()
9
10
 
10
11
 
11
12
  class CustomOp(nn.Module):
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
60
61
  def forward_cuda(self, *args, **kwargs):
61
62
  raise NotImplementedError
62
63
 
64
+ def forward_npu(self, *args, **kwargs):
65
+ raise NotImplementedError
66
+
63
67
  def forward_hip(self, *args, **kwargs):
64
68
  return self.forward_cuda(*args, **kwargs)
65
69
 
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
79
83
  return self.forward_hip
80
84
  elif _is_cpu and _is_cpu_amx_available:
81
85
  return self.forward_cpu
86
+ elif _is_npu:
87
+ return self.forward_npu
82
88
  else:
83
89
  return self.forward_native
@@ -27,6 +27,8 @@ class KVArgs:
27
27
  decode_tp_size: int
28
28
  # for pp prefill
29
29
  prefill_pp_size: int
30
+ kv_head_num: int
31
+ page_size: int
30
32
 
31
33
 
32
34
  class KVPoll:
@@ -579,11 +579,11 @@ class DecodeTransferQueue:
579
579
  idx = decode_req.metadata_buffer_index
580
580
  (
581
581
  output_id,
582
- output_hidden_states,
583
582
  output_token_logprobs_val,
584
583
  output_token_logprobs_idx,
585
584
  output_top_logprobs_val,
586
585
  output_top_logprobs_idx,
586
+ output_hidden_states,
587
587
  ) = self.metadata_buffers.get_buf(idx)
588
588
 
589
589
  decode_req.req.output_ids.append(output_id[0].item())
@@ -103,6 +103,9 @@ class KVArgsRegisterInfo:
103
103
  mooncake_session_id: str
104
104
  dst_kv_ptrs: list[int]
105
105
  dst_aux_ptrs: list[int]
106
+ dst_tp_rank: int
107
+ dst_tp_size: int
108
+ dst_kv_item_len: int
106
109
 
107
110
  @classmethod
108
111
  def from_zmq(cls, msg: List[bytes]):
@@ -113,6 +116,9 @@ class KVArgsRegisterInfo:
113
116
  mooncake_session_id=msg[3].decode("ascii"),
114
117
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
115
118
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
119
+ dst_tp_rank=int(msg[6].decode("ascii")),
120
+ dst_tp_size=int(msg[7].decode("ascii")),
121
+ dst_kv_item_len=int(msg[8].decode("ascii")),
116
122
  )
117
123
 
118
124
 
@@ -181,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
181
187
  ).start()
182
188
 
183
189
  self.bootstrap_time_out = get_int_env_var(
184
- "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
190
+ "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
185
191
  )
186
192
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
187
193
  self.heartbeat_failures = {}
@@ -189,6 +195,8 @@ class MooncakeKVManager(BaseKVManager):
189
195
  self.session_pool_lock = threading.Lock()
190
196
  self.addr_to_rooms_tracker = defaultdict(set)
191
197
  self.connection_lock = threading.Lock()
198
+ self.required_prefill_response_num_table: Dict[int, int] = {}
199
+ self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
192
200
  # Heartbeat interval should be at least 2 seconds
193
201
  self.heartbeat_interval = max(
194
202
  float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
@@ -251,17 +259,19 @@ class MooncakeKVManager(BaseKVManager):
251
259
 
252
260
  # Worker function for processing a single layer
253
261
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
262
+ src_addr_list = []
263
+ dst_addr_list = []
264
+ length_list = []
254
265
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
255
266
  src_addr = src_ptr + int(prefill_index[0]) * item_len
256
267
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
257
268
  length = item_len * len(prefill_index)
258
-
259
- status = self.engine.transfer_sync(
260
- mooncake_session_id, src_addr, dst_addr, length
261
- )
262
- if status != 0:
263
- return status
264
- return 0
269
+ src_addr_list.append(src_addr)
270
+ dst_addr_list.append(dst_addr)
271
+ length_list.append(length)
272
+ return self.engine.batch_transfer_sync(
273
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
274
+ )
265
275
 
266
276
  futures = [
267
277
  executor.submit(
@@ -282,6 +292,162 @@ class MooncakeKVManager(BaseKVManager):
282
292
 
283
293
  return 0
284
294
 
295
+ def send_kvcache_slice(
296
+ self,
297
+ mooncake_session_id: str,
298
+ prefill_kv_indices: npt.NDArray[np.int64],
299
+ dst_kv_ptrs: list[int],
300
+ dst_kv_indices: npt.NDArray[np.int64],
301
+ dst_tp_rank: int,
302
+ dst_tp_size: int,
303
+ dst_kv_item_len: int,
304
+ executor: concurrent.futures.ThreadPoolExecutor,
305
+ ):
306
+ """
307
+ Sends KV cache slices from this Prefill rank to a target Decode rank,
308
+ supporting generic M-to-N TP size configurations.
309
+
310
+ NOTE: This implementation calls the transfer engine for each token slot within
311
+ each page to ensure correctness for any page_size and head-slicing configuration.
312
+ This may introduce performance overhead (increased TTFT) for long sequences.
313
+ """
314
+ # Extract configuration
315
+ local_tp_rank = self.kv_args.engine_rank
316
+ local_tp_size = self.tp_size // self.dp_size
317
+ num_kv_heads = self.kv_args.kv_head_num
318
+ num_layers = len(self.kv_args.kv_data_ptrs)
319
+ page_size = self.kv_args.page_size
320
+
321
+ # Calculate head distribution
322
+ heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
323
+ heads_per_prefill_rank = num_kv_heads
324
+ decode_global_head_start = dst_tp_rank * heads_per_decode_rank
325
+ prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
326
+ bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
327
+
328
+ decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
329
+
330
+ # Determine slicing parameters based on TP configuration
331
+ if local_tp_size > dst_tp_size:
332
+ src_head_offset = 0
333
+ num_heads_to_send = heads_per_prefill_rank
334
+ dst_head_offset = prefill_global_head_start - decode_global_head_start
335
+ else:
336
+ src_head_offset = decode_global_head_start - prefill_global_head_start
337
+ num_heads_to_send = heads_per_decode_rank
338
+ dst_head_offset = 0
339
+
340
+ layer_transfer_params = []
341
+ for layer_id in range(num_layers):
342
+ item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
343
+
344
+ # Page stride on the target dst decode rank for its slice pages
345
+ item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
346
+
347
+ if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
348
+ logger.error(
349
+ f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
350
+ )
351
+ return -1
352
+
353
+ # Calculate precise byte offset and length for the sub-slice within the prefill page data
354
+ src_slice_offset = src_head_offset * bytes_per_head
355
+ dst_slice_offset = dst_head_offset * bytes_per_head
356
+ slice_lens_per_page = num_heads_to_send * bytes_per_head
357
+
358
+ # Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
359
+ # This means slice_lens_per_page <= item_len_of_decode_rank_page
360
+ if slice_lens_per_page > item_len_of_decode_rank_page:
361
+ logger.error(
362
+ f"[{mooncake_session_id}] Layer {layer_id}: "
363
+ f"slice size ({slice_lens_per_page}) exceeds "
364
+ f"target page size ({item_len_of_decode_rank_page})"
365
+ )
366
+ return -1
367
+ layer_transfer_params.append(
368
+ (
369
+ self.kv_args.kv_data_ptrs[layer_id],
370
+ dst_kv_ptrs[layer_id],
371
+ item_len_of_prefill_rank_page,
372
+ item_len_of_decode_rank_page,
373
+ src_slice_offset,
374
+ dst_slice_offset,
375
+ slice_lens_per_page,
376
+ )
377
+ )
378
+
379
+ def process_layer_tp_aware(layer_params):
380
+ (
381
+ src_ptr,
382
+ dst_ptr,
383
+ src_item_len,
384
+ dst_item_len,
385
+ src_offset,
386
+ dst_offset,
387
+ slice_lens_per_page,
388
+ ) = layer_params
389
+ src_addr_list = []
390
+ dst_addr_list = []
391
+ length_list = []
392
+
393
+ # Calculate strides for a single token slot
394
+ bytes_per_token_on_prefill = src_item_len // page_size
395
+ bytes_per_token_on_decode = dst_item_len // page_size
396
+
397
+ for i in range(len(prefill_kv_indices)):
398
+ prefill_page_idx = int(prefill_kv_indices[i])
399
+ decode_page_idx = int(dst_kv_indices[i])
400
+
401
+ # Get the starting addresses for the current src and dst pages
402
+ src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
403
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
404
+
405
+ # Iterate through each valid token slot within the current page
406
+ for token_slot_in_page in range(page_size):
407
+ # Calculate the start address of the current token slot
408
+ src_token_slot_start_addr = (
409
+ src_page_start_addr
410
+ + token_slot_in_page * bytes_per_token_on_prefill
411
+ )
412
+ dst_token_slot_start_addr = (
413
+ dst_page_start_addr
414
+ + token_slot_in_page * bytes_per_token_on_decode
415
+ )
416
+
417
+ # Calculate final src and dst addresses by applying head-slice offsets
418
+ src_slice_addr = src_token_slot_start_addr + src_offset
419
+ dst_slice_addr = dst_token_slot_start_addr + dst_offset
420
+
421
+ src_addr_list.append(src_slice_addr)
422
+ dst_addr_list.append(dst_slice_addr)
423
+ length_list.append(slice_lens_per_page)
424
+
425
+ logger.debug(
426
+ f"SYNC: sid={mooncake_session_id}, "
427
+ f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
428
+ )
429
+
430
+ return self.engine.batch_transfer_sync(
431
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
432
+ )
433
+
434
+ futures = [
435
+ executor.submit(
436
+ process_layer_tp_aware,
437
+ layer_params,
438
+ )
439
+ for layer_params in layer_transfer_params
440
+ ]
441
+
442
+ for future in concurrent.futures.as_completed(futures):
443
+ status = future.result()
444
+ if status != 0:
445
+ for f in futures:
446
+ f.cancel()
447
+ return status
448
+
449
+ return 0
450
+
285
451
  def send_aux(
286
452
  self,
287
453
  mooncake_session_id: str,
@@ -289,18 +455,24 @@ class MooncakeKVManager(BaseKVManager):
289
455
  dst_aux_ptrs: list[int],
290
456
  dst_aux_index: int,
291
457
  ):
292
- aux_item_len = self.kv_args.aux_item_lens[0]
293
- prefill_aux_addr = (
294
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
295
- )
296
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
297
- status = self.engine.transfer_sync(
298
- mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
458
+ src_addr_list = []
459
+ dst_addr_list = []
460
+ length_list = []
461
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
462
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
463
+ for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
464
+ length = prefill_aux_item_lens[i]
465
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
466
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
467
+ src_addr_list.append(src_addr)
468
+ dst_addr_list.append(dst_addr)
469
+ length_list.append(length)
470
+ return self.engine.batch_transfer_sync(
471
+ mooncake_session_id, src_addr_list, dst_addr_list, length_list
299
472
  )
300
- return status
301
473
 
302
474
  def sync_status_to_decode_endpoint(
303
- self, remote: str, dst_port: int, room: int, status: int
475
+ self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
304
476
  ):
305
477
  if ":" in remote:
306
478
  remote = remote.split(":")[0]
@@ -308,6 +480,7 @@ class MooncakeKVManager(BaseKVManager):
308
480
  [
309
481
  str(room).encode("ascii"),
310
482
  str(status).encode("ascii"),
483
+ str(prefill_rank).encode("ascii"),
311
484
  ]
312
485
  )
313
486
 
@@ -324,6 +497,7 @@ class MooncakeKVManager(BaseKVManager):
324
497
  )
325
498
  polls = []
326
499
  dst_ranks_infos = []
500
+ local_rank = self.kv_args.engine_rank
327
501
  for req in reqs_to_be_processed:
328
502
  if not req.is_dummy:
329
503
  # Early exit if the request has failed
@@ -339,6 +513,7 @@ class MooncakeKVManager(BaseKVManager):
339
513
  req.dst_port,
340
514
  req.room,
341
515
  KVPoll.Failed,
516
+ local_rank,
342
517
  )
343
518
  break
344
519
 
@@ -356,15 +531,31 @@ class MooncakeKVManager(BaseKVManager):
356
531
  f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
357
532
  )
358
533
 
359
- ret = self.send_kvcache(
360
- req.mooncake_session_id,
361
- kv_chunk.prefill_kv_indices,
362
- self.decode_kv_args_table[
363
- req.mooncake_session_id
364
- ].dst_kv_ptrs,
365
- chunked_dst_kv_indice,
366
- executor,
534
+ target_rank_registration_info: KVArgsRegisterInfo = (
535
+ self.decode_kv_args_table[req.mooncake_session_id]
367
536
  )
537
+ local_tp_size = self.tp_size // self.dp_size
538
+ if self.is_mla_backend or (
539
+ local_tp_size == target_rank_registration_info.dst_tp_size
540
+ ):
541
+ ret = self.send_kvcache(
542
+ req.mooncake_session_id,
543
+ kv_chunk.prefill_kv_indices,
544
+ target_rank_registration_info.dst_kv_ptrs,
545
+ chunked_dst_kv_indice,
546
+ executor,
547
+ )
548
+ else:
549
+ ret = self.send_kvcache_slice(
550
+ req.mooncake_session_id,
551
+ kv_chunk.prefill_kv_indices,
552
+ target_rank_registration_info.dst_kv_ptrs,
553
+ chunked_dst_kv_indice,
554
+ target_rank_registration_info.dst_tp_rank,
555
+ target_rank_registration_info.dst_tp_size,
556
+ target_rank_registration_info.dst_kv_item_len,
557
+ executor,
558
+ )
368
559
  if ret != 0:
369
560
  with self.session_lock:
370
561
  self.session_failures[req.mooncake_session_id] += 1
@@ -380,7 +571,11 @@ class MooncakeKVManager(BaseKVManager):
380
571
  )
381
572
  self.update_status(kv_chunk.room, KVPoll.Failed)
382
573
  self.sync_status_to_decode_endpoint(
383
- req.endpoint, req.dst_port, req.room, KVPoll.Failed
574
+ req.endpoint,
575
+ req.dst_port,
576
+ req.room,
577
+ KVPoll.Failed,
578
+ local_rank,
384
579
  )
385
580
  break
386
581
 
@@ -389,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
389
584
  ret = self.send_aux(
390
585
  req.mooncake_session_id,
391
586
  kv_chunk.prefill_aux_index,
392
- self.decode_kv_args_table[
393
- req.mooncake_session_id
394
- ].dst_aux_ptrs,
587
+ target_rank_registration_info.dst_aux_ptrs,
395
588
  req.dst_aux_index,
396
589
  )
397
590
  polls.append(True if ret == 0 else False)
@@ -405,7 +598,7 @@ class MooncakeKVManager(BaseKVManager):
405
598
  self.update_status(req.room, status)
406
599
  for endpoint, dst_port, room in dst_ranks_infos:
407
600
  self.sync_status_to_decode_endpoint(
408
- endpoint, dst_port, room, status
601
+ endpoint, dst_port, room, status, local_rank
409
602
  )
410
603
  else:
411
604
  # Dummy request means the decode instance is not used, so its status can be marked as success directly
@@ -471,15 +664,33 @@ class MooncakeKVManager(BaseKVManager):
471
664
 
472
665
  def decode_thread():
473
666
  while True:
474
- (bootstrap_room, status) = self.server_socket.recv_multipart()
667
+ (bootstrap_room, status, prefill_rank) = (
668
+ self.server_socket.recv_multipart()
669
+ )
475
670
  status = int(status.decode("ascii"))
476
671
  bootstrap_room = int(bootstrap_room.decode("ascii"))
477
- if status == KVPoll.Failed:
672
+ prefill_rank = int(prefill_rank.decode("ascii"))
673
+
674
+ if status == KVPoll.Success:
675
+ if bootstrap_room in self.request_status:
676
+ self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
677
+ expected_response_num = (
678
+ self.required_prefill_response_num_table[bootstrap_room]
679
+ )
680
+ arrived_response_num = len(
681
+ self.prefill_response_tracker[bootstrap_room]
682
+ )
683
+ if (
684
+ self.is_mla_backend
685
+ or arrived_response_num == expected_response_num
686
+ ):
687
+ self.update_status(bootstrap_room, KVPoll.Success)
688
+ elif status == KVPoll.Failed:
478
689
  self.record_failure(
479
690
  bootstrap_room,
480
691
  f"Failed to get kvcache from prefill instance, it might be dead",
481
692
  )
482
- self.update_status(bootstrap_room, status)
693
+ self.update_status(bootstrap_room, status)
483
694
 
484
695
  def heartbeat_checker():
485
696
  while True:
@@ -686,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
686
897
  self.aux_index = None
687
898
  self.bootstrap_server_url = bootstrap_addr
688
899
  self.conclude_state = None
689
- self.init_time = None
900
+ self.init_time = time.time()
690
901
  # inner state
691
902
  self.curr_idx = 0
692
903
 
693
904
  def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
694
905
  self.num_kv_indices = num_kv_indices
695
906
  self.aux_index = aux_index
696
- self.init_time = time.time()
697
907
 
698
908
  def send(
699
909
  self,
@@ -705,7 +915,10 @@ class MooncakeKVSender(BaseKVSender):
705
915
 
706
916
  if not is_last:
707
917
  self.kv_mgr.add_transfer_request(
708
- self.bootstrap_room, kv_indices, index_slice, False
918
+ self.bootstrap_room,
919
+ kv_indices,
920
+ index_slice,
921
+ False,
709
922
  )
710
923
  else:
711
924
  self.kv_mgr.add_transfer_request(
@@ -814,23 +1027,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
814
1027
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
815
1028
  )
816
1029
  self.required_dst_info_num = 1
1030
+ self.required_prefill_response_num = 1
817
1031
  self.target_tp_ranks = [self.target_tp_rank]
818
1032
  elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
819
- assert (
820
- self.kv_mgr.is_mla_backend
821
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
1033
+ if not self.kv_mgr.is_mla_backend:
1034
+ logger.warning_once(
1035
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1036
+ )
822
1037
  self.target_tp_rank = (
823
1038
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
824
1039
  ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
825
1040
  self.required_dst_info_num = (
826
1041
  local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
827
1042
  )
1043
+ self.required_prefill_response_num = 1
828
1044
  self.target_tp_ranks = [self.target_tp_rank]
829
1045
  else:
830
- assert (
831
- self.kv_mgr.is_mla_backend
832
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
833
-
1046
+ if not self.kv_mgr.is_mla_backend:
1047
+ logger.warning_once(
1048
+ "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
1049
+ )
834
1050
  # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
835
1051
  self.target_tp_ranks = [
836
1052
  rank
@@ -847,6 +1063,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
847
1063
  # or the KVPoll will never be set correctly
848
1064
  self.target_tp_rank = self.target_tp_ranks[0]
849
1065
  self.required_dst_info_num = 1
1066
+ self.required_prefill_response_num = (
1067
+ prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
1068
+ )
850
1069
 
851
1070
  if self.data_parallel_rank is not None:
852
1071
  logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
@@ -854,6 +1073,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
854
1073
  else:
855
1074
  self.target_dp_group = bootstrap_room % self.prefill_dp_size
856
1075
 
1076
+ self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1077
+ self.required_prefill_response_num
1078
+ )
857
1079
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
858
1080
  bootstrap_key = (
859
1081
  f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
@@ -867,11 +1089,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
867
1089
  self.target_dp_group,
868
1090
  )
869
1091
  if bootstrap_info is not None:
870
- # NOTE: only support MLA for now: select one prefill rank as real rank
871
- bootstrap_info["is_dummy"] = not bool(
872
- target_tp_rank == self.target_tp_rank
873
- or self.target_tp_rank is None
874
- )
1092
+ if self.kv_mgr.is_mla_backend:
1093
+ # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
1094
+ bootstrap_info["is_dummy"] = not bool(
1095
+ target_tp_rank == self.target_tp_rank
1096
+ or self.target_tp_rank is None
1097
+ )
1098
+ else:
1099
+ # For non-MLA: all target_tp_ranks are selected real ranks
1100
+ bootstrap_info["is_dummy"] = False
875
1101
  logger.debug(
876
1102
  f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
877
1103
  )
@@ -943,6 +1169,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
943
1169
  packed_aux_data_ptrs = b"".join(
944
1170
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
945
1171
  )
1172
+ tp_rank = self.kv_mgr.kv_args.engine_rank
1173
+ tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
1174
+ kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
1175
+ dst_tp_rank = str(tp_rank).encode("ascii")
1176
+ dst_tp_size = str(tp_size).encode("ascii")
1177
+ dst_kv_item_len = str(kv_item_len).encode("ascii")
946
1178
 
947
1179
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
948
1180
  with lock:
@@ -954,6 +1186,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
954
1186
  self.session_id.encode("ascii"),
955
1187
  packed_kv_data_ptrs,
956
1188
  packed_aux_data_ptrs,
1189
+ dst_tp_rank,
1190
+ dst_tp_size,
1191
+ dst_kv_item_len,
957
1192
  ]
958
1193
  )
959
1194
 
@@ -1002,6 +1237,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
1002
1237
  if self.bootstrap_room in self.kv_mgr.request_status:
1003
1238
  self.kv_mgr.request_status.pop(self.bootstrap_room)
1004
1239
 
1240
+ if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
1241
+ self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
1242
+
1243
+ if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
1244
+ self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
1245
+
1005
1246
  def failure_exception(self):
1006
1247
  # Explicitly set the status to failure since this request has failed in another rank
1007
1248
  if self.conclude_state is None:
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from dataclasses import dataclass
4
- from typing import Optional
4
+ from typing import List, Optional
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -90,5 +90,35 @@ class MooncakeTransferEngine:
90
90
 
91
91
  return ret
92
92
 
93
+ def batch_transfer_sync(
94
+ self,
95
+ session_id: str,
96
+ buffers: List[int],
97
+ peer_buffer_addresses: List[int],
98
+ lengths: List[int],
99
+ ) -> int:
100
+ """Synchronously transfer data to the specified addresses in batches."""
101
+ try:
102
+ ret = self.engine.batch_transfer_sync_write(
103
+ session_id, buffers, peer_buffer_addresses, lengths
104
+ )
105
+ except Exception:
106
+ ret = -1
107
+ # Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
108
+ if not hasattr(self.engine, "batch_transfer_sync_write"):
109
+ raise RuntimeError(
110
+ "Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
111
+ "Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
112
+ )
113
+
114
+ if ret < 0:
115
+ logger.debug(
116
+ "Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
117
+ buffers,
118
+ session_id,
119
+ peer_buffer_addresses,
120
+ )
121
+ return ret
122
+
93
123
  def get_session_id(self):
94
124
  return self.session_id