sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
39
|
-
get_free_port,
|
40
|
-
get_int_env_var,
|
41
|
-
get_ip,
|
42
|
-
get_local_ip_by_remote,
|
43
|
-
)
|
38
|
+
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
44
39
|
|
45
40
|
logger = logging.getLogger(__name__)
|
46
41
|
|
@@ -108,6 +103,9 @@ class KVArgsRegisterInfo:
|
|
108
103
|
mooncake_session_id: str
|
109
104
|
dst_kv_ptrs: list[int]
|
110
105
|
dst_aux_ptrs: list[int]
|
106
|
+
dst_tp_rank: int
|
107
|
+
dst_tp_size: int
|
108
|
+
dst_kv_item_len: int
|
111
109
|
|
112
110
|
@classmethod
|
113
111
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -118,6 +116,9 @@ class KVArgsRegisterInfo:
|
|
118
116
|
mooncake_session_id=msg[3].decode("ascii"),
|
119
117
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
120
118
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
119
|
+
dst_tp_rank=int(msg[6].decode("ascii")),
|
120
|
+
dst_tp_size=int(msg[7].decode("ascii")),
|
121
|
+
dst_kv_item_len=int(msg[8].decode("ascii")),
|
121
122
|
)
|
122
123
|
|
123
124
|
|
@@ -130,8 +131,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
130
131
|
is_mla_backend: Optional[bool] = False,
|
131
132
|
):
|
132
133
|
self.kv_args = args
|
134
|
+
self.local_ip = get_local_ip_auto()
|
133
135
|
self.engine = MooncakeTransferEngine(
|
134
|
-
hostname=
|
136
|
+
hostname=self.local_ip,
|
135
137
|
gpu_id=self.kv_args.gpu_id,
|
136
138
|
ib_device=self.kv_args.ib_device,
|
137
139
|
)
|
@@ -185,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
185
187
|
).start()
|
186
188
|
|
187
189
|
self.bootstrap_time_out = get_int_env_var(
|
188
|
-
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT",
|
190
|
+
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
|
189
191
|
)
|
190
192
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
191
193
|
self.heartbeat_failures = {}
|
@@ -193,6 +195,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
193
195
|
self.session_pool_lock = threading.Lock()
|
194
196
|
self.addr_to_rooms_tracker = defaultdict(set)
|
195
197
|
self.connection_lock = threading.Lock()
|
198
|
+
self.required_prefill_response_num_table: Dict[int, int] = {}
|
199
|
+
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
196
200
|
# Heartbeat interval should be at least 2 seconds
|
197
201
|
self.heartbeat_interval = max(
|
198
202
|
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
@@ -255,17 +259,19 @@ class MooncakeKVManager(BaseKVManager):
|
|
255
259
|
|
256
260
|
# Worker function for processing a single layer
|
257
261
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
262
|
+
src_addr_list = []
|
263
|
+
dst_addr_list = []
|
264
|
+
length_list = []
|
258
265
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
259
266
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
260
267
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
261
268
|
length = item_len * len(prefill_index)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
return 0
|
269
|
+
src_addr_list.append(src_addr)
|
270
|
+
dst_addr_list.append(dst_addr)
|
271
|
+
length_list.append(length)
|
272
|
+
return self.engine.batch_transfer_sync(
|
273
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
274
|
+
)
|
269
275
|
|
270
276
|
futures = [
|
271
277
|
executor.submit(
|
@@ -286,6 +292,162 @@ class MooncakeKVManager(BaseKVManager):
|
|
286
292
|
|
287
293
|
return 0
|
288
294
|
|
295
|
+
def send_kvcache_slice(
|
296
|
+
self,
|
297
|
+
mooncake_session_id: str,
|
298
|
+
prefill_kv_indices: npt.NDArray[np.int64],
|
299
|
+
dst_kv_ptrs: list[int],
|
300
|
+
dst_kv_indices: npt.NDArray[np.int64],
|
301
|
+
dst_tp_rank: int,
|
302
|
+
dst_tp_size: int,
|
303
|
+
dst_kv_item_len: int,
|
304
|
+
executor: concurrent.futures.ThreadPoolExecutor,
|
305
|
+
):
|
306
|
+
"""
|
307
|
+
Sends KV cache slices from this Prefill rank to a target Decode rank,
|
308
|
+
supporting generic M-to-N TP size configurations.
|
309
|
+
|
310
|
+
NOTE: This implementation calls the transfer engine for each token slot within
|
311
|
+
each page to ensure correctness for any page_size and head-slicing configuration.
|
312
|
+
This may introduce performance overhead (increased TTFT) for long sequences.
|
313
|
+
"""
|
314
|
+
# Extract configuration
|
315
|
+
local_tp_rank = self.kv_args.engine_rank
|
316
|
+
local_tp_size = self.tp_size // self.dp_size
|
317
|
+
num_kv_heads = self.kv_args.kv_head_num
|
318
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
319
|
+
page_size = self.kv_args.page_size
|
320
|
+
|
321
|
+
# Calculate head distribution
|
322
|
+
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
|
323
|
+
heads_per_prefill_rank = num_kv_heads
|
324
|
+
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
|
325
|
+
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
|
326
|
+
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
|
327
|
+
|
328
|
+
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
329
|
+
|
330
|
+
# Determine slicing parameters based on TP configuration
|
331
|
+
if local_tp_size > dst_tp_size:
|
332
|
+
src_head_offset = 0
|
333
|
+
num_heads_to_send = heads_per_prefill_rank
|
334
|
+
dst_head_offset = prefill_global_head_start - decode_global_head_start
|
335
|
+
else:
|
336
|
+
src_head_offset = decode_global_head_start - prefill_global_head_start
|
337
|
+
num_heads_to_send = heads_per_decode_rank
|
338
|
+
dst_head_offset = 0
|
339
|
+
|
340
|
+
layer_transfer_params = []
|
341
|
+
for layer_id in range(num_layers):
|
342
|
+
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
|
343
|
+
|
344
|
+
# Page stride on the target dst decode rank for its slice pages
|
345
|
+
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
|
346
|
+
|
347
|
+
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
|
348
|
+
logger.error(
|
349
|
+
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
|
350
|
+
)
|
351
|
+
return -1
|
352
|
+
|
353
|
+
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
354
|
+
src_slice_offset = src_head_offset * bytes_per_head
|
355
|
+
dst_slice_offset = dst_head_offset * bytes_per_head
|
356
|
+
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
357
|
+
|
358
|
+
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
|
359
|
+
# This means slice_lens_per_page <= item_len_of_decode_rank_page
|
360
|
+
if slice_lens_per_page > item_len_of_decode_rank_page:
|
361
|
+
logger.error(
|
362
|
+
f"[{mooncake_session_id}] Layer {layer_id}: "
|
363
|
+
f"slice size ({slice_lens_per_page}) exceeds "
|
364
|
+
f"target page size ({item_len_of_decode_rank_page})"
|
365
|
+
)
|
366
|
+
return -1
|
367
|
+
layer_transfer_params.append(
|
368
|
+
(
|
369
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
370
|
+
dst_kv_ptrs[layer_id],
|
371
|
+
item_len_of_prefill_rank_page,
|
372
|
+
item_len_of_decode_rank_page,
|
373
|
+
src_slice_offset,
|
374
|
+
dst_slice_offset,
|
375
|
+
slice_lens_per_page,
|
376
|
+
)
|
377
|
+
)
|
378
|
+
|
379
|
+
def process_layer_tp_aware(layer_params):
|
380
|
+
(
|
381
|
+
src_ptr,
|
382
|
+
dst_ptr,
|
383
|
+
src_item_len,
|
384
|
+
dst_item_len,
|
385
|
+
src_offset,
|
386
|
+
dst_offset,
|
387
|
+
slice_lens_per_page,
|
388
|
+
) = layer_params
|
389
|
+
src_addr_list = []
|
390
|
+
dst_addr_list = []
|
391
|
+
length_list = []
|
392
|
+
|
393
|
+
# Calculate strides for a single token slot
|
394
|
+
bytes_per_token_on_prefill = src_item_len // page_size
|
395
|
+
bytes_per_token_on_decode = dst_item_len // page_size
|
396
|
+
|
397
|
+
for i in range(len(prefill_kv_indices)):
|
398
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
399
|
+
decode_page_idx = int(dst_kv_indices[i])
|
400
|
+
|
401
|
+
# Get the starting addresses for the current src and dst pages
|
402
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
|
403
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
|
404
|
+
|
405
|
+
# Iterate through each valid token slot within the current page
|
406
|
+
for token_slot_in_page in range(page_size):
|
407
|
+
# Calculate the start address of the current token slot
|
408
|
+
src_token_slot_start_addr = (
|
409
|
+
src_page_start_addr
|
410
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
411
|
+
)
|
412
|
+
dst_token_slot_start_addr = (
|
413
|
+
dst_page_start_addr
|
414
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
415
|
+
)
|
416
|
+
|
417
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
418
|
+
src_slice_addr = src_token_slot_start_addr + src_offset
|
419
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_offset
|
420
|
+
|
421
|
+
src_addr_list.append(src_slice_addr)
|
422
|
+
dst_addr_list.append(dst_slice_addr)
|
423
|
+
length_list.append(slice_lens_per_page)
|
424
|
+
|
425
|
+
logger.debug(
|
426
|
+
f"SYNC: sid={mooncake_session_id}, "
|
427
|
+
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
|
428
|
+
)
|
429
|
+
|
430
|
+
return self.engine.batch_transfer_sync(
|
431
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
432
|
+
)
|
433
|
+
|
434
|
+
futures = [
|
435
|
+
executor.submit(
|
436
|
+
process_layer_tp_aware,
|
437
|
+
layer_params,
|
438
|
+
)
|
439
|
+
for layer_params in layer_transfer_params
|
440
|
+
]
|
441
|
+
|
442
|
+
for future in concurrent.futures.as_completed(futures):
|
443
|
+
status = future.result()
|
444
|
+
if status != 0:
|
445
|
+
for f in futures:
|
446
|
+
f.cancel()
|
447
|
+
return status
|
448
|
+
|
449
|
+
return 0
|
450
|
+
|
289
451
|
def send_aux(
|
290
452
|
self,
|
291
453
|
mooncake_session_id: str,
|
@@ -293,18 +455,24 @@ class MooncakeKVManager(BaseKVManager):
|
|
293
455
|
dst_aux_ptrs: list[int],
|
294
456
|
dst_aux_index: int,
|
295
457
|
):
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
458
|
+
src_addr_list = []
|
459
|
+
dst_addr_list = []
|
460
|
+
length_list = []
|
461
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
462
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
463
|
+
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
464
|
+
length = prefill_aux_item_lens[i]
|
465
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
466
|
+
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
467
|
+
src_addr_list.append(src_addr)
|
468
|
+
dst_addr_list.append(dst_addr)
|
469
|
+
length_list.append(length)
|
470
|
+
return self.engine.batch_transfer_sync(
|
471
|
+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
303
472
|
)
|
304
|
-
return status
|
305
473
|
|
306
474
|
def sync_status_to_decode_endpoint(
|
307
|
-
self, remote: str, dst_port: int, room: int, status: int
|
475
|
+
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
308
476
|
):
|
309
477
|
if ":" in remote:
|
310
478
|
remote = remote.split(":")[0]
|
@@ -312,6 +480,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
312
480
|
[
|
313
481
|
str(room).encode("ascii"),
|
314
482
|
str(status).encode("ascii"),
|
483
|
+
str(prefill_rank).encode("ascii"),
|
315
484
|
]
|
316
485
|
)
|
317
486
|
|
@@ -328,6 +497,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
328
497
|
)
|
329
498
|
polls = []
|
330
499
|
dst_ranks_infos = []
|
500
|
+
local_rank = self.kv_args.engine_rank
|
331
501
|
for req in reqs_to_be_processed:
|
332
502
|
if not req.is_dummy:
|
333
503
|
# Early exit if the request has failed
|
@@ -343,6 +513,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
343
513
|
req.dst_port,
|
344
514
|
req.room,
|
345
515
|
KVPoll.Failed,
|
516
|
+
local_rank,
|
346
517
|
)
|
347
518
|
break
|
348
519
|
|
@@ -360,15 +531,31 @@ class MooncakeKVManager(BaseKVManager):
|
|
360
531
|
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
361
532
|
)
|
362
533
|
|
363
|
-
|
364
|
-
req.mooncake_session_id
|
365
|
-
kv_chunk.prefill_kv_indices,
|
366
|
-
self.decode_kv_args_table[
|
367
|
-
req.mooncake_session_id
|
368
|
-
].dst_kv_ptrs,
|
369
|
-
chunked_dst_kv_indice,
|
370
|
-
executor,
|
534
|
+
target_rank_registration_info: KVArgsRegisterInfo = (
|
535
|
+
self.decode_kv_args_table[req.mooncake_session_id]
|
371
536
|
)
|
537
|
+
local_tp_size = self.tp_size // self.dp_size
|
538
|
+
if self.is_mla_backend or (
|
539
|
+
local_tp_size == target_rank_registration_info.dst_tp_size
|
540
|
+
):
|
541
|
+
ret = self.send_kvcache(
|
542
|
+
req.mooncake_session_id,
|
543
|
+
kv_chunk.prefill_kv_indices,
|
544
|
+
target_rank_registration_info.dst_kv_ptrs,
|
545
|
+
chunked_dst_kv_indice,
|
546
|
+
executor,
|
547
|
+
)
|
548
|
+
else:
|
549
|
+
ret = self.send_kvcache_slice(
|
550
|
+
req.mooncake_session_id,
|
551
|
+
kv_chunk.prefill_kv_indices,
|
552
|
+
target_rank_registration_info.dst_kv_ptrs,
|
553
|
+
chunked_dst_kv_indice,
|
554
|
+
target_rank_registration_info.dst_tp_rank,
|
555
|
+
target_rank_registration_info.dst_tp_size,
|
556
|
+
target_rank_registration_info.dst_kv_item_len,
|
557
|
+
executor,
|
558
|
+
)
|
372
559
|
if ret != 0:
|
373
560
|
with self.session_lock:
|
374
561
|
self.session_failures[req.mooncake_session_id] += 1
|
@@ -384,7 +571,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
384
571
|
)
|
385
572
|
self.update_status(kv_chunk.room, KVPoll.Failed)
|
386
573
|
self.sync_status_to_decode_endpoint(
|
387
|
-
req.endpoint,
|
574
|
+
req.endpoint,
|
575
|
+
req.dst_port,
|
576
|
+
req.room,
|
577
|
+
KVPoll.Failed,
|
578
|
+
local_rank,
|
388
579
|
)
|
389
580
|
break
|
390
581
|
|
@@ -393,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
393
584
|
ret = self.send_aux(
|
394
585
|
req.mooncake_session_id,
|
395
586
|
kv_chunk.prefill_aux_index,
|
396
|
-
|
397
|
-
req.mooncake_session_id
|
398
|
-
].dst_aux_ptrs,
|
587
|
+
target_rank_registration_info.dst_aux_ptrs,
|
399
588
|
req.dst_aux_index,
|
400
589
|
)
|
401
590
|
polls.append(True if ret == 0 else False)
|
@@ -409,7 +598,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
409
598
|
self.update_status(req.room, status)
|
410
599
|
for endpoint, dst_port, room in dst_ranks_infos:
|
411
600
|
self.sync_status_to_decode_endpoint(
|
412
|
-
endpoint, dst_port, room, status
|
601
|
+
endpoint, dst_port, room, status, local_rank
|
413
602
|
)
|
414
603
|
else:
|
415
604
|
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
@@ -432,7 +621,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
432
621
|
|
433
622
|
def start_prefill_thread(self):
|
434
623
|
self.rank_port = get_free_port()
|
435
|
-
self.server_socket.bind(f"tcp://{
|
624
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
436
625
|
|
437
626
|
def bootstrap_thread():
|
438
627
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -471,19 +660,37 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
660
|
|
472
661
|
def start_decode_thread(self):
|
473
662
|
self.rank_port = get_free_port()
|
474
|
-
self.server_socket.bind(f"tcp://{
|
663
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
475
664
|
|
476
665
|
def decode_thread():
|
477
666
|
while True:
|
478
|
-
(bootstrap_room, status) =
|
667
|
+
(bootstrap_room, status, prefill_rank) = (
|
668
|
+
self.server_socket.recv_multipart()
|
669
|
+
)
|
479
670
|
status = int(status.decode("ascii"))
|
480
671
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
481
|
-
|
672
|
+
prefill_rank = int(prefill_rank.decode("ascii"))
|
673
|
+
|
674
|
+
if status == KVPoll.Success:
|
675
|
+
if bootstrap_room in self.request_status:
|
676
|
+
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
|
677
|
+
expected_response_num = (
|
678
|
+
self.required_prefill_response_num_table[bootstrap_room]
|
679
|
+
)
|
680
|
+
arrived_response_num = len(
|
681
|
+
self.prefill_response_tracker[bootstrap_room]
|
682
|
+
)
|
683
|
+
if (
|
684
|
+
self.is_mla_backend
|
685
|
+
or arrived_response_num == expected_response_num
|
686
|
+
):
|
687
|
+
self.update_status(bootstrap_room, KVPoll.Success)
|
688
|
+
elif status == KVPoll.Failed:
|
482
689
|
self.record_failure(
|
483
690
|
bootstrap_room,
|
484
691
|
f"Failed to get kvcache from prefill instance, it might be dead",
|
485
692
|
)
|
486
|
-
|
693
|
+
self.update_status(bootstrap_room, status)
|
487
694
|
|
488
695
|
def heartbeat_checker():
|
489
696
|
while True:
|
@@ -620,7 +827,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
620
827
|
"role": "Prefill",
|
621
828
|
"tp_size": self.tp_size,
|
622
829
|
"dp_size": self.dp_size,
|
623
|
-
"rank_ip":
|
830
|
+
"rank_ip": self.local_ip,
|
624
831
|
"rank_port": self.rank_port,
|
625
832
|
"engine_rank": self.kv_args.engine_rank,
|
626
833
|
}
|
@@ -690,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
|
|
690
897
|
self.aux_index = None
|
691
898
|
self.bootstrap_server_url = bootstrap_addr
|
692
899
|
self.conclude_state = None
|
693
|
-
self.init_time =
|
900
|
+
self.init_time = time.time()
|
694
901
|
# inner state
|
695
902
|
self.curr_idx = 0
|
696
903
|
|
697
904
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
698
905
|
self.num_kv_indices = num_kv_indices
|
699
906
|
self.aux_index = aux_index
|
700
|
-
self.init_time = time.time()
|
701
907
|
|
702
908
|
def send(
|
703
909
|
self,
|
@@ -709,7 +915,10 @@ class MooncakeKVSender(BaseKVSender):
|
|
709
915
|
|
710
916
|
if not is_last:
|
711
917
|
self.kv_mgr.add_transfer_request(
|
712
|
-
self.bootstrap_room,
|
918
|
+
self.bootstrap_room,
|
919
|
+
kv_indices,
|
920
|
+
index_slice,
|
921
|
+
False,
|
713
922
|
)
|
714
923
|
else:
|
715
924
|
self.kv_mgr.add_transfer_request(
|
@@ -746,12 +955,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
746
955
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
747
956
|
|
748
957
|
def failure_exception(self):
|
749
|
-
self.clear()
|
750
|
-
|
751
958
|
# Explicitly set the status to failure since this request has failed in another rank
|
752
959
|
if self.conclude_state is None:
|
753
960
|
self.conclude_state = KVPoll.Failed
|
754
961
|
|
962
|
+
self.clear()
|
963
|
+
|
755
964
|
with self.kv_mgr.failure_lock:
|
756
965
|
failure_reason = self.kv_mgr.failure_records.pop(
|
757
966
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -818,23 +1027,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
818
1027
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
819
1028
|
)
|
820
1029
|
self.required_dst_info_num = 1
|
1030
|
+
self.required_prefill_response_num = 1
|
821
1031
|
self.target_tp_ranks = [self.target_tp_rank]
|
822
1032
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
823
|
-
|
824
|
-
|
825
|
-
|
1033
|
+
if not self.kv_mgr.is_mla_backend:
|
1034
|
+
logger.warning_once(
|
1035
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1036
|
+
)
|
826
1037
|
self.target_tp_rank = (
|
827
1038
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
828
1039
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
829
1040
|
self.required_dst_info_num = (
|
830
1041
|
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
831
1042
|
)
|
1043
|
+
self.required_prefill_response_num = 1
|
832
1044
|
self.target_tp_ranks = [self.target_tp_rank]
|
833
1045
|
else:
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
1046
|
+
if not self.kv_mgr.is_mla_backend:
|
1047
|
+
logger.warning_once(
|
1048
|
+
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
|
1049
|
+
)
|
838
1050
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
839
1051
|
self.target_tp_ranks = [
|
840
1052
|
rank
|
@@ -851,6 +1063,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
851
1063
|
# or the KVPoll will never be set correctly
|
852
1064
|
self.target_tp_rank = self.target_tp_ranks[0]
|
853
1065
|
self.required_dst_info_num = 1
|
1066
|
+
self.required_prefill_response_num = (
|
1067
|
+
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
|
1068
|
+
)
|
854
1069
|
|
855
1070
|
if self.data_parallel_rank is not None:
|
856
1071
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
@@ -858,6 +1073,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
858
1073
|
else:
|
859
1074
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
860
1075
|
|
1076
|
+
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
1077
|
+
self.required_prefill_response_num
|
1078
|
+
)
|
861
1079
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
862
1080
|
bootstrap_key = (
|
863
1081
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
@@ -871,11 +1089,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
871
1089
|
self.target_dp_group,
|
872
1090
|
)
|
873
1091
|
if bootstrap_info is not None:
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
1092
|
+
if self.kv_mgr.is_mla_backend:
|
1093
|
+
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
1094
|
+
bootstrap_info["is_dummy"] = not bool(
|
1095
|
+
target_tp_rank == self.target_tp_rank
|
1096
|
+
or self.target_tp_rank is None
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
# For non-MLA: all target_tp_ranks are selected real ranks
|
1100
|
+
bootstrap_info["is_dummy"] = False
|
879
1101
|
logger.debug(
|
880
1102
|
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
|
881
1103
|
)
|
@@ -947,17 +1169,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
947
1169
|
packed_aux_data_ptrs = b"".join(
|
948
1170
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
949
1171
|
)
|
1172
|
+
tp_rank = self.kv_mgr.kv_args.engine_rank
|
1173
|
+
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
1174
|
+
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
1175
|
+
dst_tp_rank = str(tp_rank).encode("ascii")
|
1176
|
+
dst_tp_size = str(tp_size).encode("ascii")
|
1177
|
+
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
950
1178
|
|
951
1179
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
952
1180
|
with lock:
|
953
1181
|
sock.send_multipart(
|
954
1182
|
[
|
955
1183
|
"None".encode("ascii"),
|
956
|
-
|
1184
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
957
1185
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
958
1186
|
self.session_id.encode("ascii"),
|
959
1187
|
packed_kv_data_ptrs,
|
960
1188
|
packed_aux_data_ptrs,
|
1189
|
+
dst_tp_rank,
|
1190
|
+
dst_tp_size,
|
1191
|
+
dst_kv_item_len,
|
961
1192
|
]
|
962
1193
|
)
|
963
1194
|
|
@@ -983,7 +1214,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
983
1214
|
sock.send_multipart(
|
984
1215
|
[
|
985
1216
|
str(self.bootstrap_room).encode("ascii"),
|
986
|
-
|
1217
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
987
1218
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
988
1219
|
self.session_id.encode("ascii"),
|
989
1220
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -1006,13 +1237,19 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1006
1237
|
if self.bootstrap_room in self.kv_mgr.request_status:
|
1007
1238
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
1008
1239
|
|
1009
|
-
|
1010
|
-
|
1240
|
+
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
|
1241
|
+
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
|
1242
|
+
|
1243
|
+
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
|
1244
|
+
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
|
1011
1245
|
|
1246
|
+
def failure_exception(self):
|
1012
1247
|
# Explicitly set the status to failure since this request has failed in another rank
|
1013
1248
|
if self.conclude_state is None:
|
1014
1249
|
self.conclude_state = KVPoll.Failed
|
1015
1250
|
|
1251
|
+
self.clear()
|
1252
|
+
|
1016
1253
|
with self.kv_mgr.failure_lock:
|
1017
1254
|
failure_reason = self.kv_mgr.failure_records.pop(
|
1018
1255
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Optional
|
4
|
+
from typing import List, Optional
|
5
5
|
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
@@ -90,5 +90,35 @@ class MooncakeTransferEngine:
|
|
90
90
|
|
91
91
|
return ret
|
92
92
|
|
93
|
+
def batch_transfer_sync(
|
94
|
+
self,
|
95
|
+
session_id: str,
|
96
|
+
buffers: List[int],
|
97
|
+
peer_buffer_addresses: List[int],
|
98
|
+
lengths: List[int],
|
99
|
+
) -> int:
|
100
|
+
"""Synchronously transfer data to the specified addresses in batches."""
|
101
|
+
try:
|
102
|
+
ret = self.engine.batch_transfer_sync_write(
|
103
|
+
session_id, buffers, peer_buffer_addresses, lengths
|
104
|
+
)
|
105
|
+
except Exception:
|
106
|
+
ret = -1
|
107
|
+
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
|
108
|
+
if not hasattr(self.engine, "batch_transfer_sync_write"):
|
109
|
+
raise RuntimeError(
|
110
|
+
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
|
111
|
+
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
|
112
|
+
)
|
113
|
+
|
114
|
+
if ret < 0:
|
115
|
+
logger.debug(
|
116
|
+
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
|
117
|
+
buffers,
|
118
|
+
session_id,
|
119
|
+
peer_buffer_addresses,
|
120
|
+
)
|
121
|
+
return ret
|
122
|
+
|
93
123
|
def get_session_id(self):
|
94
124
|
return self.session_id
|