sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
|
|
51
51
|
return src_groups, dst_groups
|
52
52
|
|
53
53
|
|
54
|
+
# prefill
|
54
55
|
@dataclasses.dataclass
|
55
56
|
class TransferKVChunk:
|
56
57
|
room: int
|
@@ -60,6 +61,7 @@ class TransferKVChunk:
|
|
60
61
|
prefill_aux_index: Optional[int]
|
61
62
|
|
62
63
|
|
64
|
+
# decode
|
63
65
|
@dataclasses.dataclass
|
64
66
|
class TransferInfo:
|
65
67
|
room: int
|
@@ -68,19 +70,32 @@ class TransferInfo:
|
|
68
70
|
mooncake_session_id: str
|
69
71
|
dst_kv_indices: npt.NDArray[np.int64]
|
70
72
|
dst_aux_index: int
|
73
|
+
required_dst_info_num: int
|
74
|
+
is_dummy: bool
|
71
75
|
|
72
76
|
@classmethod
|
73
77
|
def from_zmq(cls, msg: List[bytes]):
|
78
|
+
if msg[4] == b"" and msg[5] == b"":
|
79
|
+
is_dummy = True
|
80
|
+
dst_kv_indices = np.array([], dtype=np.int64)
|
81
|
+
dst_aux_index = None
|
82
|
+
else:
|
83
|
+
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
|
84
|
+
dst_aux_index = int(msg[5].decode("ascii"))
|
85
|
+
is_dummy = False
|
74
86
|
return cls(
|
75
87
|
room=int(msg[0].decode("ascii")),
|
76
88
|
endpoint=msg[1].decode("ascii"),
|
77
89
|
dst_port=int(msg[2].decode("ascii")),
|
78
90
|
mooncake_session_id=msg[3].decode("ascii"),
|
79
|
-
dst_kv_indices=
|
80
|
-
dst_aux_index=
|
91
|
+
dst_kv_indices=dst_kv_indices,
|
92
|
+
dst_aux_index=dst_aux_index,
|
93
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
94
|
+
is_dummy=is_dummy,
|
81
95
|
)
|
82
96
|
|
83
97
|
|
98
|
+
# decode
|
84
99
|
@dataclasses.dataclass
|
85
100
|
class KVArgsRegisterInfo:
|
86
101
|
room: str
|
@@ -108,6 +123,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
108
123
|
args: KVArgs,
|
109
124
|
disaggregation_mode: DisaggregationMode,
|
110
125
|
server_args: ServerArgs,
|
126
|
+
is_mla_backend: Optional[bool] = False,
|
111
127
|
):
|
112
128
|
self.kv_args = args
|
113
129
|
self.engine = MooncakeTransferEngine(
|
@@ -115,6 +131,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
115
131
|
gpu_id=self.kv_args.gpu_id,
|
116
132
|
ib_device=self.kv_args.ib_device,
|
117
133
|
)
|
134
|
+
self.is_mla_backend = is_mla_backend
|
118
135
|
self.disaggregation_mode = disaggregation_mode
|
119
136
|
# for p/d multi node infer
|
120
137
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
@@ -132,7 +149,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
132
149
|
self.register_buffer_to_engine()
|
133
150
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
134
151
|
self.transfer_queue = queue.Queue()
|
135
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
152
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
136
153
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
137
154
|
self.start_prefill_thread()
|
138
155
|
self._register_to_bootstrap()
|
@@ -145,6 +162,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
145
162
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
146
163
|
self.start_decode_thread()
|
147
164
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
165
|
+
self.prefill_tp_size_table: Dict[str, int] = {}
|
148
166
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
149
167
|
else:
|
150
168
|
raise ValueError(
|
@@ -218,7 +236,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
218
236
|
status = future.result()
|
219
237
|
if status != 0:
|
220
238
|
# Immediate shutdown on first error (existing tasks will finish)
|
221
|
-
executor.shutdown(wait=False)
|
239
|
+
self.executor.shutdown(wait=False)
|
222
240
|
for f in futures:
|
223
241
|
f.cancel()
|
224
242
|
return status
|
@@ -250,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
250
268
|
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
251
269
|
[
|
252
270
|
str(room).encode("ascii"),
|
253
|
-
str(self.
|
271
|
+
str(self.check_status(room)).encode("ascii"),
|
254
272
|
]
|
255
273
|
)
|
256
274
|
|
@@ -264,8 +282,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
264
282
|
while True:
|
265
283
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
266
284
|
room = waiting_req_bytes[0].decode("ascii")
|
285
|
+
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
267
286
|
if room == "None":
|
268
|
-
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
269
287
|
self.decode_kv_args_table[mooncake_session_id] = (
|
270
288
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
271
289
|
)
|
@@ -273,53 +291,84 @@ class MooncakeKVManager(BaseKVManager):
|
|
273
291
|
f"Register KVArgs from {mooncake_session_id} successfully"
|
274
292
|
)
|
275
293
|
continue
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
294
|
+
else:
|
295
|
+
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
296
|
+
room = int(room)
|
297
|
+
if room not in self.transfer_infos:
|
298
|
+
self.transfer_infos[room] = {}
|
299
|
+
|
300
|
+
self.transfer_infos[room][mooncake_session_id] = (
|
301
|
+
TransferInfo.from_zmq(waiting_req_bytes)
|
302
|
+
)
|
303
|
+
# NOTE: after bootstrapping we can mark the req as waiting for input
|
304
|
+
if len(self.transfer_infos[room]) == required_dst_info_num:
|
305
|
+
self.update_status(room, KVPoll.WaitingForInput)
|
281
306
|
|
282
307
|
def transfer_thread():
|
283
308
|
# TODO: Shall we use KVPoll.Transferring state?
|
284
309
|
while True:
|
285
310
|
try:
|
286
311
|
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
312
|
+
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
|
313
|
+
polls = []
|
314
|
+
dst_ranks_infos = []
|
315
|
+
for req in reqs_to_be_processed:
|
316
|
+
if not req.is_dummy:
|
317
|
+
chunked_dst_kv_indice = req.dst_kv_indices[
|
318
|
+
kv_chunk.index_slice
|
319
|
+
]
|
320
|
+
assert len(chunked_dst_kv_indice) == len(
|
321
|
+
kv_chunk.prefill_kv_indices
|
322
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
323
|
+
|
324
|
+
ret = self.send_kvcache(
|
325
|
+
req.mooncake_session_id,
|
326
|
+
kv_chunk.prefill_kv_indices,
|
327
|
+
self.decode_kv_args_table[
|
328
|
+
req.mooncake_session_id
|
329
|
+
].dst_kv_ptrs,
|
330
|
+
chunked_dst_kv_indice,
|
331
|
+
)
|
332
|
+
if ret != 0:
|
333
|
+
self.update_status(kv_chunk.room, KVPoll.Failed)
|
334
|
+
self.sync_status_to_decode_endpoint(
|
335
|
+
req.endpoint, req.dst_port, req.room
|
336
|
+
)
|
337
|
+
continue
|
338
|
+
|
339
|
+
if kv_chunk.is_last:
|
340
|
+
# Only the last chunk we need to send the aux data
|
341
|
+
ret = self.send_aux(
|
342
|
+
req.mooncake_session_id,
|
343
|
+
kv_chunk.prefill_aux_index,
|
344
|
+
self.decode_kv_args_table[
|
345
|
+
req.mooncake_session_id
|
346
|
+
].dst_aux_ptrs,
|
347
|
+
req.dst_aux_index,
|
348
|
+
)
|
349
|
+
polls.append(True if ret == 0 else False)
|
350
|
+
dst_ranks_infos.append(
|
351
|
+
(req.endpoint, req.dst_port, req.room)
|
352
|
+
)
|
353
|
+
|
354
|
+
# Only sync status when all the dst ranks have received the kvcache
|
355
|
+
if len(polls) == req.required_dst_info_num:
|
356
|
+
self.update_status(
|
357
|
+
req.room,
|
358
|
+
KVPoll.Success if all(polls) else KVPoll.Failed,
|
359
|
+
)
|
360
|
+
for endpoint, dst_port, room in dst_ranks_infos:
|
361
|
+
self.sync_status_to_decode_endpoint(
|
362
|
+
endpoint, dst_port, room
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
366
|
+
# Dummy request does not need to sync status to decode endpoint
|
367
|
+
if kv_chunk.is_last:
|
368
|
+
self.update_status(req.room, KVPoll.Success)
|
369
|
+
|
370
|
+
if self.check_status(kv_chunk.room) == KVPoll.Success:
|
371
|
+
self.transfer_infos.pop(kv_chunk.room)
|
323
372
|
|
324
373
|
except queue.Empty:
|
325
374
|
continue
|
@@ -336,7 +385,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
336
385
|
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
337
386
|
status = int(status.decode("ascii"))
|
338
387
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
339
|
-
self.
|
388
|
+
self.update_status(bootstrap_room, status)
|
340
389
|
|
341
390
|
threading.Thread(target=decode_thread).start()
|
342
391
|
|
@@ -360,11 +409,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
360
409
|
prefill_aux_index=aux_index,
|
361
410
|
)
|
362
411
|
)
|
363
|
-
self.
|
412
|
+
self.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
364
413
|
|
365
414
|
def check_status(self, bootstrap_room: int):
|
366
|
-
# TOOD: do we really need the poll()?
|
367
|
-
|
368
415
|
return self.request_status[bootstrap_room]
|
369
416
|
|
370
417
|
def update_status(self, bootstrap_room: int, status: KVPoll):
|
@@ -420,6 +467,8 @@ class MooncakeKVSender(BaseKVSender):
|
|
420
467
|
self.aux_index = None
|
421
468
|
self.bootstrap_server_url = bootstrap_addr
|
422
469
|
self.session_id = self.kv_mgr.get_session_id()
|
470
|
+
# inner state
|
471
|
+
self.curr_idx = 0
|
423
472
|
|
424
473
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
425
474
|
self.num_kv_indices = num_kv_indices
|
@@ -428,9 +477,11 @@ class MooncakeKVSender(BaseKVSender):
|
|
428
477
|
def send(
|
429
478
|
self,
|
430
479
|
kv_indices: npt.NDArray[np.int64],
|
431
|
-
index_slice: slice,
|
432
|
-
is_last: bool,
|
433
480
|
):
|
481
|
+
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
482
|
+
self.curr_idx += len(kv_indices)
|
483
|
+
is_last = self.curr_idx == self.num_kv_indices
|
484
|
+
|
434
485
|
if not is_last:
|
435
486
|
self.kv_mgr.add_transfer_request(
|
436
487
|
self.bootstrap_room, kv_indices, index_slice, False
|
@@ -448,6 +499,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
448
499
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
449
500
|
|
450
501
|
def failure_exception(self):
|
502
|
+
# TODO: raise a real exception
|
451
503
|
raise Exception("Fake KVSender Exception")
|
452
504
|
|
453
505
|
|
@@ -469,54 +521,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
469
521
|
self.session_id = self.kv_mgr.get_session_id()
|
470
522
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
471
523
|
|
472
|
-
if not self.kv_mgr.
|
473
|
-
|
474
|
-
# both prefill role and decode role. If the decode instance does
|
475
|
-
# not enable dp_attention, then dp_attention is not enabled on the
|
476
|
-
# prefill instance as well. Therefore, we should skip questioning
|
477
|
-
# the prefill dp size to reduce bootstrap overhead.
|
478
|
-
self.prefill_dp_size = 1
|
479
|
-
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
480
|
-
self.prefill_dp_size, tp_size_per_dp_rank = (
|
524
|
+
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
525
|
+
self.prefill_tp_size, self.prefill_dp_size = (
|
481
526
|
self._get_prefill_dp_size_from_server()
|
482
527
|
)
|
483
|
-
|
484
|
-
# have different TP sizes per DP rank.
|
485
|
-
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
486
|
-
if self.prefill_dp_size is None:
|
528
|
+
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
487
529
|
logger.error(
|
488
|
-
f"Could not fetch prefill
|
530
|
+
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
489
531
|
)
|
490
532
|
else:
|
533
|
+
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
534
|
+
self.prefill_tp_size
|
535
|
+
)
|
491
536
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
492
537
|
self.prefill_dp_size
|
493
538
|
)
|
494
539
|
else:
|
540
|
+
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
541
|
+
self.bootstrap_addr
|
542
|
+
]
|
495
543
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
496
544
|
self.bootstrap_addr
|
497
545
|
]
|
498
546
|
|
499
|
-
#
|
547
|
+
# Currently, we don't allow prefill instance and decode instance to
|
548
|
+
# have different TP sizes per DP rank, except for models using MLA.
|
549
|
+
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
550
|
+
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
551
|
+
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
552
|
+
self.target_tp_rank = (
|
553
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
554
|
+
)
|
555
|
+
self.required_dst_info_num = 1
|
556
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
557
|
+
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
558
|
+
assert (
|
559
|
+
self.kv_mgr.is_mla_backend
|
560
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
561
|
+
self.target_tp_rank = (
|
562
|
+
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
563
|
+
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
564
|
+
self.required_dst_info_num = (
|
565
|
+
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
566
|
+
)
|
567
|
+
self.target_tp_ranks = [self.target_tp_rank]
|
568
|
+
else:
|
569
|
+
assert (
|
570
|
+
self.kv_mgr.is_mla_backend
|
571
|
+
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
572
|
+
|
573
|
+
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
574
|
+
self.target_tp_ranks = [
|
575
|
+
rank
|
576
|
+
for rank in range(
|
577
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
578
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
579
|
+
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
580
|
+
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
581
|
+
)
|
582
|
+
]
|
583
|
+
|
584
|
+
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
585
|
+
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
586
|
+
# or the KVPoll will never be set correctly
|
587
|
+
self.target_tp_rank = self.target_tp_ranks[0]
|
588
|
+
self.required_dst_info_num = 1
|
589
|
+
|
500
590
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
501
|
-
|
591
|
+
|
592
|
+
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
593
|
+
bootstrap_key = (
|
594
|
+
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
595
|
+
)
|
502
596
|
|
503
597
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
504
|
-
|
505
|
-
|
506
|
-
self.
|
507
|
-
|
508
|
-
|
598
|
+
bootstrap_infos = []
|
599
|
+
for target_tp_rank in self.target_tp_ranks:
|
600
|
+
bootstrap_info = self._get_bootstrap_info_from_server(
|
601
|
+
target_tp_rank,
|
602
|
+
self.target_dp_group,
|
603
|
+
)
|
604
|
+
if bootstrap_info is not None:
|
605
|
+
# NOTE: only support MLA for now: select one prefill rank as real rank
|
606
|
+
bootstrap_info["is_dummy"] = not bool(
|
607
|
+
target_tp_rank == self.target_tp_rank
|
608
|
+
or self.target_tp_rank is None
|
609
|
+
)
|
610
|
+
bootstrap_infos.append(bootstrap_info)
|
611
|
+
else:
|
612
|
+
logger.error(
|
613
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
614
|
+
)
|
615
|
+
self.bootstrap_infos = bootstrap_infos
|
616
|
+
|
617
|
+
if len(self.bootstrap_infos) == 0:
|
509
618
|
logger.error(
|
510
619
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
511
620
|
)
|
512
621
|
else:
|
513
|
-
self.kv_mgr.connection_pool[bootstrap_key] = self.
|
622
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
514
623
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
515
624
|
self._register_kv_args()
|
516
625
|
else:
|
517
|
-
self.
|
626
|
+
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
518
627
|
|
519
|
-
assert self.
|
628
|
+
assert len(self.bootstrap_infos) > 0
|
520
629
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
521
630
|
|
522
631
|
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
@@ -543,8 +652,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
543
652
|
response = requests.get(url)
|
544
653
|
if response.status_code == 200:
|
545
654
|
prefill_parallel_info = response.json()
|
546
|
-
return int(prefill_parallel_info["
|
547
|
-
prefill_parallel_info["
|
655
|
+
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
656
|
+
prefill_parallel_info["prefill_dp_size"]
|
548
657
|
)
|
549
658
|
else:
|
550
659
|
logger.error(
|
@@ -556,29 +665,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
556
665
|
return None
|
557
666
|
|
558
667
|
def _register_kv_args(self):
|
559
|
-
self.
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
)
|
569
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
570
|
-
with lock:
|
571
|
-
sock.send_multipart(
|
572
|
-
[
|
573
|
-
"None".encode("ascii"),
|
574
|
-
get_local_ip_by_remote().encode("ascii"),
|
575
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
576
|
-
self.session_id.encode("ascii"),
|
577
|
-
packed_kv_data_ptrs,
|
578
|
-
packed_aux_data_ptrs,
|
579
|
-
]
|
668
|
+
for bootstrap_info in self.bootstrap_infos:
|
669
|
+
self.prefill_server_url = (
|
670
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
671
|
+
)
|
672
|
+
packed_kv_data_ptrs = b"".join(
|
673
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
674
|
+
)
|
675
|
+
packed_aux_data_ptrs = b"".join(
|
676
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
580
677
|
)
|
581
678
|
|
679
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
680
|
+
with lock:
|
681
|
+
sock.send_multipart(
|
682
|
+
[
|
683
|
+
"None".encode("ascii"),
|
684
|
+
get_local_ip_by_remote().encode("ascii"),
|
685
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
686
|
+
self.session_id.encode("ascii"),
|
687
|
+
packed_kv_data_ptrs,
|
688
|
+
packed_aux_data_ptrs,
|
689
|
+
]
|
690
|
+
)
|
691
|
+
|
582
692
|
@classmethod
|
583
693
|
def _connect(cls, endpoint: str):
|
584
694
|
with cls._global_lock:
|
@@ -590,30 +700,34 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
590
700
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
591
701
|
|
592
702
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
593
|
-
self.
|
594
|
-
|
595
|
-
|
596
|
-
logger.debug(
|
597
|
-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
598
|
-
)
|
599
|
-
|
600
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
601
|
-
with lock:
|
602
|
-
sock.send_multipart(
|
603
|
-
[
|
604
|
-
str(self.bootstrap_room).encode("ascii"),
|
605
|
-
get_local_ip_by_remote().encode("ascii"),
|
606
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
607
|
-
self.session_id.encode("ascii"),
|
608
|
-
kv_indices.tobytes(),
|
609
|
-
str(aux_index).encode("ascii"),
|
610
|
-
]
|
703
|
+
for bootstrap_info in self.bootstrap_infos:
|
704
|
+
self.prefill_server_url = (
|
705
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
611
706
|
)
|
707
|
+
logger.debug(
|
708
|
+
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
709
|
+
)
|
710
|
+
is_dummy = bootstrap_info["is_dummy"]
|
711
|
+
|
712
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
713
|
+
with lock:
|
714
|
+
sock.send_multipart(
|
715
|
+
[
|
716
|
+
str(self.bootstrap_room).encode("ascii"),
|
717
|
+
get_local_ip_by_remote().encode("ascii"),
|
718
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
719
|
+
self.session_id.encode("ascii"),
|
720
|
+
kv_indices.tobytes() if not is_dummy else b"",
|
721
|
+
str(aux_index).encode("ascii") if not is_dummy else b"",
|
722
|
+
str(self.required_dst_info_num).encode("ascii"),
|
723
|
+
]
|
724
|
+
)
|
612
725
|
|
613
726
|
def poll(self) -> KVPoll:
|
614
727
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
615
728
|
|
616
729
|
def failure_exception(self):
|
730
|
+
# TODO: raise a real exception
|
617
731
|
raise Exception("Fake KVReceiver Exception")
|
618
732
|
|
619
733
|
|
@@ -624,6 +738,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
624
738
|
self.store = dict()
|
625
739
|
self.lock = asyncio.Lock()
|
626
740
|
self._setup_routes()
|
741
|
+
self.tp_size = None
|
627
742
|
self.dp_size = None
|
628
743
|
self.tp_size_per_dp_rank = None
|
629
744
|
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
@@ -658,6 +773,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
658
773
|
rank_port = int(data["rank_port"])
|
659
774
|
engine_rank = int(data["engine_rank"])
|
660
775
|
|
776
|
+
if self.tp_size is None:
|
777
|
+
self.tp_size = tp_size
|
778
|
+
|
661
779
|
if self.dp_size is None:
|
662
780
|
self.dp_size = dp_size
|
663
781
|
|
@@ -693,17 +811,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
693
811
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
694
812
|
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
695
813
|
prefill_parallel_info = {
|
814
|
+
"prefill_tp_size": self.tp_size,
|
696
815
|
"prefill_dp_size": self.dp_size,
|
697
|
-
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
698
816
|
}
|
699
817
|
return web.json_response(prefill_parallel_info, status=200)
|
700
818
|
|
701
819
|
# Find corresponding prefill info
|
702
|
-
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
703
|
-
|
704
820
|
async with self.lock:
|
705
821
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
706
|
-
|
822
|
+
int(engine_rank)
|
707
823
|
]
|
708
824
|
|
709
825
|
if bootstrap_info is not None:
|
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
|
|
61
61
|
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
62
62
|
) -> int:
|
63
63
|
"""Synchronously transfer data to the specified address."""
|
64
|
-
|
64
|
+
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
|
65
|
+
# later: based on the cached queue pair to send data
|
65
66
|
ret = self.engine.transfer_sync_write(
|
66
67
|
session_id, buffer, peer_buffer_address, length
|
67
68
|
)
|
@@ -35,29 +35,19 @@ logger = logging.getLogger(__name__)
|
|
35
35
|
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
36
36
|
|
37
37
|
|
38
|
-
# From Mooncake backend.
|
39
38
|
def group_concurrent_contiguous(
|
40
39
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
41
40
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
50
|
-
if src_contiguous and dst_contiguous:
|
51
|
-
current_src.append(src_indices[i])
|
52
|
-
current_dst.append(dst_indices[i])
|
53
|
-
else:
|
54
|
-
src_groups.append(current_src)
|
55
|
-
dst_groups.append(current_dst)
|
56
|
-
current_src = [src_indices[i]]
|
57
|
-
current_dst = [dst_indices[i]]
|
41
|
+
"""Vectorised NumPy implementation."""
|
42
|
+
if src_indices.size == 0:
|
43
|
+
return [], []
|
44
|
+
|
45
|
+
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
46
|
+
src_groups = np.split(src_indices, brk)
|
47
|
+
dst_groups = np.split(dst_indices, brk)
|
58
48
|
|
59
|
-
src_groups.
|
60
|
-
dst_groups.
|
49
|
+
src_groups = [g.tolist() for g in src_groups]
|
50
|
+
dst_groups = [g.tolist() for g in dst_groups]
|
61
51
|
|
62
52
|
return src_groups, dst_groups
|
63
53
|
|
@@ -132,6 +122,7 @@ class NixlKVManager(BaseKVManager):
|
|
132
122
|
args: KVArgs,
|
133
123
|
disaggregation_mode: DisaggregationMode,
|
134
124
|
server_args: ServerArgs,
|
125
|
+
is_mla_backend: Optional[bool] = False,
|
135
126
|
):
|
136
127
|
try:
|
137
128
|
from nixl._api import nixl_agent
|