sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
from collections import deque
|
24
25
|
from dataclasses import dataclass
|
25
26
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
27
|
|
@@ -136,7 +137,7 @@ class DecodePreallocQueue:
|
|
136
137
|
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
137
138
|
kv_receiver = kv_receiver_class(
|
138
139
|
mgr=self.kv_manager,
|
139
|
-
bootstrap_addr=f"{req.bootstrap_host}:{
|
140
|
+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
140
141
|
bootstrap_room=req.bootstrap_room,
|
141
142
|
)
|
142
143
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
@@ -306,7 +307,7 @@ class DecodeTransferQueue:
|
|
306
307
|
def extend(self, req_conns) -> None:
|
307
308
|
self.queue.extend(req_conns)
|
308
309
|
|
309
|
-
def pop_transferred(self) -> List[
|
310
|
+
def pop_transferred(self) -> List[DecodeRequest]:
|
310
311
|
if not self.queue:
|
311
312
|
return []
|
312
313
|
|
@@ -329,7 +330,7 @@ class DecodeTransferQueue:
|
|
329
330
|
assert len(decode_req.req.output_ids) == 0
|
330
331
|
assert decode_req.req.transferred_output_id is None
|
331
332
|
decode_req.req.transferred_output_id = output_id
|
332
|
-
transferred_reqs.append(decode_req
|
333
|
+
transferred_reqs.append(decode_req)
|
333
334
|
indices_to_remove.add(i)
|
334
335
|
elif poll in [
|
335
336
|
KVPoll.Bootstrapping,
|
@@ -443,8 +444,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
443
444
|
|
444
445
|
class SchedulerDisaggregationDecodeMixin:
|
445
446
|
|
447
|
+
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
448
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
449
|
+
result = None
|
450
|
+
if batch:
|
451
|
+
result = self.run_batch(batch)
|
452
|
+
if not delay_process:
|
453
|
+
self.process_batch_result(batch, result)
|
454
|
+
return batch, result
|
455
|
+
|
446
456
|
@torch.no_grad()
|
447
|
-
def event_loop_normal_disagg_decode(self):
|
457
|
+
def event_loop_normal_disagg_decode(self: Scheduler):
|
448
458
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
449
459
|
|
450
460
|
while True:
|
@@ -455,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
|
|
455
465
|
batch = self.get_next_disagg_decode_batch_to_run()
|
456
466
|
self.cur_batch = batch
|
457
467
|
|
468
|
+
prepare_dp_attn_flag = (
|
469
|
+
self.server_args.enable_dp_attention
|
470
|
+
or self.server_args.enable_sp_layernorm
|
471
|
+
)
|
472
|
+
|
458
473
|
if batch:
|
459
474
|
# Generate fake extend output.
|
460
475
|
if batch.forward_mode.is_extend():
|
461
476
|
# Note: Logprobs should be handled on the prefill engine.
|
462
477
|
self.stream_output(batch.reqs, False)
|
478
|
+
if prepare_dp_attn_flag:
|
479
|
+
self._prepare_idle_batch_and_run(None)
|
463
480
|
else:
|
481
|
+
if prepare_dp_attn_flag:
|
482
|
+
self.prepare_dp_attn_batch(batch)
|
464
483
|
result = self.run_batch(batch)
|
465
484
|
self.process_batch_result(batch, result)
|
485
|
+
elif prepare_dp_attn_flag:
|
486
|
+
batch, _ = self._prepare_idle_batch_and_run(None)
|
466
487
|
|
467
488
|
if batch is None and (
|
468
489
|
len(self.disagg_decode_transfer_queue.queue)
|
@@ -475,6 +496,69 @@ class SchedulerDisaggregationDecodeMixin:
|
|
475
496
|
|
476
497
|
self.last_batch = batch
|
477
498
|
|
499
|
+
@torch.no_grad()
|
500
|
+
def event_loop_overlap_disagg_decode(self: Scheduler):
|
501
|
+
result_queue = deque()
|
502
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
503
|
+
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
504
|
+
|
505
|
+
while True:
|
506
|
+
recv_reqs = self.recv_requests()
|
507
|
+
self.process_input_requests(recv_reqs)
|
508
|
+
# polling and allocating kv cache
|
509
|
+
self.process_decode_queue()
|
510
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
511
|
+
self.cur_batch = batch
|
512
|
+
last_batch_in_queue = False
|
513
|
+
|
514
|
+
prepare_dp_attn_flag = (
|
515
|
+
self.server_args.enable_dp_attention
|
516
|
+
or self.server_args.enable_sp_layernorm
|
517
|
+
)
|
518
|
+
|
519
|
+
if batch:
|
520
|
+
# Generate fake extend output.
|
521
|
+
if batch.forward_mode.is_extend():
|
522
|
+
# Note: Logprobs should be handled on the prefill engine.
|
523
|
+
self.stream_output(batch.reqs, False)
|
524
|
+
if prepare_dp_attn_flag:
|
525
|
+
batch_, result = self._prepare_idle_batch_and_run(
|
526
|
+
None, delay_process=True
|
527
|
+
)
|
528
|
+
if batch_:
|
529
|
+
result_queue.append((batch_.copy(), result))
|
530
|
+
last_batch_in_queue = True
|
531
|
+
else:
|
532
|
+
if prepare_dp_attn_flag:
|
533
|
+
self.prepare_dp_attn_batch(batch)
|
534
|
+
result = self.run_batch(batch)
|
535
|
+
result_queue.append((batch.copy(), result))
|
536
|
+
last_batch_in_queue = True
|
537
|
+
elif prepare_dp_attn_flag:
|
538
|
+
batch, result = self._prepare_idle_batch_and_run(
|
539
|
+
None, delay_process=True
|
540
|
+
)
|
541
|
+
if batch:
|
542
|
+
result_queue.append((batch.copy(), result))
|
543
|
+
last_batch_in_queue = True
|
544
|
+
|
545
|
+
# Process the results of the previous batch but skip if the last batch is extend
|
546
|
+
if self.last_batch and self.last_batch_in_queue:
|
547
|
+
tmp_batch, tmp_result = result_queue.popleft()
|
548
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
549
|
+
|
550
|
+
if batch is None and (
|
551
|
+
len(self.disagg_decode_transfer_queue.queue)
|
552
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
553
|
+
== 0
|
554
|
+
):
|
555
|
+
# When the server is idle, do self-check and re-init some states
|
556
|
+
self.check_memory()
|
557
|
+
self.new_token_ratio = self.init_new_token_ratio
|
558
|
+
|
559
|
+
self.last_batch = batch
|
560
|
+
self.last_batch_in_queue = last_batch_in_queue
|
561
|
+
|
478
562
|
def get_next_disagg_decode_batch_to_run(
|
479
563
|
self: Scheduler,
|
480
564
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
@@ -557,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
557
641
|
|
558
642
|
def process_decode_queue(self: Scheduler):
|
559
643
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
644
|
+
|
645
|
+
def _num_pre_alloc(req):
|
646
|
+
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
647
|
+
|
648
|
+
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
560
649
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
561
650
|
alloc_reqs = (
|
562
651
|
self.disagg_decode_transfer_queue.pop_transferred()
|
563
652
|
) # the requests which kv has arrived
|
564
|
-
self.
|
653
|
+
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
|
654
|
+
|
655
|
+
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
@@ -6,6 +6,7 @@ import asyncio
|
|
6
6
|
import random
|
7
7
|
import urllib
|
8
8
|
from itertools import chain
|
9
|
+
from typing import List
|
9
10
|
|
10
11
|
import aiohttp
|
11
12
|
import orjson
|
@@ -14,17 +15,27 @@ from fastapi import FastAPI, HTTPException
|
|
14
15
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
15
16
|
|
16
17
|
|
18
|
+
class PrefillConfig:
|
19
|
+
def __init__(self, url: str, bootstrap_port: int):
|
20
|
+
self.url = url
|
21
|
+
self.bootstrap_port = bootstrap_port
|
22
|
+
|
23
|
+
|
17
24
|
class MiniLoadBalancer:
|
18
|
-
def __init__(self,
|
19
|
-
self.
|
25
|
+
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
26
|
+
self.prefill_configs = prefill_configs
|
27
|
+
self.prefill_servers = [p.url for p in prefill_configs]
|
20
28
|
self.decode_servers = decode_servers
|
21
29
|
|
22
30
|
def select_pair(self):
|
23
|
-
|
31
|
+
prefill_config = random.choice(self.prefill_configs)
|
32
|
+
decode_server = random.choice(self.decode_servers)
|
33
|
+
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
24
34
|
|
25
35
|
async def generate(
|
26
|
-
self, modified_request, prefill_server, decode_server
|
36
|
+
self, modified_request, prefill_server, decode_server, endpoint
|
27
37
|
) -> ORJSONResponse:
|
38
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
28
39
|
|
29
40
|
async with aiohttp.ClientSession(
|
30
41
|
timeout=aiohttp.ClientTimeout(
|
@@ -32,8 +43,8 @@ class MiniLoadBalancer:
|
|
32
43
|
) # Add timeout for request reliability
|
33
44
|
) as session:
|
34
45
|
tasks = [
|
35
|
-
session.post(f"{prefill_server}/
|
36
|
-
session.post(f"{decode_server}/
|
46
|
+
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
47
|
+
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
37
48
|
]
|
38
49
|
# Wait for both responses to complete. Prefill should end first.
|
39
50
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
@@ -43,7 +54,11 @@ class MiniLoadBalancer:
|
|
43
54
|
status_code=decode_response.status,
|
44
55
|
)
|
45
56
|
|
46
|
-
async def generate_stream(
|
57
|
+
async def generate_stream(
|
58
|
+
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
59
|
+
):
|
60
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
61
|
+
|
47
62
|
async def stream_results():
|
48
63
|
async with aiohttp.ClientSession(
|
49
64
|
timeout=aiohttp.ClientTimeout(
|
@@ -54,10 +69,10 @@ class MiniLoadBalancer:
|
|
54
69
|
# Create the tasks for both prefill and decode requests
|
55
70
|
tasks = [
|
56
71
|
session.post(
|
57
|
-
f"{prefill_server}/
|
72
|
+
f"{prefill_server}/{endpoint}", json=modified_request
|
58
73
|
),
|
59
74
|
session.post(
|
60
|
-
f"{decode_server}/
|
75
|
+
f"{decode_server}/{endpoint}", json=modified_request
|
61
76
|
),
|
62
77
|
]
|
63
78
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
@@ -155,7 +170,46 @@ async def get_model_info():
|
|
155
170
|
|
156
171
|
@app.post("/generate")
|
157
172
|
async def handle_generate_request(request_data: dict):
|
158
|
-
prefill_server, decode_server = load_balancer.select_pair()
|
173
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
174
|
+
|
175
|
+
# Parse and transform prefill_server for bootstrap data
|
176
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
177
|
+
hostname = parsed_url.hostname
|
178
|
+
modified_request = request_data.copy()
|
179
|
+
|
180
|
+
batch_size = _get_request_batch_size(modified_request)
|
181
|
+
if batch_size is not None:
|
182
|
+
modified_request.update(
|
183
|
+
{
|
184
|
+
"bootstrap_host": [hostname] * batch_size,
|
185
|
+
"bootstrap_port": [bootstrap_port] * batch_size,
|
186
|
+
"bootstrap_room": [
|
187
|
+
_generate_bootstrap_room() for _ in range(batch_size)
|
188
|
+
],
|
189
|
+
}
|
190
|
+
)
|
191
|
+
else:
|
192
|
+
modified_request.update(
|
193
|
+
{
|
194
|
+
"bootstrap_host": hostname,
|
195
|
+
"bootstrap_port": bootstrap_port,
|
196
|
+
"bootstrap_room": _generate_bootstrap_room(),
|
197
|
+
}
|
198
|
+
)
|
199
|
+
|
200
|
+
if request_data.get("stream", False):
|
201
|
+
return await load_balancer.generate_stream(
|
202
|
+
modified_request, prefill_server, decode_server, "generate"
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
return await load_balancer.generate(
|
206
|
+
modified_request, prefill_server, decode_server, "generate"
|
207
|
+
)
|
208
|
+
|
209
|
+
|
210
|
+
@app.post("/v1/chat/completions")
|
211
|
+
async def handle_completion_request(request_data: dict):
|
212
|
+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
159
213
|
|
160
214
|
# Parse and transform prefill_server for bootstrap data
|
161
215
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
@@ -164,20 +218,40 @@ async def handle_generate_request(request_data: dict):
|
|
164
218
|
modified_request.update(
|
165
219
|
{
|
166
220
|
"bootstrap_host": hostname,
|
221
|
+
"bootstrap_port": bootstrap_port,
|
167
222
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
168
223
|
}
|
169
224
|
)
|
170
225
|
|
171
226
|
if request_data.get("stream", False):
|
172
227
|
return await load_balancer.generate_stream(
|
173
|
-
modified_request,
|
228
|
+
modified_request,
|
229
|
+
prefill_server,
|
230
|
+
decode_server,
|
231
|
+
endpoint="v1/chat/completions",
|
174
232
|
)
|
175
233
|
else:
|
176
234
|
return await load_balancer.generate(
|
177
|
-
modified_request,
|
235
|
+
modified_request,
|
236
|
+
prefill_server,
|
237
|
+
decode_server,
|
238
|
+
endpoint="v1/chat/completions",
|
178
239
|
)
|
179
240
|
|
180
241
|
|
242
|
+
def _generate_bootstrap_room():
|
243
|
+
return random.randint(0, 2**63 - 1)
|
244
|
+
|
245
|
+
|
246
|
+
# We may utilize `GenerateReqInput`'s logic later
|
247
|
+
def _get_request_batch_size(request):
|
248
|
+
if (text := request.get("text")) is not None:
|
249
|
+
return None if isinstance(text, str) else len(text)
|
250
|
+
if (input_ids := request.get("input_ids")) is not None:
|
251
|
+
return None if isinstance(input_ids[0], int) else len(input_ids)
|
252
|
+
return None
|
253
|
+
|
254
|
+
|
181
255
|
@app.get("/v1/models")
|
182
256
|
async def get_models():
|
183
257
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
@@ -194,9 +268,9 @@ async def get_models():
|
|
194
268
|
raise HTTPException(status_code=500, detail=str(e))
|
195
269
|
|
196
270
|
|
197
|
-
def run(
|
271
|
+
def run(prefill_configs, decode_addrs, host, port):
|
198
272
|
global load_balancer
|
199
|
-
load_balancer = MiniLoadBalancer(
|
273
|
+
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
200
274
|
uvicorn.run(app, host=host, port=port)
|
201
275
|
|
202
276
|
|
@@ -207,6 +281,11 @@ if __name__ == "__main__":
|
|
207
281
|
parser.add_argument(
|
208
282
|
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
209
283
|
)
|
284
|
+
parser.add_argument(
|
285
|
+
"--prefill-bootstrap-ports",
|
286
|
+
help="Comma-separated bootstrap ports for prefill servers",
|
287
|
+
default="8998",
|
288
|
+
)
|
210
289
|
parser.add_argument(
|
211
290
|
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
212
291
|
)
|
@@ -217,4 +296,23 @@ if __name__ == "__main__":
|
|
217
296
|
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
218
297
|
)
|
219
298
|
args = parser.parse_args()
|
220
|
-
|
299
|
+
|
300
|
+
prefill_urls = args.prefill.split(",")
|
301
|
+
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
302
|
+
|
303
|
+
if len(bootstrap_ports) == 1:
|
304
|
+
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
305
|
+
else:
|
306
|
+
if len(bootstrap_ports) != len(prefill_urls):
|
307
|
+
raise ValueError(
|
308
|
+
"Number of prefill URLs must match number of bootstrap ports"
|
309
|
+
)
|
310
|
+
exit(1)
|
311
|
+
|
312
|
+
prefill_configs = []
|
313
|
+
for url, port in zip(prefill_urls, bootstrap_ports):
|
314
|
+
prefill_configs.append(PrefillConfig(url, port))
|
315
|
+
|
316
|
+
decode_addrs = args.decode.split(",")
|
317
|
+
|
318
|
+
run(prefill_configs, decode_addrs, args.host, args.port)
|