sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
|
+
from sglang.srt.layers.dp_attention import (
|
38
|
+
get_attention_dp_rank,
|
39
|
+
get_attention_dp_size,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
)
|
37
43
|
from sglang.srt.server_args import ServerArgs
|
38
44
|
from sglang.srt.utils import (
|
39
45
|
format_tcp_address,
|
@@ -113,7 +119,7 @@ class KVArgsRegisterInfo:
|
|
113
119
|
dst_kv_ptrs: list[int]
|
114
120
|
dst_aux_ptrs: list[int]
|
115
121
|
dst_tp_rank: int
|
116
|
-
|
122
|
+
dst_attn_tp_size: int
|
117
123
|
dst_kv_item_len: int
|
118
124
|
|
119
125
|
@classmethod
|
@@ -126,7 +132,7 @@ class KVArgsRegisterInfo:
|
|
126
132
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
127
133
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
128
134
|
dst_tp_rank=int(msg[6].decode("ascii")),
|
129
|
-
|
135
|
+
dst_attn_tp_size=int(msg[7].decode("ascii")),
|
130
136
|
dst_kv_item_len=int(msg[8].decode("ascii")),
|
131
137
|
)
|
132
138
|
|
@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager):
|
|
147
153
|
# for p/d multi node infer
|
148
154
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
149
155
|
self.dist_init_addr = server_args.dist_init_addr
|
150
|
-
self.
|
151
|
-
self.
|
152
|
-
self.
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
156
|
+
self.attn_tp_size = get_attention_tp_size()
|
157
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
158
|
+
self.attn_dp_size = get_attention_dp_size()
|
159
|
+
self.attn_dp_rank = get_attention_dp_rank()
|
160
|
+
self.system_dp_size = (
|
161
|
+
1 if server_args.enable_dp_attention else server_args.dp_size
|
162
|
+
)
|
163
|
+
self.system_dp_rank = (
|
164
|
+
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
|
165
|
+
)
|
166
|
+
self.pp_size = server_args.pp_size
|
167
|
+
self.pp_rank = self.kv_args.pp_rank
|
157
168
|
self.request_status: Dict[int, KVPoll] = {}
|
158
169
|
self.rank_port = None
|
159
170
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
221
232
|
)
|
222
233
|
self.start_decode_thread()
|
223
234
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
224
|
-
self.
|
235
|
+
self.prefill_attn_tp_size_table: Dict[str, int] = {}
|
225
236
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
237
|
+
self.prefill_pp_size_table: Dict[str, int] = {}
|
226
238
|
# If a timeout happens on the decode side, it means decode instances
|
227
239
|
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
228
240
|
# These timeout requests should be aborted to release the tree cache.
|
@@ -245,15 +257,17 @@ class MooncakeKVManager(BaseKVManager):
|
|
245
257
|
)
|
246
258
|
|
247
259
|
def register_buffer_to_engine(self):
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
260
|
+
# Batch register KV data buffers
|
261
|
+
if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
|
262
|
+
self.engine.batch_register(
|
263
|
+
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
264
|
+
)
|
252
265
|
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
266
|
+
# Batch register auxiliary data buffers
|
267
|
+
if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
|
268
|
+
self.engine.batch_register(
|
269
|
+
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
270
|
+
)
|
257
271
|
|
258
272
|
@cache
|
259
273
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
@@ -296,43 +310,97 @@ class MooncakeKVManager(BaseKVManager):
|
|
296
310
|
prefill_kv_indices, dst_kv_indices
|
297
311
|
)
|
298
312
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
313
|
+
layers_params = None
|
314
|
+
|
315
|
+
# pp is not supported on the decode side yet
|
316
|
+
if self.is_mla_backend:
|
317
|
+
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
318
|
+
layers_per_pp_stage = len(src_kv_ptrs)
|
319
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
320
|
+
end_layer = start_layer + layers_per_pp_stage
|
321
|
+
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
322
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
323
|
+
layers_params = [
|
324
|
+
(
|
325
|
+
src_kv_ptrs[layer_id],
|
326
|
+
dst_kv_ptrs[layer_id],
|
327
|
+
kv_item_len,
|
328
|
+
)
|
329
|
+
for layer_id in range(layers_per_pp_stage)
|
330
|
+
]
|
331
|
+
else:
|
332
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
333
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
334
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
335
|
+
layers_per_pp_stage = len(src_k_ptrs)
|
336
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
337
|
+
end_layer = start_layer + layers_per_pp_stage
|
338
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
339
|
+
dst_v_ptrs = dst_kv_ptrs[
|
340
|
+
num_kv_layers + start_layer : num_kv_layers + end_layer
|
341
|
+
]
|
342
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
308
343
|
|
309
|
-
|
310
|
-
|
344
|
+
layers_params = [
|
345
|
+
(
|
346
|
+
src_k_ptrs[layer_id],
|
347
|
+
dst_k_ptrs[layer_id],
|
348
|
+
kv_item_len,
|
349
|
+
)
|
350
|
+
for layer_id in range(layers_per_pp_stage)
|
351
|
+
] + [
|
352
|
+
(
|
353
|
+
src_v_ptrs[layer_id],
|
354
|
+
dst_v_ptrs[layer_id],
|
355
|
+
kv_item_len,
|
356
|
+
)
|
357
|
+
for layer_id in range(layers_per_pp_stage)
|
358
|
+
]
|
359
|
+
assert layers_params is not None
|
360
|
+
|
361
|
+
def set_transfer_blocks(
|
362
|
+
src_ptr: int, dst_ptr: int, item_len: int
|
363
|
+
) -> List[Tuple[int, int, int]]:
|
311
364
|
transfer_blocks = []
|
312
365
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
313
366
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
314
367
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
315
368
|
length = item_len * len(prefill_index)
|
316
369
|
transfer_blocks.append((src_addr, dst_addr, length))
|
370
|
+
return transfer_blocks
|
317
371
|
|
372
|
+
# Worker function for processing a single layer
|
373
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
374
|
+
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
318
375
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
319
376
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
dst_ptr,
|
325
|
-
|
326
|
-
)
|
327
|
-
for (src_ptr, dst_ptr, item_len) in layers_params
|
328
|
-
]
|
377
|
+
# Worker function for processing all layers in a batch
|
378
|
+
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
379
|
+
transfer_blocks = []
|
380
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
381
|
+
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
382
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
329
383
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
384
|
+
if self.enable_custom_mem_pool:
|
385
|
+
futures = [
|
386
|
+
executor.submit(
|
387
|
+
process_layer,
|
388
|
+
src_ptr,
|
389
|
+
dst_ptr,
|
390
|
+
item_len,
|
391
|
+
)
|
392
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
393
|
+
]
|
394
|
+
for future in concurrent.futures.as_completed(futures):
|
395
|
+
status = future.result()
|
396
|
+
if status != 0:
|
397
|
+
for f in futures:
|
398
|
+
f.cancel()
|
399
|
+
return status
|
400
|
+
else:
|
401
|
+
# Combining all layers' params in one batch transfer is more efficient
|
402
|
+
# compared to using multiple threads
|
403
|
+
return process_layers(layers_params)
|
336
404
|
|
337
405
|
return 0
|
338
406
|
|
@@ -343,7 +411,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
343
411
|
dst_kv_ptrs: list[int],
|
344
412
|
dst_kv_indices: npt.NDArray[np.int64],
|
345
413
|
dst_tp_rank: int,
|
346
|
-
|
414
|
+
dst_attn_tp_size: int,
|
347
415
|
dst_kv_item_len: int,
|
348
416
|
executor: concurrent.futures.ThreadPoolExecutor,
|
349
417
|
):
|
@@ -356,23 +424,22 @@ class MooncakeKVManager(BaseKVManager):
|
|
356
424
|
This may introduce performance overhead (increased TTFT) for long sequences.
|
357
425
|
"""
|
358
426
|
# Extract configuration
|
359
|
-
|
360
|
-
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
|
427
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
|
361
428
|
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
362
|
-
dst_tp_rank_in_group = dst_tp_rank %
|
429
|
+
dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
|
363
430
|
num_kv_heads = self.kv_args.kv_head_num
|
364
431
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
365
432
|
page_size = self.kv_args.page_size
|
366
433
|
|
367
434
|
# Calculate head distribution
|
368
435
|
src_heads_per_rank = num_kv_heads
|
369
|
-
dst_heads_per_rank = num_kv_heads *
|
436
|
+
dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
|
370
437
|
bytes_per_head_slice_to_send = (
|
371
438
|
dst_kv_item_len // page_size // dst_heads_per_rank
|
372
439
|
)
|
373
440
|
|
374
441
|
# Determine slicing parameters based on TP configuration
|
375
|
-
if
|
442
|
+
if self.attn_tp_size > dst_attn_tp_size:
|
376
443
|
# Send KVCache from multiple prefill instances to 1 decode instance
|
377
444
|
src_head_start_offset = 0
|
378
445
|
num_heads_to_send = src_heads_per_rank
|
@@ -383,35 +450,55 @@ class MooncakeKVManager(BaseKVManager):
|
|
383
450
|
num_heads_to_send = dst_heads_per_rank
|
384
451
|
dst_head_start_offset = 0
|
385
452
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
453
|
+
# pp is not supported on the decode side yet
|
454
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
455
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
456
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
457
|
+
layers_per_pp_stage = len(src_k_ptrs)
|
458
|
+
start_layer = self.pp_rank * layers_per_pp_stage
|
459
|
+
end_layer = start_layer + layers_per_pp_stage
|
460
|
+
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
461
|
+
dst_v_ptrs = dst_kv_ptrs[
|
462
|
+
num_kv_layers + start_layer : num_kv_layers + end_layer
|
463
|
+
]
|
464
|
+
|
465
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
466
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
467
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
468
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
469
|
+
|
470
|
+
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
|
471
|
+
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
|
472
|
+
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
|
473
|
+
logger.error(
|
474
|
+
f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
|
475
|
+
f"target token slot size ({dst_kv_item_len // page_size})"
|
393
476
|
)
|
477
|
+
return -1
|
394
478
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
479
|
+
layers_params = [
|
480
|
+
(
|
481
|
+
src_k_ptrs[layer_id],
|
482
|
+
dst_k_ptrs[layer_id],
|
483
|
+
src_kv_item_len,
|
484
|
+
dst_kv_item_len,
|
485
|
+
src_head_slice_offset,
|
486
|
+
dst_head_slice_offset,
|
487
|
+
heads_bytes_per_token_to_send,
|
488
|
+
)
|
489
|
+
for layer_id in range(layers_per_pp_stage)
|
490
|
+
] + [
|
491
|
+
(
|
492
|
+
src_v_ptrs[layer_id],
|
493
|
+
dst_v_ptrs[layer_id],
|
494
|
+
src_kv_item_len,
|
495
|
+
dst_kv_item_len,
|
496
|
+
src_head_slice_offset,
|
497
|
+
dst_head_slice_offset,
|
498
|
+
heads_bytes_per_token_to_send,
|
414
499
|
)
|
500
|
+
for layer_id in range(layers_per_pp_stage)
|
501
|
+
]
|
415
502
|
|
416
503
|
def process_layer_tp_aware(layer_params):
|
417
504
|
(
|
@@ -562,9 +649,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
562
649
|
target_rank_registration_info: KVArgsRegisterInfo = (
|
563
650
|
self.decode_kv_args_table[req.mooncake_session_id]
|
564
651
|
)
|
565
|
-
local_tp_size = self.tp_size // self.dp_size
|
566
652
|
if self.is_mla_backend or (
|
567
|
-
|
653
|
+
self.attn_tp_size
|
654
|
+
== target_rank_registration_info.dst_attn_tp_size
|
568
655
|
):
|
569
656
|
ret = self.send_kvcache(
|
570
657
|
req.mooncake_session_id,
|
@@ -580,7 +667,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
580
667
|
target_rank_registration_info.dst_kv_ptrs,
|
581
668
|
chunked_dst_kv_indice,
|
582
669
|
target_rank_registration_info.dst_tp_rank,
|
583
|
-
target_rank_registration_info.
|
670
|
+
target_rank_registration_info.dst_attn_tp_size,
|
584
671
|
target_rank_registration_info.dst_kv_item_len,
|
585
672
|
executor,
|
586
673
|
)
|
@@ -863,11 +950,16 @@ class MooncakeKVManager(BaseKVManager):
|
|
863
950
|
url = f"http://{bootstrap_server_url}/route"
|
864
951
|
payload = {
|
865
952
|
"role": "Prefill",
|
866
|
-
"
|
867
|
-
"
|
953
|
+
"attn_tp_size": self.attn_tp_size,
|
954
|
+
"attn_tp_rank": self.attn_tp_rank,
|
955
|
+
"attn_dp_size": self.attn_dp_size,
|
956
|
+
"attn_dp_rank": self.attn_dp_rank,
|
957
|
+
"pp_size": self.pp_size,
|
958
|
+
"pp_rank": self.pp_rank,
|
959
|
+
"system_dp_size": self.system_dp_size,
|
960
|
+
"system_dp_rank": self.system_dp_rank,
|
868
961
|
"rank_ip": self.local_ip,
|
869
962
|
"rank_port": self.rank_port,
|
870
|
-
"engine_rank": self.kv_args.engine_rank,
|
871
963
|
}
|
872
964
|
|
873
965
|
try:
|
@@ -890,10 +982,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
890
982
|
]
|
891
983
|
for k in keys_to_remove:
|
892
984
|
del self.connection_pool[k]
|
893
|
-
if failed_bootstrap_addr in self.
|
894
|
-
del self.
|
985
|
+
if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
|
986
|
+
del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
|
895
987
|
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
896
988
|
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
989
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
990
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
897
991
|
|
898
992
|
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
899
993
|
failed_bootstrap_addr, []
|
@@ -915,7 +1009,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
915
1009
|
self.update_status(room, KVPoll.Failed)
|
916
1010
|
affected_rooms.append(room)
|
917
1011
|
logger.error(
|
918
|
-
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}),
|
1012
|
+
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
|
919
1013
|
)
|
920
1014
|
|
921
1015
|
|
@@ -1042,10 +1136,16 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1042
1136
|
self.data_parallel_rank = data_parallel_rank
|
1043
1137
|
|
1044
1138
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
1045
|
-
|
1046
|
-
self.
|
1047
|
-
|
1048
|
-
|
1139
|
+
(
|
1140
|
+
self.prefill_attn_tp_size,
|
1141
|
+
self.prefill_dp_size,
|
1142
|
+
self.prefill_pp_size,
|
1143
|
+
) = self._get_prefill_parallel_info_from_server()
|
1144
|
+
if (
|
1145
|
+
self.prefill_attn_tp_size is None
|
1146
|
+
or self.prefill_dp_size is None
|
1147
|
+
or self.prefill_pp_size is None
|
1148
|
+
):
|
1049
1149
|
self.kv_mgr.record_failure(
|
1050
1150
|
self.bootstrap_room,
|
1051
1151
|
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
@@ -1054,43 +1154,47 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1054
1154
|
return
|
1055
1155
|
else:
|
1056
1156
|
logger.debug(
|
1057
|
-
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.
|
1157
|
+
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
|
1058
1158
|
)
|
1059
|
-
self.kv_mgr.
|
1060
|
-
self.
|
1159
|
+
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
|
1160
|
+
self.prefill_attn_tp_size
|
1061
1161
|
)
|
1062
1162
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
1063
1163
|
self.prefill_dp_size
|
1064
1164
|
)
|
1165
|
+
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
|
1166
|
+
self.prefill_pp_size
|
1167
|
+
)
|
1065
1168
|
else:
|
1066
|
-
self.
|
1169
|
+
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
|
1067
1170
|
self.bootstrap_addr
|
1068
1171
|
]
|
1069
1172
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
1070
1173
|
self.bootstrap_addr
|
1071
1174
|
]
|
1175
|
+
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
|
1176
|
+
self.bootstrap_addr
|
1177
|
+
]
|
1072
1178
|
|
1073
1179
|
# Currently, we don't allow prefill instance and decode instance to
|
1074
1180
|
# have different TP sizes per DP rank, except for models using MLA.
|
1075
|
-
|
1076
|
-
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
1077
|
-
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
1181
|
+
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
|
1078
1182
|
self.target_tp_rank = (
|
1079
|
-
self.kv_mgr.kv_args.engine_rank %
|
1183
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1080
1184
|
)
|
1081
1185
|
self.required_dst_info_num = 1
|
1082
1186
|
self.required_prefill_response_num = 1
|
1083
1187
|
self.target_tp_ranks = [self.target_tp_rank]
|
1084
|
-
elif
|
1188
|
+
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
1085
1189
|
if not self.kv_mgr.is_mla_backend:
|
1086
1190
|
logger.warning_once(
|
1087
1191
|
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1088
1192
|
)
|
1089
1193
|
self.target_tp_rank = (
|
1090
|
-
self.kv_mgr.kv_args.engine_rank %
|
1091
|
-
) // (
|
1194
|
+
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1195
|
+
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
|
1092
1196
|
self.required_dst_info_num = (
|
1093
|
-
|
1197
|
+
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
1094
1198
|
)
|
1095
1199
|
self.required_prefill_response_num = 1
|
1096
1200
|
self.target_tp_ranks = [self.target_tp_rank]
|
@@ -1103,10 +1207,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1103
1207
|
self.target_tp_ranks = [
|
1104
1208
|
rank
|
1105
1209
|
for rank in range(
|
1106
|
-
(self.kv_mgr.kv_args.engine_rank %
|
1107
|
-
* (
|
1108
|
-
(self.kv_mgr.kv_args.engine_rank %
|
1109
|
-
* (
|
1210
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
|
1211
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1212
|
+
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
|
1213
|
+
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
|
1110
1214
|
)
|
1111
1215
|
]
|
1112
1216
|
|
@@ -1116,7 +1220,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1116
1220
|
self.target_tp_rank = self.target_tp_ranks[0]
|
1117
1221
|
self.required_dst_info_num = 1
|
1118
1222
|
self.required_prefill_response_num = (
|
1119
|
-
|
1223
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1120
1224
|
)
|
1121
1225
|
|
1122
1226
|
if self.data_parallel_rank is not None:
|
@@ -1136,31 +1240,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1136
1240
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
1137
1241
|
bootstrap_infos = []
|
1138
1242
|
for target_tp_rank in self.target_tp_ranks:
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1243
|
+
for target_pp_rank in range(self.prefill_pp_size):
|
1244
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
1245
|
+
target_tp_rank, self.target_dp_group, target_pp_rank
|
1246
|
+
)
|
1247
|
+
if bootstrap_info is not None:
|
1248
|
+
if self.kv_mgr.is_mla_backend:
|
1249
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1250
|
+
bootstrap_info["is_dummy"] = not bool(
|
1251
|
+
target_tp_rank == self.target_tp_rank
|
1252
|
+
or self.target_tp_rank is None
|
1253
|
+
)
|
1254
|
+
else:
|
1255
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
1256
|
+
bootstrap_info["is_dummy"] = False
|
1257
|
+
logger.debug(
|
1258
|
+
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
|
1149
1259
|
)
|
1260
|
+
bootstrap_infos.append(bootstrap_info)
|
1150
1261
|
else:
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
else:
|
1158
|
-
self.kv_mgr.record_failure(
|
1159
|
-
self.bootstrap_room,
|
1160
|
-
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
|
1161
|
-
)
|
1162
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1163
|
-
return
|
1262
|
+
self.kv_mgr.record_failure(
|
1263
|
+
self.bootstrap_room,
|
1264
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
|
1265
|
+
)
|
1266
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
1267
|
+
return
|
1164
1268
|
|
1165
1269
|
self.bootstrap_infos = bootstrap_infos
|
1166
1270
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
@@ -1174,10 +1278,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1174
1278
|
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
|
1175
1279
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
|
1176
1280
|
|
1177
|
-
def _get_bootstrap_info_from_server(
|
1281
|
+
def _get_bootstrap_info_from_server(
|
1282
|
+
self, engine_rank, target_dp_group, target_pp_rank
|
1283
|
+
):
|
1178
1284
|
"""Fetch the bootstrap info from the bootstrap server."""
|
1179
1285
|
try:
|
1180
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
1286
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
|
1181
1287
|
response = requests.get(url, timeout=5)
|
1182
1288
|
if response.status_code == 200:
|
1183
1289
|
bootstrap_info = response.json()
|
@@ -1191,24 +1297,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1191
1297
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
1192
1298
|
return None
|
1193
1299
|
|
1194
|
-
def _get_prefill_parallel_info_from_server(
|
1300
|
+
def _get_prefill_parallel_info_from_server(
|
1301
|
+
self,
|
1302
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
1195
1303
|
"""Fetch the prefill parallel info from the bootstrap server."""
|
1196
1304
|
try:
|
1197
|
-
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
1305
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
|
1198
1306
|
response = requests.get(url)
|
1199
1307
|
if response.status_code == 200:
|
1200
1308
|
prefill_parallel_info = response.json()
|
1201
|
-
return
|
1202
|
-
prefill_parallel_info["
|
1309
|
+
return (
|
1310
|
+
int(prefill_parallel_info["prefill_attn_tp_size"]),
|
1311
|
+
int(prefill_parallel_info["prefill_dp_size"]),
|
1312
|
+
int(prefill_parallel_info["prefill_pp_size"]),
|
1203
1313
|
)
|
1204
1314
|
else:
|
1205
1315
|
logger.error(
|
1206
1316
|
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
1207
1317
|
)
|
1208
|
-
return None, None
|
1318
|
+
return None, None, None
|
1209
1319
|
except Exception as e:
|
1210
1320
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
1211
|
-
return None, None
|
1321
|
+
return None, None, None
|
1212
1322
|
|
1213
1323
|
def _register_kv_args(self):
|
1214
1324
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -1218,11 +1328,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1218
1328
|
packed_aux_data_ptrs = b"".join(
|
1219
1329
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
1220
1330
|
)
|
1331
|
+
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
|
1221
1332
|
tp_rank = self.kv_mgr.kv_args.engine_rank
|
1222
|
-
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
1223
1333
|
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
1224
1334
|
dst_tp_rank = str(tp_rank).encode("ascii")
|
1225
|
-
|
1335
|
+
dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
|
1226
1336
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
1227
1337
|
|
1228
1338
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
@@ -1236,7 +1346,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1236
1346
|
packed_kv_data_ptrs,
|
1237
1347
|
packed_aux_data_ptrs,
|
1238
1348
|
dst_tp_rank,
|
1239
|
-
|
1349
|
+
dst_attn_tp_size,
|
1240
1350
|
dst_kv_item_len,
|
1241
1351
|
]
|
1242
1352
|
)
|
@@ -1347,10 +1457,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1347
1457
|
self.store = dict()
|
1348
1458
|
self.lock = asyncio.Lock()
|
1349
1459
|
self._setup_routes()
|
1350
|
-
self.
|
1460
|
+
self.pp_size = None
|
1461
|
+
self.attn_tp_size = None
|
1351
1462
|
self.dp_size = None
|
1352
|
-
self.
|
1353
|
-
|
1463
|
+
self.prefill_port_table: Dict[
|
1464
|
+
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
|
1465
|
+
] = {}
|
1354
1466
|
|
1355
1467
|
# Start bootstrap server
|
1356
1468
|
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
@@ -1380,37 +1492,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1380
1492
|
async def _handle_route_put(self, request: web.Request):
|
1381
1493
|
data = await request.json()
|
1382
1494
|
role = data["role"]
|
1383
|
-
|
1384
|
-
|
1495
|
+
attn_tp_size = data["attn_tp_size"]
|
1496
|
+
attn_tp_rank = data["attn_tp_rank"]
|
1497
|
+
attn_dp_size = data["attn_dp_size"]
|
1498
|
+
attn_dp_rank = data["attn_dp_rank"]
|
1499
|
+
pp_size = data["pp_size"]
|
1500
|
+
pp_rank = data["pp_rank"]
|
1501
|
+
system_dp_size = data["system_dp_size"]
|
1502
|
+
system_dp_rank = data["system_dp_rank"]
|
1385
1503
|
rank_ip = data["rank_ip"]
|
1386
1504
|
rank_port = int(data["rank_port"])
|
1387
|
-
engine_rank = int(data["engine_rank"])
|
1388
1505
|
|
1389
|
-
if self.
|
1390
|
-
self.
|
1506
|
+
if self.attn_tp_size is None:
|
1507
|
+
self.attn_tp_size = attn_tp_size
|
1391
1508
|
|
1392
1509
|
if self.dp_size is None:
|
1393
|
-
self.dp_size =
|
1510
|
+
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
|
1394
1511
|
|
1395
|
-
|
1396
|
-
|
1397
|
-
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
1512
|
+
if self.pp_size is None:
|
1513
|
+
self.pp_size = pp_size
|
1398
1514
|
|
1399
1515
|
if role == "Prefill":
|
1400
|
-
|
1401
|
-
|
1516
|
+
if system_dp_size == 1:
|
1517
|
+
dp_group = attn_dp_rank
|
1518
|
+
else:
|
1519
|
+
dp_group = system_dp_rank
|
1402
1520
|
|
1403
1521
|
# Add lock to make sure thread-safe
|
1404
1522
|
async with self.lock:
|
1405
1523
|
if dp_group not in self.prefill_port_table:
|
1406
1524
|
self.prefill_port_table[dp_group] = {}
|
1525
|
+
if attn_tp_rank not in self.prefill_port_table[dp_group]:
|
1526
|
+
self.prefill_port_table[dp_group][attn_tp_rank] = {}
|
1407
1527
|
|
1408
|
-
self.prefill_port_table[dp_group][
|
1528
|
+
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
|
1409
1529
|
"rank_ip": rank_ip,
|
1410
1530
|
"rank_port": rank_port,
|
1411
1531
|
}
|
1412
1532
|
logger.debug(
|
1413
|
-
f"Register prefill bootstrap: {
|
1533
|
+
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
1414
1534
|
)
|
1415
1535
|
|
1416
1536
|
return web.Response(text="OK", status=200)
|
@@ -1418,14 +1538,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1418
1538
|
async def _handle_route_get(self, request: web.Request):
|
1419
1539
|
engine_rank = request.query.get("engine_rank")
|
1420
1540
|
target_dp_group = request.query.get("target_dp_group")
|
1421
|
-
|
1541
|
+
target_pp_rank = request.query.get("target_pp_rank")
|
1542
|
+
if not engine_rank or not target_dp_group or not target_pp_rank:
|
1422
1543
|
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
1423
1544
|
|
1424
1545
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
1425
|
-
if
|
1546
|
+
if (
|
1547
|
+
int(engine_rank) == -1
|
1548
|
+
and int(target_dp_group) == -1
|
1549
|
+
and int(target_pp_rank) == -1
|
1550
|
+
):
|
1426
1551
|
prefill_parallel_info = {
|
1427
|
-
"
|
1552
|
+
"prefill_attn_tp_size": self.attn_tp_size,
|
1428
1553
|
"prefill_dp_size": self.dp_size,
|
1554
|
+
"prefill_pp_size": self.pp_size,
|
1429
1555
|
}
|
1430
1556
|
return web.json_response(prefill_parallel_info, status=200)
|
1431
1557
|
|
@@ -1433,7 +1559,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1433
1559
|
async with self.lock:
|
1434
1560
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
1435
1561
|
int(engine_rank)
|
1436
|
-
]
|
1562
|
+
][int(target_pp_rank)]
|
1437
1563
|
|
1438
1564
|
if bootstrap_info is not None:
|
1439
1565
|
return web.json_response(bootstrap_info, status=200)
|