sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +67 -13
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +36 -12
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +35 -4
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +38 -12
- sglang/srt/managers/scheduler.py +41 -28
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +19 -25
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +50 -11
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +31 -24
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +5 -1
- sglang/test/runners.py +6 -13
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +74 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
|
|
32
32
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
33
33
|
from sglang.srt.disaggregation.utils import (
|
34
34
|
DisaggregationMode,
|
35
|
+
FakeBootstrapHost,
|
35
36
|
KVClassType,
|
36
37
|
ReqToMetadataIdxAllocator,
|
37
38
|
TransferBackend,
|
@@ -133,11 +134,16 @@ class DecodePreallocQueue:
|
|
133
134
|
|
134
135
|
def add(self, req: Req) -> None:
|
135
136
|
"""Add a request to the pending queue."""
|
136
|
-
|
137
|
-
|
137
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
138
|
+
# Fake transfer for warmup reqs
|
139
|
+
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
140
|
+
else:
|
141
|
+
kv_receiver_class = get_kv_class(
|
142
|
+
self.transfer_backend, KVClassType.RECEIVER
|
143
|
+
)
|
138
144
|
kv_receiver = kv_receiver_class(
|
139
145
|
mgr=self.kv_manager,
|
140
|
-
bootstrap_addr=f"{req.bootstrap_host}:{
|
146
|
+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
141
147
|
bootstrap_room=req.bootstrap_room,
|
142
148
|
)
|
143
149
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
@@ -307,7 +313,7 @@ class DecodeTransferQueue:
|
|
307
313
|
def extend(self, req_conns) -> None:
|
308
314
|
self.queue.extend(req_conns)
|
309
315
|
|
310
|
-
def pop_transferred(self) -> List[
|
316
|
+
def pop_transferred(self) -> List[DecodeRequest]:
|
311
317
|
if not self.queue:
|
312
318
|
return []
|
313
319
|
|
@@ -330,7 +336,7 @@ class DecodeTransferQueue:
|
|
330
336
|
assert len(decode_req.req.output_ids) == 0
|
331
337
|
assert decode_req.req.transferred_output_id is None
|
332
338
|
decode_req.req.transferred_output_id = output_id
|
333
|
-
transferred_reqs.append(decode_req
|
339
|
+
transferred_reqs.append(decode_req)
|
334
340
|
indices_to_remove.add(i)
|
335
341
|
elif poll in [
|
336
342
|
KVPoll.Bootstrapping,
|
@@ -444,8 +450,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
444
450
|
|
445
451
|
class SchedulerDisaggregationDecodeMixin:
|
446
452
|
|
453
|
+
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
454
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
455
|
+
result = None
|
456
|
+
if batch:
|
457
|
+
result = self.run_batch(batch)
|
458
|
+
if not delay_process:
|
459
|
+
self.process_batch_result(batch, result)
|
460
|
+
return batch, result
|
461
|
+
|
447
462
|
@torch.no_grad()
|
448
|
-
def event_loop_normal_disagg_decode(self):
|
463
|
+
def event_loop_normal_disagg_decode(self: Scheduler):
|
449
464
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
450
465
|
|
451
466
|
while True:
|
@@ -456,14 +471,25 @@ class SchedulerDisaggregationDecodeMixin:
|
|
456
471
|
batch = self.get_next_disagg_decode_batch_to_run()
|
457
472
|
self.cur_batch = batch
|
458
473
|
|
474
|
+
prepare_dp_attn_flag = (
|
475
|
+
self.server_args.enable_dp_attention
|
476
|
+
or self.server_args.enable_sp_layernorm
|
477
|
+
)
|
478
|
+
|
459
479
|
if batch:
|
460
480
|
# Generate fake extend output.
|
461
481
|
if batch.forward_mode.is_extend():
|
462
482
|
# Note: Logprobs should be handled on the prefill engine.
|
463
483
|
self.stream_output(batch.reqs, False)
|
484
|
+
if prepare_dp_attn_flag:
|
485
|
+
self._prepare_idle_batch_and_run(None)
|
464
486
|
else:
|
487
|
+
if prepare_dp_attn_flag:
|
488
|
+
self.prepare_dp_attn_batch(batch)
|
465
489
|
result = self.run_batch(batch)
|
466
490
|
self.process_batch_result(batch, result)
|
491
|
+
elif prepare_dp_attn_flag:
|
492
|
+
batch, _ = self._prepare_idle_batch_and_run(None)
|
467
493
|
|
468
494
|
if batch is None and (
|
469
495
|
len(self.disagg_decode_transfer_queue.queue)
|
@@ -477,10 +503,10 @@ class SchedulerDisaggregationDecodeMixin:
|
|
477
503
|
self.last_batch = batch
|
478
504
|
|
479
505
|
@torch.no_grad()
|
480
|
-
def event_loop_overlap_disagg_decode(self):
|
506
|
+
def event_loop_overlap_disagg_decode(self: Scheduler):
|
481
507
|
result_queue = deque()
|
482
508
|
self.last_batch: Optional[ScheduleBatch] = None
|
483
|
-
self.
|
509
|
+
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
484
510
|
|
485
511
|
while True:
|
486
512
|
recv_reqs = self.recv_requests()
|
@@ -489,20 +515,41 @@ class SchedulerDisaggregationDecodeMixin:
|
|
489
515
|
self.process_decode_queue()
|
490
516
|
batch = self.get_next_disagg_decode_batch_to_run()
|
491
517
|
self.cur_batch = batch
|
492
|
-
|
518
|
+
last_batch_in_queue = False
|
519
|
+
|
520
|
+
prepare_dp_attn_flag = (
|
521
|
+
self.server_args.enable_dp_attention
|
522
|
+
or self.server_args.enable_sp_layernorm
|
523
|
+
)
|
493
524
|
|
494
525
|
if batch:
|
495
526
|
# Generate fake extend output.
|
496
527
|
if batch.forward_mode.is_extend():
|
497
528
|
# Note: Logprobs should be handled on the prefill engine.
|
498
529
|
self.stream_output(batch.reqs, False)
|
499
|
-
|
530
|
+
if prepare_dp_attn_flag:
|
531
|
+
batch_, result = self._prepare_idle_batch_and_run(
|
532
|
+
None, delay_process=True
|
533
|
+
)
|
534
|
+
if batch_:
|
535
|
+
result_queue.append((batch_.copy(), result))
|
536
|
+
last_batch_in_queue = True
|
500
537
|
else:
|
538
|
+
if prepare_dp_attn_flag:
|
539
|
+
self.prepare_dp_attn_batch(batch)
|
501
540
|
result = self.run_batch(batch)
|
502
541
|
result_queue.append((batch.copy(), result))
|
542
|
+
last_batch_in_queue = True
|
543
|
+
elif prepare_dp_attn_flag:
|
544
|
+
batch, result = self._prepare_idle_batch_and_run(
|
545
|
+
None, delay_process=True
|
546
|
+
)
|
547
|
+
if batch:
|
548
|
+
result_queue.append((batch.copy(), result))
|
549
|
+
last_batch_in_queue = True
|
503
550
|
|
504
551
|
# Process the results of the previous batch but skip if the last batch is extend
|
505
|
-
if self.last_batch and
|
552
|
+
if self.last_batch and self.last_batch_in_queue:
|
506
553
|
tmp_batch, tmp_result = result_queue.popleft()
|
507
554
|
self.process_batch_result(tmp_batch, tmp_result)
|
508
555
|
|
@@ -516,7 +563,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
516
563
|
self.new_token_ratio = self.init_new_token_ratio
|
517
564
|
|
518
565
|
self.last_batch = batch
|
519
|
-
self.
|
566
|
+
self.last_batch_in_queue = last_batch_in_queue
|
520
567
|
|
521
568
|
def get_next_disagg_decode_batch_to_run(
|
522
569
|
self: Scheduler,
|
@@ -600,8 +647,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
600
647
|
|
601
648
|
def process_decode_queue(self: Scheduler):
|
602
649
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
650
|
+
|
651
|
+
def _num_pre_alloc(req):
|
652
|
+
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
653
|
+
|
654
|
+
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
603
655
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
604
656
|
alloc_reqs = (
|
605
657
|
self.disagg_decode_transfer_queue.pop_transferred()
|
606
658
|
) # the requests which kv has arrived
|
607
|
-
self.
|
659
|
+
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
|
660
|
+
|
661
|
+
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import FakeKVReceiver, FakeKVSender
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
|
7
|
+
from sglang.srt.disaggregation.base.conn import (
|
8
|
+
BaseKVManager,
|
9
|
+
BaseKVReceiver,
|
10
|
+
BaseKVSender,
|
11
|
+
KVArgs,
|
12
|
+
KVPoll,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
19
|
+
class FakeKVSender(BaseKVSender):
|
20
|
+
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
21
|
+
self.has_sent = False
|
22
|
+
|
23
|
+
def poll(self) -> KVPoll:
|
24
|
+
if self.has_sent is False:
|
25
|
+
# Assume handshake completed instantly
|
26
|
+
return KVPoll.WaitingForInput
|
27
|
+
else:
|
28
|
+
# Assume transfer completed instantly
|
29
|
+
logger.info("FakeKVSender poll success")
|
30
|
+
return KVPoll.Success
|
31
|
+
|
32
|
+
def init(
|
33
|
+
self,
|
34
|
+
kv_indices: list[int],
|
35
|
+
aux_index: Optional[int] = None,
|
36
|
+
dest_ranks: Optional[list[int]] = None,
|
37
|
+
):
|
38
|
+
logger.info(
|
39
|
+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
|
40
|
+
)
|
41
|
+
pass
|
42
|
+
|
43
|
+
def send(
|
44
|
+
self,
|
45
|
+
kv_indices: npt.NDArray[np.int64],
|
46
|
+
index_slice: slice,
|
47
|
+
is_last: bool,
|
48
|
+
):
|
49
|
+
logger.info(
|
50
|
+
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
|
51
|
+
)
|
52
|
+
if is_last:
|
53
|
+
self.has_sent = True
|
54
|
+
logger.info(f"FakeKVSender send success")
|
55
|
+
else:
|
56
|
+
self.has_sent = False
|
57
|
+
logger.info(f"FakeKVSender send fake transfering")
|
58
|
+
|
59
|
+
def failure_exception(self):
|
60
|
+
raise Exception("Fake KVSender Exception")
|
61
|
+
|
62
|
+
|
63
|
+
class FakeKVReceiver(BaseKVReceiver):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
mgr: BaseKVManager,
|
67
|
+
bootstrap_addr: str,
|
68
|
+
bootstrap_room: Optional[int] = None,
|
69
|
+
):
|
70
|
+
self.has_init = False
|
71
|
+
|
72
|
+
def poll(self) -> KVPoll:
|
73
|
+
if self.has_init is False:
|
74
|
+
# Assume handshake completed instantly
|
75
|
+
return KVPoll.WaitingForInput
|
76
|
+
else:
|
77
|
+
# Assume transfer completed instantly
|
78
|
+
logger.info("FakeKVReceiver poll success")
|
79
|
+
return KVPoll.Success
|
80
|
+
|
81
|
+
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
82
|
+
self.has_init = True
|
83
|
+
logger.info(
|
84
|
+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
85
|
+
)
|
86
|
+
|
87
|
+
def failure_exception(self):
|
88
|
+
raise Exception("Fake KVReceiver Exception")
|
@@ -6,6 +6,7 @@ import asyncio
|
|
6
6
|
import random
|
7
7
|
import urllib
|
8
8
|
from itertools import chain
|
9
|
+
from typing import List
|
9
10
|
|
10
11
|
import aiohttp
|
11
12
|
import orjson
|
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
|
|
14
15
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
15
16
|
|
16
17
|
|
18
|
+
class PrefillConfig:
|
19
|
+
def __init__(self, url: str, bootstrap_port: int):
|
20
|
+
self.url = url
|
21
|
+
self.bootstrap_port = bootstrap_port
|
22
|
+
|
23
|
+
|
17
24
|
class MiniLoadBalancer:
|
18
|
-
def __init__(self,
|
19
|
-
self.
|
25
|
+
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
26
|
+
self.prefill_configs = prefill_configs
|
27
|
+
self.prefill_servers = [p.url for p in prefill_configs]
|
20
28
|
self.decode_servers = decode_servers
|
21
29
|
|
22
30
|
def select_pair(self):
|
23
|
-
|
31
|
+
prefill_config = random.choice(self.prefill_configs)
|
32
|
+
decode_server = random.choice(self.decode_servers)
|
33
|
+
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
24
34
|
|
25
35
|
async def generate(
|
26
36
|
self, modified_request, prefill_server, decode_server, endpoint
|
@@ -160,7 +170,7 @@ async def get_model_info():
|
|
160
170
|
|
161
171
|
@app.post("/generate")
|
162
172
|
async def handle_generate_request(request_data: dict):
|
163
|
-
prefill_server, decode_server = load_balancer.select_pair()
|
173
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
164
174
|
|
165
175
|
# Parse and transform prefill_server for bootstrap data
|
166
176
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
|
|
172
182
|
modified_request.update(
|
173
183
|
{
|
174
184
|
"bootstrap_host": [hostname] * batch_size,
|
185
|
+
"bootstrap_port": [bootstrap_port] * batch_size,
|
175
186
|
"bootstrap_room": [
|
176
187
|
_generate_bootstrap_room() for _ in range(batch_size)
|
177
188
|
],
|
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
|
|
181
192
|
modified_request.update(
|
182
193
|
{
|
183
194
|
"bootstrap_host": hostname,
|
195
|
+
"bootstrap_port": bootstrap_port,
|
184
196
|
"bootstrap_room": _generate_bootstrap_room(),
|
185
197
|
}
|
186
198
|
)
|
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
|
|
197
209
|
|
198
210
|
@app.post("/v1/chat/completions")
|
199
211
|
async def handle_completion_request(request_data: dict):
|
200
|
-
prefill_server, decode_server = load_balancer.select_pair()
|
212
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
201
213
|
|
202
214
|
# Parse and transform prefill_server for bootstrap data
|
203
215
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
|
|
206
218
|
modified_request.update(
|
207
219
|
{
|
208
220
|
"bootstrap_host": hostname,
|
221
|
+
"bootstrap_port": bootstrap_port,
|
209
222
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
210
223
|
}
|
211
224
|
)
|
@@ -255,9 +268,9 @@ async def get_models():
|
|
255
268
|
raise HTTPException(status_code=500, detail=str(e))
|
256
269
|
|
257
270
|
|
258
|
-
def run(
|
271
|
+
def run(prefill_configs, decode_addrs, host, port):
|
259
272
|
global load_balancer
|
260
|
-
load_balancer = MiniLoadBalancer(
|
273
|
+
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
261
274
|
uvicorn.run(app, host=host, port=port)
|
262
275
|
|
263
276
|
|
@@ -268,6 +281,11 @@ if __name__ == "__main__":
|
|
268
281
|
parser.add_argument(
|
269
282
|
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
270
283
|
)
|
284
|
+
parser.add_argument(
|
285
|
+
"--prefill-bootstrap-ports",
|
286
|
+
help="Comma-separated bootstrap ports for prefill servers",
|
287
|
+
default="8998",
|
288
|
+
)
|
271
289
|
parser.add_argument(
|
272
290
|
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
273
291
|
)
|
@@ -278,4 +296,23 @@ if __name__ == "__main__":
|
|
278
296
|
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
279
297
|
)
|
280
298
|
args = parser.parse_args()
|
281
|
-
|
299
|
+
|
300
|
+
prefill_urls = args.prefill.split(",")
|
301
|
+
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
302
|
+
|
303
|
+
if len(bootstrap_ports) == 1:
|
304
|
+
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
305
|
+
else:
|
306
|
+
if len(bootstrap_ports) != len(prefill_urls):
|
307
|
+
raise ValueError(
|
308
|
+
"Number of prefill URLs must match number of bootstrap ports"
|
309
|
+
)
|
310
|
+
exit(1)
|
311
|
+
|
312
|
+
prefill_configs = []
|
313
|
+
for url, port in zip(prefill_urls, bootstrap_ports):
|
314
|
+
prefill_configs.append(PrefillConfig(url, port))
|
315
|
+
|
316
|
+
decode_addrs = args.decode.split(",")
|
317
|
+
|
318
|
+
run(prefill_configs, decode_addrs, args.host, args.port)
|