sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
137
137
  kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
138
138
  kv_receiver = kv_receiver_class(
139
139
  mgr=self.kv_manager,
140
- bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
140
+ bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
141
141
  bootstrap_room=req.bootstrap_room,
142
142
  )
143
143
  self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
@@ -307,7 +307,7 @@ class DecodeTransferQueue:
307
307
  def extend(self, req_conns) -> None:
308
308
  self.queue.extend(req_conns)
309
309
 
310
- def pop_transferred(self) -> List[Req]:
310
+ def pop_transferred(self) -> List[DecodeRequest]:
311
311
  if not self.queue:
312
312
  return []
313
313
 
@@ -330,7 +330,7 @@ class DecodeTransferQueue:
330
330
  assert len(decode_req.req.output_ids) == 0
331
331
  assert decode_req.req.transferred_output_id is None
332
332
  decode_req.req.transferred_output_id = output_id
333
- transferred_reqs.append(decode_req.req)
333
+ transferred_reqs.append(decode_req)
334
334
  indices_to_remove.add(i)
335
335
  elif poll in [
336
336
  KVPoll.Bootstrapping,
@@ -444,8 +444,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
444
444
 
445
445
  class SchedulerDisaggregationDecodeMixin:
446
446
 
447
+ def _prepare_idle_batch_and_run(self, batch, delay_process=False):
448
+ batch, _ = self.prepare_dp_attn_batch(batch)
449
+ result = None
450
+ if batch:
451
+ result = self.run_batch(batch)
452
+ if not delay_process:
453
+ self.process_batch_result(batch, result)
454
+ return batch, result
455
+
447
456
  @torch.no_grad()
448
- def event_loop_normal_disagg_decode(self):
457
+ def event_loop_normal_disagg_decode(self: Scheduler):
449
458
  """A normal scheduler loop for decode worker in disaggregation mode."""
450
459
 
451
460
  while True:
@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
456
465
  batch = self.get_next_disagg_decode_batch_to_run()
457
466
  self.cur_batch = batch
458
467
 
468
+ prepare_dp_attn_flag = (
469
+ self.server_args.enable_dp_attention
470
+ or self.server_args.enable_sp_layernorm
471
+ )
472
+
459
473
  if batch:
460
474
  # Generate fake extend output.
461
475
  if batch.forward_mode.is_extend():
462
476
  # Note: Logprobs should be handled on the prefill engine.
463
477
  self.stream_output(batch.reqs, False)
478
+ if prepare_dp_attn_flag:
479
+ self._prepare_idle_batch_and_run(None)
464
480
  else:
481
+ if prepare_dp_attn_flag:
482
+ self.prepare_dp_attn_batch(batch)
465
483
  result = self.run_batch(batch)
466
484
  self.process_batch_result(batch, result)
485
+ elif prepare_dp_attn_flag:
486
+ batch, _ = self._prepare_idle_batch_and_run(None)
467
487
 
468
488
  if batch is None and (
469
489
  len(self.disagg_decode_transfer_queue.queue)
@@ -477,10 +497,10 @@ class SchedulerDisaggregationDecodeMixin:
477
497
  self.last_batch = batch
478
498
 
479
499
  @torch.no_grad()
480
- def event_loop_overlap_disagg_decode(self):
500
+ def event_loop_overlap_disagg_decode(self: Scheduler):
481
501
  result_queue = deque()
482
502
  self.last_batch: Optional[ScheduleBatch] = None
483
- self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
503
+ self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
484
504
 
485
505
  while True:
486
506
  recv_reqs = self.recv_requests()
@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin:
489
509
  self.process_decode_queue()
490
510
  batch = self.get_next_disagg_decode_batch_to_run()
491
511
  self.cur_batch = batch
492
- last_batch_is_extend = False
512
+ last_batch_in_queue = False
513
+
514
+ prepare_dp_attn_flag = (
515
+ self.server_args.enable_dp_attention
516
+ or self.server_args.enable_sp_layernorm
517
+ )
493
518
 
494
519
  if batch:
495
520
  # Generate fake extend output.
496
521
  if batch.forward_mode.is_extend():
497
522
  # Note: Logprobs should be handled on the prefill engine.
498
523
  self.stream_output(batch.reqs, False)
499
- last_batch_is_extend = True
524
+ if prepare_dp_attn_flag:
525
+ batch_, result = self._prepare_idle_batch_and_run(
526
+ None, delay_process=True
527
+ )
528
+ if batch_:
529
+ result_queue.append((batch_.copy(), result))
530
+ last_batch_in_queue = True
500
531
  else:
532
+ if prepare_dp_attn_flag:
533
+ self.prepare_dp_attn_batch(batch)
501
534
  result = self.run_batch(batch)
502
535
  result_queue.append((batch.copy(), result))
536
+ last_batch_in_queue = True
537
+ elif prepare_dp_attn_flag:
538
+ batch, result = self._prepare_idle_batch_and_run(
539
+ None, delay_process=True
540
+ )
541
+ if batch:
542
+ result_queue.append((batch.copy(), result))
543
+ last_batch_in_queue = True
503
544
 
504
545
  # Process the results of the previous batch but skip if the last batch is extend
505
- if self.last_batch and not self.last_batch_is_extend:
546
+ if self.last_batch and self.last_batch_in_queue:
506
547
  tmp_batch, tmp_result = result_queue.popleft()
507
548
  self.process_batch_result(tmp_batch, tmp_result)
508
549
 
@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin:
516
557
  self.new_token_ratio = self.init_new_token_ratio
517
558
 
518
559
  self.last_batch = batch
519
- self.last_batch_is_extend = last_batch_is_extend
560
+ self.last_batch_in_queue = last_batch_in_queue
520
561
 
521
562
  def get_next_disagg_decode_batch_to_run(
522
563
  self: Scheduler,
@@ -600,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
600
641
 
601
642
  def process_decode_queue(self: Scheduler):
602
643
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
644
+
645
+ def _num_pre_alloc(req):
646
+ return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
647
+
648
+ self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
603
649
  self.disagg_decode_transfer_queue.extend(req_conns)
604
650
  alloc_reqs = (
605
651
  self.disagg_decode_transfer_queue.pop_transferred()
606
652
  ) # the requests which kv has arrived
607
- self.waiting_queue.extend(alloc_reqs)
653
+ self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
654
+
655
+ self.waiting_queue.extend([req.req for req in alloc_reqs])
@@ -6,6 +6,7 @@ import asyncio
6
6
  import random
7
7
  import urllib
8
8
  from itertools import chain
9
+ from typing import List
9
10
 
10
11
  import aiohttp
11
12
  import orjson
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
14
15
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
15
16
 
16
17
 
18
+ class PrefillConfig:
19
+ def __init__(self, url: str, bootstrap_port: int):
20
+ self.url = url
21
+ self.bootstrap_port = bootstrap_port
22
+
23
+
17
24
  class MiniLoadBalancer:
18
- def __init__(self, prefill_servers, decode_servers):
19
- self.prefill_servers = prefill_servers
25
+ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
26
+ self.prefill_configs = prefill_configs
27
+ self.prefill_servers = [p.url for p in prefill_configs]
20
28
  self.decode_servers = decode_servers
21
29
 
22
30
  def select_pair(self):
23
- return random.choice(self.prefill_servers), random.choice(self.decode_servers)
31
+ prefill_config = random.choice(self.prefill_configs)
32
+ decode_server = random.choice(self.decode_servers)
33
+ return prefill_config.url, prefill_config.bootstrap_port, decode_server
24
34
 
25
35
  async def generate(
26
36
  self, modified_request, prefill_server, decode_server, endpoint
@@ -160,7 +170,7 @@ async def get_model_info():
160
170
 
161
171
  @app.post("/generate")
162
172
  async def handle_generate_request(request_data: dict):
163
- prefill_server, decode_server = load_balancer.select_pair()
173
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
164
174
 
165
175
  # Parse and transform prefill_server for bootstrap data
166
176
  parsed_url = urllib.parse.urlparse(prefill_server)
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
172
182
  modified_request.update(
173
183
  {
174
184
  "bootstrap_host": [hostname] * batch_size,
185
+ "bootstrap_port": [bootstrap_port] * batch_size,
175
186
  "bootstrap_room": [
176
187
  _generate_bootstrap_room() for _ in range(batch_size)
177
188
  ],
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
181
192
  modified_request.update(
182
193
  {
183
194
  "bootstrap_host": hostname,
195
+ "bootstrap_port": bootstrap_port,
184
196
  "bootstrap_room": _generate_bootstrap_room(),
185
197
  }
186
198
  )
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
197
209
 
198
210
  @app.post("/v1/chat/completions")
199
211
  async def handle_completion_request(request_data: dict):
200
- prefill_server, decode_server = load_balancer.select_pair()
212
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
201
213
 
202
214
  # Parse and transform prefill_server for bootstrap data
203
215
  parsed_url = urllib.parse.urlparse(prefill_server)
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
206
218
  modified_request.update(
207
219
  {
208
220
  "bootstrap_host": hostname,
221
+ "bootstrap_port": bootstrap_port,
209
222
  "bootstrap_room": random.randint(0, 2**63 - 1),
210
223
  }
211
224
  )
@@ -255,9 +268,9 @@ async def get_models():
255
268
  raise HTTPException(status_code=500, detail=str(e))
256
269
 
257
270
 
258
- def run(prefill_addrs, decode_addrs, host, port):
271
+ def run(prefill_configs, decode_addrs, host, port):
259
272
  global load_balancer
260
- load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
273
+ load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
261
274
  uvicorn.run(app, host=host, port=port)
262
275
 
263
276
 
@@ -268,6 +281,11 @@ if __name__ == "__main__":
268
281
  parser.add_argument(
269
282
  "--prefill", required=True, help="Comma-separated URLs for prefill servers"
270
283
  )
284
+ parser.add_argument(
285
+ "--prefill-bootstrap-ports",
286
+ help="Comma-separated bootstrap ports for prefill servers",
287
+ default="8998",
288
+ )
271
289
  parser.add_argument(
272
290
  "--decode", required=True, help="Comma-separated URLs for decode servers"
273
291
  )
@@ -278,4 +296,23 @@ if __name__ == "__main__":
278
296
  "--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
279
297
  )
280
298
  args = parser.parse_args()
281
- run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
299
+
300
+ prefill_urls = args.prefill.split(",")
301
+ bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
+
303
+ if len(bootstrap_ports) == 1:
304
+ bootstrap_ports = bootstrap_ports * len(prefill_urls)
305
+ else:
306
+ if len(bootstrap_ports) != len(prefill_urls):
307
+ raise ValueError(
308
+ "Number of prefill URLs must match number of bootstrap ports"
309
+ )
310
+ exit(1)
311
+
312
+ prefill_configs = []
313
+ for url, port in zip(prefill_urls, bootstrap_ports):
314
+ prefill_configs.append(PrefillConfig(url, port))
315
+
316
+ decode_addrs = args.decode.split(",")
317
+
318
+ run(prefill_configs, decode_addrs, args.host, args.port)
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import concurrent.futures
4
5
  import dataclasses
5
6
  import logging
7
+ import os
6
8
  import queue
7
9
  import socket
8
10
  import struct
@@ -73,9 +75,7 @@ class TransferInfo:
73
75
  endpoint: str
74
76
  dst_port: int
75
77
  mooncake_session_id: str
76
- dst_kv_ptrs: list[int]
77
78
  dst_kv_indices: npt.NDArray[np.int64]
78
- dst_aux_ptrs: list[int]
79
79
  dst_aux_index: int
80
80
 
81
81
  @classmethod
@@ -85,10 +85,29 @@ class TransferInfo:
85
85
  endpoint=msg[1].decode("ascii"),
86
86
  dst_port=int(msg[2].decode("ascii")),
87
87
  mooncake_session_id=msg[3].decode("ascii"),
88
+ dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
89
+ dst_aux_index=int(msg[5].decode("ascii")),
90
+ )
91
+
92
+
93
+ @dataclasses.dataclass
94
+ class KVArgsRegisterInfo:
95
+ room: str
96
+ endpoint: str
97
+ dst_port: int
98
+ mooncake_session_id: str
99
+ dst_kv_ptrs: list[int]
100
+ dst_aux_ptrs: list[int]
101
+
102
+ @classmethod
103
+ def from_zmq(cls, msg: List[bytes]):
104
+ return cls(
105
+ room=str(msg[0].decode("ascii")),
106
+ endpoint=msg[1].decode("ascii"),
107
+ dst_port=int(msg[2].decode("ascii")),
108
+ mooncake_session_id=msg[3].decode("ascii"),
88
109
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
89
- dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
90
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
91
- dst_aux_index=int(msg[7].decode("ascii")),
110
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
92
111
  )
93
112
 
94
113
 
@@ -109,6 +128,13 @@ class MooncakeKVManager(BaseKVManager):
109
128
  # for p/d multi node infer
110
129
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
111
130
  self.dist_init_addr = server_args.dist_init_addr
131
+ self.tp_size = server_args.tp_size
132
+ self.dp_size = server_args.dp_size
133
+ self.enable_dp_attention = server_args.enable_dp_attention
134
+ if not server_args.enable_dp_attention and server_args.dp_size != 1:
135
+ raise ValueError(
136
+ "If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
137
+ )
112
138
  self.request_status: Dict[int, KVPoll] = {}
113
139
  self.rank_port = None
114
140
  self.server_socket = zmq.Context().socket(zmq.PULL)
@@ -116,11 +142,19 @@ class MooncakeKVManager(BaseKVManager):
116
142
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
117
143
  self.transfer_queue = queue.Queue()
118
144
  self.transfer_infos: Dict[int, TransferInfo] = {}
145
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
119
146
  self.start_prefill_thread()
120
147
  self._register_to_bootstrap()
148
+
149
+ # Determine the number of threads to use for kv sender
150
+ cpu_count = os.cpu_count()
151
+ self.executor = concurrent.futures.ThreadPoolExecutor(
152
+ min(cpu_count // 4, 16)
153
+ )
121
154
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
122
155
  self.start_decode_thread()
123
156
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
157
+ self.prefill_dp_size_table: Dict[str, int] = {}
124
158
  else:
125
159
  raise ValueError(
126
160
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
@@ -150,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
150
184
  dst_kv_ptrs: list[int],
151
185
  dst_kv_indices: npt.NDArray[np.int64],
152
186
  ):
153
- # group by indices
187
+ # Group by indices
154
188
  prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
155
189
  prefill_kv_indices, dst_kv_indices
156
190
  )
157
191
 
158
192
  num_layers = len(self.kv_args.kv_data_ptrs)
159
- for layer_id in range(num_layers):
160
- src_ptr = self.kv_args.kv_data_ptrs[layer_id]
161
- dst_ptr = dst_kv_ptrs[layer_id]
162
- item_len = self.kv_args.kv_item_lens[layer_id]
193
+ layers_params = [
194
+ (
195
+ self.kv_args.kv_data_ptrs[layer_id],
196
+ dst_kv_ptrs[layer_id],
197
+ self.kv_args.kv_item_lens[layer_id],
198
+ )
199
+ for layer_id in range(num_layers)
200
+ ]
163
201
 
202
+ # Worker function for processing a single layer
203
+ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
164
204
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
165
205
  src_addr = src_ptr + int(prefill_index[0]) * item_len
166
206
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
167
207
  length = item_len * len(prefill_index)
168
208
 
169
- # TODO: make async later
170
209
  status = self.engine.transfer_sync(
171
210
  mooncake_session_id, src_addr, dst_addr, length
172
211
  )
173
212
  if status != 0:
174
213
  return status
214
+ return 0
215
+
216
+ futures = [
217
+ self.executor.submit(
218
+ process_layer,
219
+ src_ptr,
220
+ dst_ptr,
221
+ item_len,
222
+ )
223
+ for (src_ptr, dst_ptr, item_len) in layers_params
224
+ ]
225
+
226
+ for future in concurrent.futures.as_completed(futures):
227
+ status = future.result()
228
+ if status != 0:
229
+ # Immediate shutdown on first error (existing tasks will finish)
230
+ executor.shutdown(wait=False)
231
+ for f in futures:
232
+ f.cancel()
233
+ return status
175
234
 
176
235
  return 0
177
236
 
@@ -215,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
215
274
  waiting_req_bytes = self.server_socket.recv_multipart()
216
275
  room = waiting_req_bytes[0].decode("ascii")
217
276
  if room == "None":
277
+ mooncake_session_id = waiting_req_bytes[3].decode("ascii")
278
+ self.decode_kv_args_table[mooncake_session_id] = (
279
+ KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
280
+ )
281
+ logger.debug(
282
+ f"Register KVArgs from {mooncake_session_id} successfully"
283
+ )
218
284
  continue
219
285
  room = int(room)
220
286
  self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
@@ -236,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
236
302
  ret = self.send_kvcache(
237
303
  req.mooncake_session_id,
238
304
  kv_chunk.prefill_kv_indices,
239
- req.dst_kv_ptrs,
305
+ self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
240
306
  chunked_dst_kv_indice,
241
307
  )
242
308
  if ret != 0:
@@ -251,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
251
317
  ret = self.send_aux(
252
318
  req.mooncake_session_id,
253
319
  kv_chunk.prefill_aux_index,
254
- req.dst_aux_ptrs,
320
+ self.decode_kv_args_table[
321
+ req.mooncake_session_id
322
+ ].dst_aux_ptrs,
255
323
  req.dst_aux_index,
256
324
  )
257
325
  self.request_status[req.room] = (
@@ -331,6 +399,8 @@ class MooncakeKVManager(BaseKVManager):
331
399
  url = f"http://{bootstrap_server_url}/route"
332
400
  payload = {
333
401
  "role": "Prefill",
402
+ "tp_size": self.tp_size,
403
+ "dp_size": self.dp_size,
334
404
  "rank_ip": get_local_ip_by_remote(),
335
405
  "rank_port": self.rank_port,
336
406
  "engine_rank": self.kv_args.engine_rank,
@@ -408,12 +478,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
408
478
  self.session_id = self.kv_mgr.get_session_id()
409
479
  self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
410
480
 
481
+ if not self.kv_mgr.enable_dp_attention:
482
+ # We assume dp_attention should be activated simultaneously for
483
+ # both prefill role and decode role. If the decode instance does
484
+ # not enable dp_attention, then dp_attention is not enabled on the
485
+ # prefill instance as well. Therefore, we should skip questioning
486
+ # the prefill dp size to reduce bootstrap overhead.
487
+ self.prefill_dp_size = 1
488
+ elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
489
+ self.prefill_dp_size, tp_size_per_dp_rank = (
490
+ self._get_prefill_dp_size_from_server()
491
+ )
492
+ # Currently, we don't allow prefill instance and decode instance to
493
+ # have different TP sizes per DP rank.
494
+ assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
495
+ if self.prefill_dp_size is None:
496
+ logger.error(
497
+ f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
498
+ )
499
+ else:
500
+ self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
501
+ self.prefill_dp_size
502
+ )
503
+ else:
504
+ self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
505
+ self.bootstrap_addr
506
+ ]
507
+
411
508
  # NOTE: key distinguished by bootstrap_addr and engine_rank
509
+ self.target_dp_group = bootstrap_room % self.prefill_dp_size
412
510
  bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
413
511
 
414
512
  if bootstrap_key not in self.kv_mgr.connection_pool:
415
513
  self.bootstrap_info = self._get_bootstrap_info_from_server(
416
- self.kv_mgr.kv_args.engine_rank
514
+ self.kv_mgr.kv_args.engine_rank,
515
+ self.target_dp_group,
417
516
  )
418
517
  if self.bootstrap_info is None:
419
518
  logger.error(
@@ -421,16 +520,18 @@ class MooncakeKVReceiver(BaseKVReceiver):
421
520
  )
422
521
  else:
423
522
  self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
523
+ # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
524
+ self._register_kv_args()
424
525
  else:
425
526
  self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
426
527
 
427
528
  assert self.bootstrap_info is not None
428
529
  self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
429
530
 
430
- def _get_bootstrap_info_from_server(self, engine_rank):
531
+ def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
431
532
  """Fetch the bootstrap info from the bootstrap server."""
432
533
  try:
433
- url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
534
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
434
535
  response = requests.get(url)
435
536
  if response.status_code == 200:
436
537
  bootstrap_info = response.json()
@@ -444,6 +545,49 @@ class MooncakeKVReceiver(BaseKVReceiver):
444
545
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
445
546
  return None
446
547
 
548
+ def _get_prefill_dp_size_from_server(self) -> int:
549
+ """Fetch the prefill parallel info from the bootstrap server."""
550
+ try:
551
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
552
+ response = requests.get(url)
553
+ if response.status_code == 200:
554
+ prefill_parallel_info = response.json()
555
+ return int(prefill_parallel_info["prefill_dp_size"]), int(
556
+ prefill_parallel_info["tp_size_per_dp_rank"]
557
+ )
558
+ else:
559
+ logger.error(
560
+ f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
561
+ )
562
+ return None
563
+ except Exception as e:
564
+ logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
565
+ return None
566
+
567
+ def _register_kv_args(self):
568
+ self.prefill_server_url = (
569
+ f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
570
+ )
571
+
572
+ packed_kv_data_ptrs = b"".join(
573
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
574
+ )
575
+ packed_aux_data_ptrs = b"".join(
576
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
577
+ )
578
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
579
+ with lock:
580
+ sock.send_multipart(
581
+ [
582
+ "None".encode("ascii"),
583
+ get_local_ip_by_remote().encode("ascii"),
584
+ str(self.kv_mgr.rank_port).encode("ascii"),
585
+ self.session_id.encode("ascii"),
586
+ packed_kv_data_ptrs,
587
+ packed_aux_data_ptrs,
588
+ ]
589
+ )
590
+
447
591
  @classmethod
448
592
  def _connect(cls, endpoint: str):
449
593
  with cls._global_lock:
@@ -462,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
462
606
  f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
463
607
  )
464
608
 
465
- packed_kv_data_ptrs = b"".join(
466
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
467
- )
468
- packed_aux_data_ptrs = b"".join(
469
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
470
- )
471
609
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
472
610
  with lock:
473
611
  sock.send_multipart(
@@ -476,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
476
614
  get_local_ip_by_remote().encode("ascii"),
477
615
  str(self.kv_mgr.rank_port).encode("ascii"),
478
616
  self.session_id.encode("ascii"),
479
- packed_kv_data_ptrs,
480
617
  kv_indices.tobytes(),
481
- packed_aux_data_ptrs,
482
618
  str(aux_index).encode("ascii"),
483
619
  ]
484
620
  )
@@ -497,7 +633,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
497
633
  self.store = dict()
498
634
  self.lock = asyncio.Lock()
499
635
  self._setup_routes()
500
- self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
636
+ self.dp_size = None
637
+ self.tp_size_per_dp_rank = None
638
+ self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
501
639
 
502
640
  # Start bootstrap server
503
641
  self.thread = threading.Thread(target=self._run_server, daemon=True)
@@ -523,35 +661,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
523
661
  async def _handle_route_put(self, request: web.Request):
524
662
  data = await request.json()
525
663
  role = data["role"]
664
+ tp_size = data["tp_size"]
665
+ dp_size = data["dp_size"]
526
666
  rank_ip = data["rank_ip"]
527
667
  rank_port = int(data["rank_port"])
528
668
  engine_rank = int(data["engine_rank"])
529
669
 
670
+ if self.dp_size is None:
671
+ self.dp_size = dp_size
672
+
673
+ tp_size_per_dp_rank = tp_size // dp_size
674
+ if self.tp_size_per_dp_rank == None:
675
+ self.tp_size_per_dp_rank = tp_size_per_dp_rank
676
+
530
677
  # Add lock to make sure thread-safe
531
678
  if role == "Prefill":
532
- self.prefill_port_table[engine_rank] = {
679
+ dp_group = engine_rank // tp_size_per_dp_rank
680
+ tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
681
+
682
+ async with self.lock:
683
+ if dp_group not in self.prefill_port_table:
684
+ self.prefill_port_table[dp_group] = {}
685
+
686
+ self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
533
687
  "rank_ip": rank_ip,
534
688
  "rank_port": rank_port,
535
689
  }
536
690
  logger.debug(
537
- f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
691
+ f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
538
692
  )
539
693
 
540
694
  return web.Response(text="OK", status=200)
541
695
 
542
696
  async def _handle_route_get(self, request: web.Request):
543
697
  engine_rank = request.query.get("engine_rank")
544
- if not engine_rank:
545
- return web.Response(text="Missing rank", status=400)
698
+ target_dp_group = request.query.get("target_dp_group")
699
+ if not engine_rank or not target_dp_group:
700
+ return web.Response(text="Missing inputs for bootstrap server.", status=400)
701
+
702
+ # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
703
+ if int(engine_rank) == -1 and int(target_dp_group) == -1:
704
+ prefill_parallel_info = {
705
+ "prefill_dp_size": self.dp_size,
706
+ "tp_size_per_dp_rank": self.tp_size_per_dp_rank,
707
+ }
708
+ return web.json_response(prefill_parallel_info, status=200)
546
709
 
547
710
  # Find corresponding prefill info
711
+ tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
712
+
548
713
  async with self.lock:
549
- bootstrap_info = self.prefill_port_table.get(int(engine_rank))
714
+ bootstrap_info = self.prefill_port_table[int(target_dp_group)][
715
+ tp_rank_in_dp_group
716
+ ]
550
717
 
551
718
  if bootstrap_info is not None:
552
719
  return web.json_response(bootstrap_info, status=200)
553
720
  else:
554
- return web.Response(text="Not Found", status=404)
721
+ return web.Response(text="Bootstrap info not Found", status=404)
555
722
 
556
723
  def _run_server(self):
557
724
  try: