sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/srt/conversation.py
CHANGED
sglang/srt/custom_op.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
from torch import nn
|
2
2
|
|
3
|
-
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
3
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
4
4
|
|
5
5
|
_is_cuda = is_cuda()
|
6
6
|
_is_hip = is_hip()
|
7
7
|
_is_cpu = is_cpu()
|
8
8
|
_is_cpu_amx_available = cpu_has_amx_support()
|
9
|
+
_is_npu = is_npu()
|
9
10
|
|
10
11
|
|
11
12
|
class CustomOp(nn.Module):
|
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
|
|
60
61
|
def forward_cuda(self, *args, **kwargs):
|
61
62
|
raise NotImplementedError
|
62
63
|
|
64
|
+
def forward_npu(self, *args, **kwargs):
|
65
|
+
raise NotImplementedError
|
66
|
+
|
63
67
|
def forward_hip(self, *args, **kwargs):
|
64
68
|
return self.forward_cuda(*args, **kwargs)
|
65
69
|
|
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
|
|
79
83
|
return self.forward_hip
|
80
84
|
elif _is_cpu and _is_cpu_amx_available:
|
81
85
|
return self.forward_cpu
|
86
|
+
elif _is_npu:
|
87
|
+
return self.forward_npu
|
82
88
|
else:
|
83
89
|
return self.forward_native
|
@@ -579,11 +579,11 @@ class DecodeTransferQueue:
|
|
579
579
|
idx = decode_req.metadata_buffer_index
|
580
580
|
(
|
581
581
|
output_id,
|
582
|
-
output_hidden_states,
|
583
582
|
output_token_logprobs_val,
|
584
583
|
output_token_logprobs_idx,
|
585
584
|
output_top_logprobs_val,
|
586
585
|
output_top_logprobs_idx,
|
586
|
+
output_hidden_states,
|
587
587
|
) = self.metadata_buffers.get_buf(idx)
|
588
588
|
|
589
589
|
decode_req.req.output_ids.append(output_id[0].item())
|
@@ -103,6 +103,9 @@ class KVArgsRegisterInfo:
|
|
103
103
|
mooncake_session_id: str
|
104
104
|
dst_kv_ptrs: list[int]
|
105
105
|
dst_aux_ptrs: list[int]
|
106
|
+
dst_tp_rank: int
|
107
|
+
dst_tp_size: int
|
108
|
+
dst_kv_item_len: int
|
106
109
|
|
107
110
|
@classmethod
|
108
111
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -113,6 +116,9 @@ class KVArgsRegisterInfo:
|
|
113
116
|
mooncake_session_id=msg[3].decode("ascii"),
|
114
117
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
115
118
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
119
|
+
dst_tp_rank=int(msg[6].decode("ascii")),
|
120
|
+
dst_tp_size=int(msg[7].decode("ascii")),
|
121
|
+
dst_kv_item_len=int(msg[8].decode("ascii")),
|
116
122
|
)
|
117
123
|
|
118
124
|
|
@@ -181,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
181
187
|
).start()
|
182
188
|
|
183
189
|
self.bootstrap_time_out = get_int_env_var(
|
184
|
-
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT",
|
190
|
+
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
|
185
191
|
)
|
186
192
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
187
193
|
self.heartbeat_failures = {}
|
@@ -189,6 +195,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
189
195
|
self.session_pool_lock = threading.Lock()
|
190
196
|
self.addr_to_rooms_tracker = defaultdict(set)
|
191
197
|
self.connection_lock = threading.Lock()
|
198
|
+
self.required_prefill_response_num_table: Dict[int, int] = {}
|
199
|
+
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
192
200
|
# Heartbeat interval should be at least 2 seconds
|
193
201
|
self.heartbeat_interval = max(
|
194
202
|
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
@@ -251,17 +259,19 @@ class MooncakeKVManager(BaseKVManager):
|
|
251
259
|
|
252
260
|
# Worker function for processing a single layer
|
253
261
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
262
|
+
src_addr_list = []
|
263
|
+
dst_addr_list = []
|
264
|
+
length_list = []
|
254
265
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
255
266
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
256
267
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
257
268
|
length = item_len * len(prefill_index)
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
return 0
|
269
|
+
src_addr_list.append(src_addr)
|
270
|
+
dst_addr_list.append(dst_addr)
|
271
|
+
length_list.append(length)
|
272
|
+
return self.engine.batch_transfer_sync(
|
273
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
274
|
+
)
|
265
275
|
|
266
276
|
futures = [
|
267
277
|
executor.submit(
|
@@ -282,6 +292,162 @@ class MooncakeKVManager(BaseKVManager):
|
|
282
292
|
|
283
293
|
return 0
|
284
294
|
|
295
|
+
def send_kvcache_slice(
|
296
|
+
self,
|
297
|
+
mooncake_session_id: str,
|
298
|
+
prefill_kv_indices: npt.NDArray[np.int64],
|
299
|
+
dst_kv_ptrs: list[int],
|
300
|
+
dst_kv_indices: npt.NDArray[np.int64],
|
301
|
+
dst_tp_rank: int,
|
302
|
+
dst_tp_size: int,
|
303
|
+
dst_kv_item_len: int,
|
304
|
+
executor: concurrent.futures.ThreadPoolExecutor,
|
305
|
+
):
|
306
|
+
"""
|
307
|
+
Sends KV cache slices from this Prefill rank to a target Decode rank,
|
308
|
+
supporting generic M-to-N TP size configurations.
|
309
|
+
|
310
|
+
NOTE: This implementation calls the transfer engine for each token slot within
|
311
|
+
each page to ensure correctness for any page_size and head-slicing configuration.
|
312
|
+
This may introduce performance overhead (increased TTFT) for long sequences.
|
313
|
+
"""
|
314
|
+
# Extract configuration
|
315
|
+
local_tp_rank = self.kv_args.engine_rank
|
316
|
+
local_tp_size = self.tp_size // self.dp_size
|
317
|
+
num_kv_heads = self.kv_args.kv_head_num
|
318
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
319
|
+
page_size = self.kv_args.page_size
|
320
|
+
|
321
|
+
# Calculate head distribution
|
322
|
+
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
|
323
|
+
heads_per_prefill_rank = num_kv_heads
|
324
|
+
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
|
325
|
+
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
|
326
|
+
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
|
327
|
+
|
328
|
+
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
329
|
+
|
330
|
+
# Determine slicing parameters based on TP configuration
|
331
|
+
if local_tp_size > dst_tp_size:
|
332
|
+
src_head_offset = 0
|
333
|
+
num_heads_to_send = heads_per_prefill_rank
|
334
|
+
dst_head_offset = prefill_global_head_start - decode_global_head_start
|
335
|
+
else:
|
336
|
+
src_head_offset = decode_global_head_start - prefill_global_head_start
|
337
|
+
num_heads_to_send = heads_per_decode_rank
|
338
|
+
dst_head_offset = 0
|
339
|
+
|
340
|
+
layer_transfer_params = []
|
341
|
+
for layer_id in range(num_layers):
|
342
|
+
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
|
343
|
+
|
344
|
+
# Page stride on the target dst decode rank for its slice pages
|
345
|
+
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
|
346
|
+
|
347
|
+
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
|
348
|
+
logger.error(
|
349
|
+
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
|
350
|
+
)
|
351
|
+
return -1
|
352
|
+
|
353
|
+
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
354
|
+
src_slice_offset = src_head_offset * bytes_per_head
|
355
|
+
dst_slice_offset = dst_head_offset * bytes_per_head
|
356
|
+
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
357
|
+
|
358
|
+
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
|
359
|
+
# This means slice_lens_per_page <= item_len_of_decode_rank_page
|
360
|
+
if slice_lens_per_page > item_len_of_decode_rank_page:
|
361
|
+
logger.error(
|
362
|
+
f"[{mooncake_session_id}] Layer {layer_id}: "
|
363
|
+
f"slice size ({slice_lens_per_page}) exceeds "
|
364
|
+
f"target page size ({item_len_of_decode_rank_page})"
|
365
|
+
)
|
366
|
+
return -1
|
367
|
+
layer_transfer_params.append(
|
368
|
+
(
|
369
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
370
|
+
dst_kv_ptrs[layer_id],
|
371
|
+
item_len_of_prefill_rank_page,
|
372
|
+
item_len_of_decode_rank_page,
|
373
|
+
src_slice_offset,
|
374
|
+
dst_slice_offset,
|
375
|
+
slice_lens_per_page,
|
376
|
+
)
|
377
|
+
)
|
378
|
+
|
379
|
+
def process_layer_tp_aware(layer_params):
|
380
|
+
(
|
381
|
+
src_ptr,
|
382
|
+
dst_ptr,
|
383
|
+
src_item_len,
|
384
|
+
dst_item_len,
|
385
|
+
src_offset,
|
386
|
+
dst_offset,
|
387
|
+
slice_lens_per_page,
|
388
|
+
) = layer_params
|
389
|
+
src_addr_list = []
|
390
|
+
dst_addr_list = []
|
391
|
+
length_list = []
|
392
|
+
|
393
|
+
# Calculate strides for a single token slot
|
394
|
+
bytes_per_token_on_prefill = src_item_len // page_size
|
395
|
+
bytes_per_token_on_decode = dst_item_len // page_size
|
396
|
+
|
397
|
+
for i in range(len(prefill_kv_indices)):
|
398
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
399
|
+
decode_page_idx = int(dst_kv_indices[i])
|
400
|
+
|
401
|
+
# Get the starting addresses for the current src and dst pages
|
402
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
|
403
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
|
404
|
+
|
405
|
+
# Iterate through each valid token slot within the current page
|
406
|
+
for token_slot_in_page in range(page_size):
|
407
|
+
# Calculate the start address of the current token slot
|
408
|
+
src_token_slot_start_addr = (
|
409
|
+
src_page_start_addr
|
410
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
411
|
+
)
|
412
|
+
dst_token_slot_start_addr = (
|
413
|
+
dst_page_start_addr
|
414
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
415
|
+
)
|
416
|
+
|
417
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
418
|
+
src_slice_addr = src_token_slot_start_addr + src_offset
|
419
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_offset
|
420
|
+
|
421
|
+
src_addr_list.append(src_slice_addr)
|
422
|
+
dst_addr_list.append(dst_slice_addr)
|
423
|
+
length_list.append(slice_lens_per_page)
|
424
|
+
|
425
|
+
logger.debug(
|
426
|
+
f"SYNC: sid={mooncake_session_id}, "
|
427
|
+
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
|
428
|
+
)
|
429
|
+
|
430
|
+
return self.engine.batch_transfer_sync(
|
431
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
432
|
+
)
|
433
|
+
|
434
|
+
futures = [
|
435
|
+
executor.submit(
|
436
|
+
process_layer_tp_aware,
|
437
|
+
layer_params,
|
438
|
+
)
|
439
|
+
for layer_params in layer_transfer_params
|
440
|
+
]
|
441
|
+
|
442
|
+
for future in concurrent.futures.as_completed(futures):
|
443
|
+
status = future.result()
|
444
|
+
if status != 0:
|
445
|
+
for f in futures:
|
446
|
+
f.cancel()
|
447
|
+
return status
|
448
|
+
|
449
|
+
return 0
|
450
|
+
|
285
451
|
def send_aux(
|
286
452
|
self,
|
287
453
|
mooncake_session_id: str,
|
@@ -289,18 +455,24 @@ class MooncakeKVManager(BaseKVManager):
|
|
289
455
|
dst_aux_ptrs: list[int],
|
290
456
|
dst_aux_index: int,
|
291
457
|
):
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
458
|
+
src_addr_list = []
|
459
|
+
dst_addr_list = []
|
460
|
+
length_list = []
|
461
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
462
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
463
|
+
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
464
|
+
length = prefill_aux_item_lens[i]
|
465
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
466
|
+
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
467
|
+
src_addr_list.append(src_addr)
|
468
|
+
dst_addr_list.append(dst_addr)
|
469
|
+
length_list.append(length)
|
470
|
+
return self.engine.batch_transfer_sync(
|
471
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
299
472
|
)
|
300
|
-
return status
|
301
473
|
|
302
474
|
def sync_status_to_decode_endpoint(
|
303
|
-
self, remote: str, dst_port: int, room: int, status: int
|
475
|
+
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
304
476
|
):
|
305
477
|
if ":" in remote:
|
306
478
|
remote = remote.split(":")[0]
|
@@ -308,6 +480,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
308
480
|
[
|
309
481
|
str(room).encode("ascii"),
|
310
482
|
str(status).encode("ascii"),
|
483
|
+
str(prefill_rank).encode("ascii"),
|
311
484
|
]
|
312
485
|
)
|
313
486
|
|
@@ -324,6 +497,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
324
497
|
)
|
325
498
|
polls = []
|
326
499
|
dst_ranks_infos = []
|
500
|
+
local_rank = self.kv_args.engine_rank
|
327
501
|
for req in reqs_to_be_processed:
|
328
502
|
if not req.is_dummy:
|
329
503
|
# Early exit if the request has failed
|
@@ -339,6 +513,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
339
513
|
req.dst_port,
|
340
514
|
req.room,
|
341
515
|
KVPoll.Failed,
|
516
|
+
local_rank,
|
342
517
|
)
|
343
518
|
break
|
344
519
|
|
@@ -356,15 +531,31 @@ class MooncakeKVManager(BaseKVManager):
|
|
356
531
|
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
357
532
|
)
|
358
533
|
|
359
|
-
|
360
|
-
req.mooncake_session_id
|
361
|
-
kv_chunk.prefill_kv_indices,
|
362
|
-
self.decode_kv_args_table[
|
363
|
-
req.mooncake_session_id
|
364
|
-
].dst_kv_ptrs,
|
365
|
-
chunked_dst_kv_indice,
|
366
|
-
executor,
|
534
|
+
target_rank_registration_info: KVArgsRegisterInfo = (
|
535
|
+
self.decode_kv_args_table[req.mooncake_session_id]
|
367
536
|
)
|
537
|
+
local_tp_size = self.tp_size // self.dp_size
|
538
|
+
if self.is_mla_backend or (
|
539
|
+
local_tp_size == target_rank_registration_info.dst_tp_size
|
540
|
+
):
|
541
|
+
ret = self.send_kvcache(
|
542
|
+
req.mooncake_session_id,
|
543
|
+
kv_chunk.prefill_kv_indices,
|
544
|
+
target_rank_registration_info.dst_kv_ptrs,
|
545
|
+
chunked_dst_kv_indice,
|
546
|
+
executor,
|
547
|
+
)
|
548
|
+
else:
|
549
|
+
ret = self.send_kvcache_slice(
|
550
|
+
req.mooncake_session_id,
|
551
|
+
kv_chunk.prefill_kv_indices,
|
552
|
+
target_rank_registration_info.dst_kv_ptrs,
|
553
|
+
chunked_dst_kv_indice,
|
554
|
+
target_rank_registration_info.dst_tp_rank,
|
555
|
+
target_rank_registration_info.dst_tp_size,
|
556
|
+
target_rank_registration_info.dst_kv_item_len,
|
557
|
+
executor,
|
558
|
+
)
|
368
559
|
if ret != 0:
|
369
560
|
with self.session_lock:
|
370
561
|
self.session_failures[req.mooncake_session_id] += 1
|
@@ -380,7 +571,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
380
571
|
)
|
381
572
|
self.update_status(kv_chunk.room, KVPoll.Failed)
|
382
573
|
self.sync_status_to_decode_endpoint(
|
383
|
-
req.endpoint,
|
574
|
+
req.endpoint,
|
575
|
+
req.dst_port,
|
576
|
+
req.room,
|
577
|
+
KVPoll.Failed,
|
578
|
+
local_rank,
|
384
579
|
)
|
385
580
|
break
|
386
581
|
|
@@ -389,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
389
584
|
ret = self.send_aux(
|
390
585
|
req.mooncake_session_id,
|
391
586
|
kv_chunk.prefill_aux_index,
|
392
|
-
|
393
|
-
req.mooncake_session_id
|
394
|
-
].dst_aux_ptrs,
|
587
|
+
target_rank_registration_info.dst_aux_ptrs,
|
395
588
|
req.dst_aux_index,
|
396
589
|
)
|
397
590
|
polls.append(True if ret == 0 else False)
|
@@ -405,7 +598,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
405
598
|
self.update_status(req.room, status)
|
406
599
|
for endpoint, dst_port, room in dst_ranks_infos:
|
407
600
|
self.sync_status_to_decode_endpoint(
|
408
|
-
endpoint, dst_port, room, status
|
601
|
+
endpoint, dst_port, room, status, local_rank
|
409
602
|
)
|
410
603
|
else:
|
411
604
|
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
@@ -471,15 +664,33 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
664
|
|
472
665
|
def decode_thread():
|
473
666
|
while True:
|
474
|
-
(bootstrap_room, status) =
|
667
|
+
(bootstrap_room, status, prefill_rank) = (
|
668
|
+
self.server_socket.recv_multipart()
|
669
|
+
)
|
475
670
|
status = int(status.decode("ascii"))
|
476
671
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
477
|
-
|
672
|
+
prefill_rank = int(prefill_rank.decode("ascii"))
|
673
|
+
|
674
|
+
if status == KVPoll.Success:
|
675
|
+
if bootstrap_room in self.request_status:
|
676
|
+
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
|
677
|
+
expected_response_num = (
|
678
|
+
self.required_prefill_response_num_table[bootstrap_room]
|
679
|
+
)
|
680
|
+
arrived_response_num = len(
|
681
|
+
self.prefill_response_tracker[bootstrap_room]
|
682
|
+
)
|
683
|
+
if (
|
684
|
+
self.is_mla_backend
|
685
|
+
or arrived_response_num == expected_response_num
|
686
|
+
):
|
687
|
+
self.update_status(bootstrap_room, KVPoll.Success)
|
688
|
+
elif status == KVPoll.Failed:
|
478
689
|
self.record_failure(
|
479
690
|
bootstrap_room,
|
480
691
|
f"Failed to get kvcache from prefill instance, it might be dead",
|
481
692
|
)
|
482
|
-
|
693
|
+
self.update_status(bootstrap_room, status)
|
483
694
|
|
484
695
|
def heartbeat_checker():
|
485
696
|
while True:
|
@@ -686,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
|
|
686
897
|
self.aux_index = None
|
687
898
|
self.bootstrap_server_url = bootstrap_addr
|
688
899
|
self.conclude_state = None
|
689
|
-
self.init_time =
|
900
|
+
self.init_time = time.time()
|
690
901
|
# inner state
|
691
902
|
self.curr_idx = 0
|
692
903
|
|
693
904
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
694
905
|
self.num_kv_indices = num_kv_indices
|
695
906
|
self.aux_index = aux_index
|
696
|
-
self.init_time = time.time()
|
697
907
|
|
698
908
|
def send(
|
699
909
|
self,
|
@@ -705,7 +915,10 @@ class MooncakeKVSender(BaseKVSender):
|
|
705
915
|
|
706
916
|
if not is_last:
|
707
917
|
self.kv_mgr.add_transfer_request(
|
708
|
-
self.bootstrap_room,
|
918
|
+
self.bootstrap_room,
|
919
|
+
kv_indices,
|
920
|
+
index_slice,
|
921
|
+
False,
|
709
922
|
)
|
710
923
|
else:
|
711
924
|
self.kv_mgr.add_transfer_request(
|
@@ -814,23 +1027,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
814
1027
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
815
1028
|
)
|
816
1029
|
self.required_dst_info_num = 1
|
1030
|
+
self.required_prefill_response_num = 1
|
817
1031
|
self.target_tp_ranks = [self.target_tp_rank]
|
818
1032
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
819
|
-
|
820
|
-
|
821
|
-
|
1033
|
+
if not self.kv_mgr.is_mla_backend:
|
1034
|
+
logger.warning_once(
|
1035
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1036
|
+
)
|
822
1037
|
self.target_tp_rank = (
|
823
1038
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
824
1039
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
825
1040
|
self.required_dst_info_num = (
|
826
1041
|
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
827
1042
|
)
|
1043
|
+
self.required_prefill_response_num = 1
|
828
1044
|
self.target_tp_ranks = [self.target_tp_rank]
|
829
1045
|
else:
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
1046
|
+
if not self.kv_mgr.is_mla_backend:
|
1047
|
+
logger.warning_once(
|
1048
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1049
|
+
)
|
834
1050
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
835
1051
|
self.target_tp_ranks = [
|
836
1052
|
rank
|
@@ -847,6 +1063,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
847
1063
|
# or the KVPoll will never be set correctly
|
848
1064
|
self.target_tp_rank = self.target_tp_ranks[0]
|
849
1065
|
self.required_dst_info_num = 1
|
1066
|
+
self.required_prefill_response_num = (
|
1067
|
+
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
|
1068
|
+
)
|
850
1069
|
|
851
1070
|
if self.data_parallel_rank is not None:
|
852
1071
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
@@ -854,6 +1073,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
854
1073
|
else:
|
855
1074
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
856
1075
|
|
1076
|
+
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
1077
|
+
self.required_prefill_response_num
|
1078
|
+
)
|
857
1079
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
858
1080
|
bootstrap_key = (
|
859
1081
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
@@ -867,11 +1089,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
867
1089
|
self.target_dp_group,
|
868
1090
|
)
|
869
1091
|
if bootstrap_info is not None:
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
1092
|
+
if self.kv_mgr.is_mla_backend:
|
1093
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1094
|
+
bootstrap_info["is_dummy"] = not bool(
|
1095
|
+
target_tp_rank == self.target_tp_rank
|
1096
|
+
or self.target_tp_rank is None
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
1100
|
+
bootstrap_info["is_dummy"] = False
|
875
1101
|
logger.debug(
|
876
1102
|
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
|
877
1103
|
)
|
@@ -943,6 +1169,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
943
1169
|
packed_aux_data_ptrs = b"".join(
|
944
1170
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
945
1171
|
)
|
1172
|
+
tp_rank = self.kv_mgr.kv_args.engine_rank
|
1173
|
+
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
1174
|
+
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
1175
|
+
dst_tp_rank = str(tp_rank).encode("ascii")
|
1176
|
+
dst_tp_size = str(tp_size).encode("ascii")
|
1177
|
+
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
946
1178
|
|
947
1179
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
948
1180
|
with lock:
|
@@ -954,6 +1186,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
954
1186
|
self.session_id.encode("ascii"),
|
955
1187
|
packed_kv_data_ptrs,
|
956
1188
|
packed_aux_data_ptrs,
|
1189
|
+
dst_tp_rank,
|
1190
|
+
dst_tp_size,
|
1191
|
+
dst_kv_item_len,
|
957
1192
|
]
|
958
1193
|
)
|
959
1194
|
|
@@ -1002,6 +1237,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1002
1237
|
if self.bootstrap_room in self.kv_mgr.request_status:
|
1003
1238
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
1004
1239
|
|
1240
|
+
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
|
1241
|
+
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
|
1242
|
+
|
1243
|
+
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
|
1244
|
+
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
|
1245
|
+
|
1005
1246
|
def failure_exception(self):
|
1006
1247
|
# Explicitly set the status to failure since this request has failed in another rank
|
1007
1248
|
if self.conclude_state is None:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Optional
|
4
|
+
from typing import List, Optional
|
5
5
|
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
@@ -90,5 +90,35 @@ class MooncakeTransferEngine:
|
|
90
90
|
|
91
91
|
return ret
|
92
92
|
|
93
|
+
def batch_transfer_sync(
|
94
|
+
self,
|
95
|
+
session_id: str,
|
96
|
+
buffers: List[int],
|
97
|
+
peer_buffer_addresses: List[int],
|
98
|
+
lengths: List[int],
|
99
|
+
) -> int:
|
100
|
+
"""Synchronously transfer data to the specified addresses in batches."""
|
101
|
+
try:
|
102
|
+
ret = self.engine.batch_transfer_sync_write(
|
103
|
+
session_id, buffers, peer_buffer_addresses, lengths
|
104
|
+
)
|
105
|
+
except Exception:
|
106
|
+
ret = -1
|
107
|
+
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
|
108
|
+
if not hasattr(self.engine, "batch_transfer_sync_write"):
|
109
|
+
raise RuntimeError(
|
110
|
+
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
|
111
|
+
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
|
112
|
+
)
|
113
|
+
|
114
|
+
if ret < 0:
|
115
|
+
logger.debug(
|
116
|
+
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
|
117
|
+
buffers,
|
118
|
+
session_id,
|
119
|
+
peer_buffer_addresses,
|
120
|
+
)
|
121
|
+
return ret
|
122
|
+
|
93
123
|
def get_session_id(self):
|
94
124
|
return self.session_id
|