sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
|
+
from sglang.srt.layers.dp_attention import (
|
38
|
+
get_attention_dp_rank,
|
39
|
+
get_attention_dp_size,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
)
|
37
43
|
from sglang.srt.server_args import ServerArgs
|
38
44
|
from sglang.srt.utils import (
|
39
45
|
format_tcp_address,
|
@@ -113,7 +119,7 @@ class KVArgsRegisterInfo:
|
|
113
119
|
dst_kv_ptrs: list[int]
|
114
120
|
dst_aux_ptrs: list[int]
|
115
121
|
dst_tp_rank: int
|
116
|
-
|
122
|
+
dst_attn_tp_size: int
|
117
123
|
dst_kv_item_len: int
|
118
124
|
|
119
125
|
@classmethod
|
@@ -126,7 +132,7 @@ class KVArgsRegisterInfo:
|
|
126
132
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
127
133
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
128
134
|
dst_tp_rank=int(msg[6].decode("ascii")),
|
129
|
-
|
135
|
+
dst_attn_tp_size=int(msg[7].decode("ascii")),
|
130
136
|
dst_kv_item_len=int(msg[8].decode("ascii")),
|
131
137
|
)
|
132
138
|
|
@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager):
|
|
147
153
|
# for p/d multi node infer
|
148
154
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
149
155
|
self.dist_init_addr = server_args.dist_init_addr
|
150
|
-
self.
|
151
|
-
self.
|
152
|
-
self.
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
156
|
+
self.attn_tp_size = get_attention_tp_size()
|
157
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
158
|
+
self.attn_dp_size = get_attention_dp_size()
|
159
|
+
self.attn_dp_rank = get_attention_dp_rank()
|
160
|
+
self.system_dp_size = (
|
161
|
+
1 if server_args.enable_dp_attention else server_args.dp_size
|
162
|
+
)
|
163
|
+
self.system_dp_rank = (
|
164
|
+
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
165
|
+
)
|
166
|
+
self.pp_size = server_args.pp_size
|
167
|
+
self.pp_rank = self.kv_args.pp_rank
|
157
168
|
self.request_status: Dict[int, KVPoll] = {}
|
158
169
|
self.rank_port = None
|
159
170
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
221
232
|
)
|
222
233
|
self.start_decode_thread()
|
223
234
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
224
|
-
self.
|
235
|
+
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
225
236
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
237
|
+
self.prefill_pp_size_table: Dict[str, int] = {}
|
226
238
|
# If a timeout happens on the decode side, it means decode instances
|
227
239
|
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
228
240
|
# These timeout requests should be aborted to release the tree cache.
|
@@ -296,15 +308,53 @@ class MooncakeKVManager(BaseKVManager):
|
|
296
308
|
prefill_kv_indices, dst_kv_indices
|
297
309
|
)
|
298
310
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
311
|
+
layers_params = None
|
312
|
+
|
313
|
+
# pp is not supported on the decode side yet
|
314
|
+
if self.is_mla_backend:
|
315
|
+
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
316
|
+
layers_per_pp_stage = len(src_kv_ptrs)
|
317
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
318
|
+
end_layer = start_layer + layers_per_pp_stage
|
319
|
+
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
320
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
321
|
+
layers_params = [
|
322
|
+
(
|
323
|
+
src_kv_ptrs[layer_id],
|
324
|
+
dst_kv_ptrs[layer_id],
|
325
|
+
kv_item_len,
|
326
|
+
)
|
327
|
+
for layer_id in range(layers_per_pp_stage)
|
328
|
+
]
|
329
|
+
else:
|
330
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
331
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
332
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
333
|
+
layers_per_pp_stage = len(src_k_ptrs)
|
334
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
335
|
+
end_layer = start_layer + layers_per_pp_stage
|
336
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
337
|
+
dst_v_ptrs = dst_kv_ptrs[
|
338
|
+
num_kv_layers + start_layer : num_kv_layers + end_layer
|
339
|
+
]
|
340
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
341
|
+
|
342
|
+
layers_params = [
|
343
|
+
(
|
344
|
+
src_k_ptrs[layer_id],
|
345
|
+
dst_k_ptrs[layer_id],
|
346
|
+
kv_item_len,
|
347
|
+
)
|
348
|
+
for layer_id in range(layers_per_pp_stage)
|
349
|
+
] + [
|
350
|
+
(
|
351
|
+
src_v_ptrs[layer_id],
|
352
|
+
dst_v_ptrs[layer_id],
|
353
|
+
kv_item_len,
|
354
|
+
)
|
355
|
+
for layer_id in range(layers_per_pp_stage)
|
356
|
+
]
|
357
|
+
assert layers_params is not None
|
308
358
|
|
309
359
|
# Worker function for processing a single layer
|
310
360
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
@@ -343,7 +393,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
343
393
|
dst_kv_ptrs: list[int],
|
344
394
|
dst_kv_indices: npt.NDArray[np.int64],
|
345
395
|
dst_tp_rank: int,
|
346
|
-
|
396
|
+
dst_attn_tp_size: int,
|
347
397
|
dst_kv_item_len: int,
|
348
398
|
executor: concurrent.futures.ThreadPoolExecutor,
|
349
399
|
):
|
@@ -356,23 +406,22 @@ class MooncakeKVManager(BaseKVManager):
|
|
356
406
|
This may introduce performance overhead (increased TTFT) for long sequences.
|
357
407
|
"""
|
358
408
|
# Extract configuration
|
359
|
-
|
360
|
-
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
|
409
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
|
361
410
|
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
362
|
-
dst_tp_rank_in_group = dst_tp_rank %
|
411
|
+
dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
|
363
412
|
num_kv_heads = self.kv_args.kv_head_num
|
364
413
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
365
414
|
page_size = self.kv_args.page_size
|
366
415
|
|
367
416
|
# Calculate head distribution
|
368
417
|
src_heads_per_rank = num_kv_heads
|
369
|
-
dst_heads_per_rank = num_kv_heads *
|
418
|
+
dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
|
370
419
|
bytes_per_head_slice_to_send = (
|
371
420
|
dst_kv_item_len // page_size // dst_heads_per_rank
|
372
421
|
)
|
373
422
|
|
374
423
|
# Determine slicing parameters based on TP configuration
|
375
|
-
if
|
424
|
+
if self.attn_tp_size > dst_attn_tp_size:
|
376
425
|
# Send KVCache from multiple prefill instances to 1 decode instance
|
377
426
|
src_head_start_offset = 0
|
378
427
|
num_heads_to_send = src_heads_per_rank
|
@@ -383,35 +432,55 @@ class MooncakeKVManager(BaseKVManager):
|
|
383
432
|
num_heads_to_send = dst_heads_per_rank
|
384
433
|
dst_head_start_offset = 0
|
385
434
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
435
|
+
# pp is not supported on the decode side yet
|
436
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
437
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
438
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
439
|
+
layers_per_pp_stage = len(src_k_ptrs)
|
440
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
441
|
+
end_layer = start_layer + layers_per_pp_stage
|
442
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
443
|
+
dst_v_ptrs = dst_kv_ptrs[
|
444
|
+
num_kv_layers + start_layer : num_kv_layers + end_layer
|
445
|
+
]
|
446
|
+
|
447
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
448
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
449
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
450
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
451
|
+
|
452
|
+
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
|
453
|
+
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
|
454
|
+
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
|
455
|
+
logger.error(
|
456
|
+
f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
|
457
|
+
f"target token slot size ({dst_kv_item_len // page_size})"
|
393
458
|
)
|
459
|
+
return -1
|
394
460
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
layers_params.append(
|
405
|
-
(
|
406
|
-
self.kv_args.kv_data_ptrs[layer_id],
|
407
|
-
dst_kv_ptrs[layer_id],
|
408
|
-
src_kv_item_len,
|
409
|
-
dst_kv_item_len,
|
410
|
-
src_head_slice_offset,
|
411
|
-
dst_head_slice_offset,
|
412
|
-
heads_bytes_per_token_to_send,
|
413
|
-
)
|
461
|
+
layers_params = [
|
462
|
+
(
|
463
|
+
src_k_ptrs[layer_id],
|
464
|
+
dst_k_ptrs[layer_id],
|
465
|
+
src_kv_item_len,
|
466
|
+
dst_kv_item_len,
|
467
|
+
src_head_slice_offset,
|
468
|
+
dst_head_slice_offset,
|
469
|
+
heads_bytes_per_token_to_send,
|
414
470
|
)
|
471
|
+
for layer_id in range(layers_per_pp_stage)
|
472
|
+
] + [
|
473
|
+
(
|
474
|
+
src_v_ptrs[layer_id],
|
475
|
+
dst_v_ptrs[layer_id],
|
476
|
+
src_kv_item_len,
|
477
|
+
dst_kv_item_len,
|
478
|
+
src_head_slice_offset,
|
479
|
+
dst_head_slice_offset,
|
480
|
+
heads_bytes_per_token_to_send,
|
481
|
+
)
|
482
|
+
for layer_id in range(layers_per_pp_stage)
|
483
|
+
]
|
415
484
|
|
416
485
|
def process_layer_tp_aware(layer_params):
|
417
486
|
(
|
@@ -562,9 +631,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
562
631
|
target_rank_registration_info: KVArgsRegisterInfo = (
|
563
632
|
self.decode_kv_args_table[req.mooncake_session_id]
|
564
633
|
)
|
565
|
-
local_tp_size = self.tp_size // self.dp_size
|
566
634
|
if self.is_mla_backend or (
|
567
|
-
|
635
|
+
self.attn_tp_size
|
636
|
+
== target_rank_registration_info.dst_attn_tp_size
|
568
637
|
):
|
569
638
|
ret = self.send_kvcache(
|
570
639
|
req.mooncake_session_id,
|
@@ -580,7 +649,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
580
649
|
target_rank_registration_info.dst_kv_ptrs,
|
581
650
|
chunked_dst_kv_indice,
|
582
651
|
target_rank_registration_info.dst_tp_rank,
|
583
|
-
target_rank_registration_info.
|
652
|
+
target_rank_registration_info.dst_attn_tp_size,
|
584
653
|
target_rank_registration_info.dst_kv_item_len,
|
585
654
|
executor,
|
586
655
|
)
|
@@ -863,11 +932,16 @@ class MooncakeKVManager(BaseKVManager):
|
|
863
932
|
url = f"http://{bootstrap_server_url}/route"
|
864
933
|
payload = {
|
865
934
|
"role": "Prefill",
|
866
|
-
"
|
867
|
-
"
|
935
|
+
"attn_tp_size": self.attn_tp_size,
|
936
|
+
"attn_tp_rank": self.attn_tp_rank,
|
937
|
+
"attn_dp_size": self.attn_dp_size,
|
938
|
+
"attn_dp_rank": self.attn_dp_rank,
|
939
|
+
"pp_size": self.pp_size,
|
940
|
+
"pp_rank": self.pp_rank,
|
941
|
+
"system_dp_size": self.system_dp_size,
|
942
|
+
"system_dp_rank": self.system_dp_rank,
|
868
943
|
"rank_ip": self.local_ip,
|
869
944
|
"rank_port": self.rank_port,
|
870
|
-
"engine_rank": self.kv_args.engine_rank,
|
871
945
|
}
|
872
946
|
|
873
947
|
try:
|
@@ -890,10 +964,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
890
964
|
]
|
891
965
|
for k in keys_to_remove:
|
892
966
|
del self.connection_pool[k]
|
893
|
-
if failed_bootstrap_addr in self.
|
894
|
-
del self.
|
967
|
+
if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
|
968
|
+
del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
|
895
969
|
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
896
970
|
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
971
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
972
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
897
973
|
|
898
974
|
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
899
975
|
failed_bootstrap_addr, []
|
@@ -915,7 +991,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
915
991
|
self.update_status(room, KVPoll.Failed)
|
916
992
|
affected_rooms.append(room)
|
917
993
|
logger.error(
|
918
|
-
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}),
|
994
|
+
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
|
919
995
|
)
|
920
996
|
|
921
997
|
|
@@ -1042,10 +1118,16 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1042
1118
|
self.data_parallel_rank = data_parallel_rank
|
1043
1119
|
|
1044
1120
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
1045
|
-
|
1046
|
-
self.
|
1047
|
-
|
1048
|
-
|
1121
|
+
(
|
1122
|
+
self.prefill_attn_tp_size,
|
1123
|
+
self.prefill_dp_size,
|
1124
|
+
self.prefill_pp_size,
|
1125
|
+
) = self._get_prefill_parallel_info_from_server()
|
1126
|
+
if (
|
1127
|
+
self.prefill_attn_tp_size is None
|
1128
|
+
or self.prefill_dp_size is None
|
1129
|
+
or self.prefill_pp_size is None
|
1130
|
+
):
|
1049
1131
|
self.kv_mgr.record_failure(
|
1050
1132
|
self.bootstrap_room,
|
1051
1133
|
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
@@ -1054,43 +1136,47 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1054
1136
|
return
|
1055
1137
|
else:
|
1056
1138
|
logger.debug(
|
1057
|
-
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.
|
1139
|
+
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
1058
1140
|
)
|
1059
|
-
self.kv_mgr.
|
1060
|
-
self.
|
1141
|
+
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
1142
|
+
self.prefill_attn_tp_size
|
1061
1143
|
)
|
1062
1144
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
1063
1145
|
self.prefill_dp_size
|
1064
1146
|
)
|
1147
|
+
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
1148
|
+
self.prefill_pp_size
|
1149
|
+
)
|
1065
1150
|
else:
|
1066
|
-
self.
|
1151
|
+
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
1067
1152
|
self.bootstrap_addr
|
1068
1153
|
]
|
1069
1154
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
1070
1155
|
self.bootstrap_addr
|
1071
1156
|
]
|
1157
|
+
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
1158
|
+
self.bootstrap_addr
|
1159
|
+
]
|
1072
1160
|
|
1073
1161
|
# Currently, we don't allow prefill instance and decode instance to
|
1074
1162
|
# have different TP sizes per DP rank, except for models using MLA.
|
1075
|
-
|
1076
|
-
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
1077
|
-
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
1163
|
+
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
1078
1164
|
self.target_tp_rank = (
|
1079
|
-
self.kv_mgr.kv_args.engine_rank %
|
1165
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1080
1166
|
)
|
1081
1167
|
self.required_dst_info_num = 1
|
1082
1168
|
self.required_prefill_response_num = 1
|
1083
1169
|
self.target_tp_ranks = [self.target_tp_rank]
|
1084
|
-
elif
|
1170
|
+
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
1085
1171
|
if not self.kv_mgr.is_mla_backend:
|
1086
1172
|
logger.warning_once(
|
1087
1173
|
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1088
1174
|
)
|
1089
1175
|
self.target_tp_rank = (
|
1090
|
-
self.kv_mgr.kv_args.engine_rank %
|
1091
|
-
) // (
|
1176
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1177
|
+
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
1092
1178
|
self.required_dst_info_num = (
|
1093
|
-
|
1179
|
+
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
1094
1180
|
)
|
1095
1181
|
self.required_prefill_response_num = 1
|
1096
1182
|
self.target_tp_ranks = [self.target_tp_rank]
|
@@ -1103,10 +1189,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1103
1189
|
self.target_tp_ranks = [
|
1104
1190
|
rank
|
1105
1191
|
for rank in range(
|
1106
|
-
(self.kv_mgr.kv_args.engine_rank %
|
1107
|
-
* (
|
1108
|
-
(self.kv_mgr.kv_args.engine_rank %
|
1109
|
-
* (
|
1192
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
1193
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1194
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
1195
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1110
1196
|
)
|
1111
1197
|
]
|
1112
1198
|
|
@@ -1116,7 +1202,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1116
1202
|
self.target_tp_rank = self.target_tp_ranks[0]
|
1117
1203
|
self.required_dst_info_num = 1
|
1118
1204
|
self.required_prefill_response_num = (
|
1119
|
-
|
1205
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1120
1206
|
)
|
1121
1207
|
|
1122
1208
|
if self.data_parallel_rank is not None:
|
@@ -1136,31 +1222,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1136
1222
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
1137
1223
|
bootstrap_infos = []
|
1138
1224
|
for target_tp_rank in self.target_tp_ranks:
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1225
|
+
for target_pp_rank in range(self.prefill_pp_size):
|
1226
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
1227
|
+
target_tp_rank, self.target_dp_group, target_pp_rank
|
1228
|
+
)
|
1229
|
+
if bootstrap_info is not None:
|
1230
|
+
if self.kv_mgr.is_mla_backend:
|
1231
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1232
|
+
bootstrap_info["is_dummy"] = not bool(
|
1233
|
+
target_tp_rank == self.target_tp_rank
|
1234
|
+
or self.target_tp_rank is None
|
1235
|
+
)
|
1236
|
+
else:
|
1237
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
1238
|
+
bootstrap_info["is_dummy"] = False
|
1239
|
+
logger.debug(
|
1240
|
+
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
1149
1241
|
)
|
1242
|
+
bootstrap_infos.append(bootstrap_info)
|
1150
1243
|
else:
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
else:
|
1158
|
-
self.kv_mgr.record_failure(
|
1159
|
-
self.bootstrap_room,
|
1160
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
|
1161
|
-
)
|
1162
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1163
|
-
return
|
1244
|
+
self.kv_mgr.record_failure(
|
1245
|
+
self.bootstrap_room,
|
1246
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
1247
|
+
)
|
1248
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1249
|
+
return
|
1164
1250
|
|
1165
1251
|
self.bootstrap_infos = bootstrap_infos
|
1166
1252
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
@@ -1174,10 +1260,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1174
1260
|
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
|
1175
1261
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
|
1176
1262
|
|
1177
|
-
def _get_bootstrap_info_from_server(
|
1263
|
+
def _get_bootstrap_info_from_server(
|
1264
|
+
self, engine_rank, target_dp_group, target_pp_rank
|
1265
|
+
):
|
1178
1266
|
"""Fetch the bootstrap info from the bootstrap server."""
|
1179
1267
|
try:
|
1180
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
1268
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
|
1181
1269
|
response = requests.get(url, timeout=5)
|
1182
1270
|
if response.status_code == 200:
|
1183
1271
|
bootstrap_info = response.json()
|
@@ -1191,24 +1279,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1191
1279
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
1192
1280
|
return None
|
1193
1281
|
|
1194
|
-
def _get_prefill_parallel_info_from_server(
|
1282
|
+
def _get_prefill_parallel_info_from_server(
|
1283
|
+
self,
|
1284
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
1195
1285
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
1196
1286
|
try:
|
1197
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
1287
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
1198
1288
|
response = requests.get(url)
|
1199
1289
|
if response.status_code == 200:
|
1200
1290
|
prefill_parallel_info = response.json()
|
1201
|
-
return
|
1202
|
-
prefill_parallel_info["
|
1291
|
+
return (
|
1292
|
+
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
1293
|
+
int(prefill_parallel_info["prefill_dp_size"]),
|
1294
|
+
int(prefill_parallel_info["prefill_pp_size"]),
|
1203
1295
|
)
|
1204
1296
|
else:
|
1205
1297
|
logger.error(
|
1206
1298
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
1207
1299
|
)
|
1208
|
-
return None, None
|
1300
|
+
return None, None, None
|
1209
1301
|
except Exception as e:
|
1210
1302
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
1211
|
-
return None, None
|
1303
|
+
return None, None, None
|
1212
1304
|
|
1213
1305
|
def _register_kv_args(self):
|
1214
1306
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -1218,11 +1310,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1218
1310
|
packed_aux_data_ptrs = b"".join(
|
1219
1311
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
1220
1312
|
)
|
1313
|
+
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
|
1221
1314
|
tp_rank = self.kv_mgr.kv_args.engine_rank
|
1222
|
-
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
1223
1315
|
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
1224
1316
|
dst_tp_rank = str(tp_rank).encode("ascii")
|
1225
|
-
|
1317
|
+
dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
|
1226
1318
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
1227
1319
|
|
1228
1320
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
@@ -1236,7 +1328,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1236
1328
|
packed_kv_data_ptrs,
|
1237
1329
|
packed_aux_data_ptrs,
|
1238
1330
|
dst_tp_rank,
|
1239
|
-
|
1331
|
+
dst_attn_tp_size,
|
1240
1332
|
dst_kv_item_len,
|
1241
1333
|
]
|
1242
1334
|
)
|
@@ -1347,10 +1439,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1347
1439
|
self.store = dict()
|
1348
1440
|
self.lock = asyncio.Lock()
|
1349
1441
|
self._setup_routes()
|
1350
|
-
self.
|
1442
|
+
self.pp_size = None
|
1443
|
+
self.attn_tp_size = None
|
1351
1444
|
self.dp_size = None
|
1352
|
-
self.
|
1353
|
-
|
1445
|
+
self.prefill_port_table: Dict[
|
1446
|
+
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
1447
|
+
] = {}
|
1354
1448
|
|
1355
1449
|
# Start bootstrap server
|
1356
1450
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -1380,37 +1474,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1380
1474
|
async def _handle_route_put(self, request: web.Request):
|
1381
1475
|
data = await request.json()
|
1382
1476
|
role = data["role"]
|
1383
|
-
|
1384
|
-
|
1477
|
+
attn_tp_size = data["attn_tp_size"]
|
1478
|
+
attn_tp_rank = data["attn_tp_rank"]
|
1479
|
+
attn_dp_size = data["attn_dp_size"]
|
1480
|
+
attn_dp_rank = data["attn_dp_rank"]
|
1481
|
+
pp_size = data["pp_size"]
|
1482
|
+
pp_rank = data["pp_rank"]
|
1483
|
+
system_dp_size = data["system_dp_size"]
|
1484
|
+
system_dp_rank = data["system_dp_rank"]
|
1385
1485
|
rank_ip = data["rank_ip"]
|
1386
1486
|
rank_port = int(data["rank_port"])
|
1387
|
-
engine_rank = int(data["engine_rank"])
|
1388
1487
|
|
1389
|
-
if self.
|
1390
|
-
self.
|
1488
|
+
if self.attn_tp_size is None:
|
1489
|
+
self.attn_tp_size = attn_tp_size
|
1391
1490
|
|
1392
1491
|
if self.dp_size is None:
|
1393
|
-
self.dp_size =
|
1492
|
+
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
1394
1493
|
|
1395
|
-
|
1396
|
-
|
1397
|
-
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
1494
|
+
if self.pp_size is None:
|
1495
|
+
self.pp_size = pp_size
|
1398
1496
|
|
1399
1497
|
if role == "Prefill":
|
1400
|
-
|
1401
|
-
|
1498
|
+
if system_dp_size == 1:
|
1499
|
+
dp_group = attn_dp_rank
|
1500
|
+
else:
|
1501
|
+
dp_group = system_dp_rank
|
1402
1502
|
|
1403
1503
|
# Add lock to make sure thread-safe
|
1404
1504
|
async with self.lock:
|
1405
1505
|
if dp_group not in self.prefill_port_table:
|
1406
1506
|
self.prefill_port_table[dp_group] = {}
|
1507
|
+
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
1508
|
+
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
1407
1509
|
|
1408
|
-
self.prefill_port_table[dp_group][
|
1510
|
+
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
1409
1511
|
"rank_ip": rank_ip,
|
1410
1512
|
"rank_port": rank_port,
|
1411
1513
|
}
|
1412
1514
|
logger.debug(
|
1413
|
-
f"Register prefill bootstrap: {
|
1515
|
+
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
1414
1516
|
)
|
1415
1517
|
|
1416
1518
|
return web.Response(text="OK", status=200)
|
@@ -1418,14 +1520,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1418
1520
|
async def _handle_route_get(self, request: web.Request):
|
1419
1521
|
engine_rank = request.query.get("engine_rank")
|
1420
1522
|
target_dp_group = request.query.get("target_dp_group")
|
1421
|
-
|
1523
|
+
target_pp_rank = request.query.get("target_pp_rank")
|
1524
|
+
if not engine_rank or not target_dp_group or not target_pp_rank:
|
1422
1525
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
1423
1526
|
|
1424
1527
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
1425
|
-
if
|
1528
|
+
if (
|
1529
|
+
int(engine_rank) == -1
|
1530
|
+
and int(target_dp_group) == -1
|
1531
|
+
and int(target_pp_rank) == -1
|
1532
|
+
):
|
1426
1533
|
prefill_parallel_info = {
|
1427
|
-
"
|
1534
|
+
"prefill_attn_tp_size": self.attn_tp_size,
|
1428
1535
|
"prefill_dp_size": self.dp_size,
|
1536
|
+
"prefill_pp_size": self.pp_size,
|
1429
1537
|
}
|
1430
1538
|
return web.json_response(prefill_parallel_info, status=200)
|
1431
1539
|
|
@@ -1433,7 +1541,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1433
1541
|
async with self.lock:
|
1434
1542
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
1435
1543
|
int(engine_rank)
|
1436
|
-
]
|
1544
|
+
][int(target_pp_rank)]
|
1437
1545
|
|
1438
1546
|
if bootstrap_info is not None:
|
1439
1547
|
return web.json_response(bootstrap_info, status=200)
|
@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
|
|
103
103
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
104
104
|
kv_args = kv_args_class()
|
105
105
|
kv_args.engine_rank = self.tp_rank
|
106
|
+
kv_args.pp_rank = self.pp_rank
|
107
|
+
kv_args.system_dp_rank = self.scheduler.dp_rank
|
106
108
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
107
109
|
kv_args.prefill_pp_size = self.pp_size
|
108
110
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|