sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +67 -13
  10. sglang/srt/disaggregation/fake/__init__.py +1 -0
  11. sglang/srt/disaggregation/fake/conn.py +88 -0
  12. sglang/srt/disaggregation/mini_lb.py +45 -8
  13. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  14. sglang/srt/disaggregation/prefill.py +36 -12
  15. sglang/srt/disaggregation/utils.py +16 -2
  16. sglang/srt/entrypoints/engine.py +9 -0
  17. sglang/srt/entrypoints/http_server.py +35 -4
  18. sglang/srt/function_call_parser.py +77 -5
  19. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  20. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  21. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  22. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  23. sglang/srt/layers/attention/utils.py +1 -1
  24. sglang/srt/layers/attention/vision.py +2 -0
  25. sglang/srt/layers/layernorm.py +38 -16
  26. sglang/srt/layers/logits_processor.py +2 -2
  27. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  44. sglang/srt/layers/pooler.py +6 -0
  45. sglang/srt/layers/quantization/awq.py +5 -1
  46. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  47. sglang/srt/layers/quantization/fp8.py +20 -22
  48. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  49. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +170 -126
  52. sglang/srt/managers/data_parallel_controller.py +10 -3
  53. sglang/srt/managers/io_struct.py +7 -0
  54. sglang/srt/managers/mm_utils.py +85 -28
  55. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  56. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  57. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  58. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  59. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  61. sglang/srt/managers/schedule_batch.py +38 -12
  62. sglang/srt/managers/scheduler.py +41 -28
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  64. sglang/srt/managers/tokenizer_manager.py +5 -1
  65. sglang/srt/managers/tp_worker.py +3 -3
  66. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  67. sglang/srt/mem_cache/memory_pool.py +87 -0
  68. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +19 -25
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +144 -70
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpmo.py +5 -1
  78. sglang/srt/models/mllama4.py +2 -2
  79. sglang/srt/models/qwen2_5_vl.py +3 -6
  80. sglang/srt/models/qwen2_vl.py +3 -7
  81. sglang/srt/models/roberta.py +178 -0
  82. sglang/srt/openai_api/adapter.py +50 -11
  83. sglang/srt/openai_api/protocol.py +2 -0
  84. sglang/srt/reasoning_parser.py +25 -1
  85. sglang/srt/server_args.py +31 -24
  86. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  87. sglang/srt/torch_memory_saver_adapter.py +10 -1
  88. sglang/srt/utils.py +5 -1
  89. sglang/test/runners.py +6 -13
  90. sglang/test/send_one.py +84 -28
  91. sglang/test/test_utils.py +74 -18
  92. sglang/version.py +1 -1
  93. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
  94. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
  95. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
  96. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  97. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
32
32
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
33
33
  from sglang.srt.disaggregation.utils import (
34
34
  DisaggregationMode,
35
+ FakeBootstrapHost,
35
36
  KVClassType,
36
37
  ReqToMetadataIdxAllocator,
37
38
  TransferBackend,
@@ -133,11 +134,16 @@ class DecodePreallocQueue:
133
134
 
134
135
  def add(self, req: Req) -> None:
135
136
  """Add a request to the pending queue."""
136
-
137
- kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137
+ if req.bootstrap_host == FakeBootstrapHost:
138
+ # Fake transfer for warmup reqs
139
+ kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
140
+ else:
141
+ kv_receiver_class = get_kv_class(
142
+ self.transfer_backend, KVClassType.RECEIVER
143
+ )
138
144
  kv_receiver = kv_receiver_class(
139
145
  mgr=self.kv_manager,
140
- bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
146
+ bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
141
147
  bootstrap_room=req.bootstrap_room,
142
148
  )
143
149
  self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
@@ -307,7 +313,7 @@ class DecodeTransferQueue:
307
313
  def extend(self, req_conns) -> None:
308
314
  self.queue.extend(req_conns)
309
315
 
310
- def pop_transferred(self) -> List[Req]:
316
+ def pop_transferred(self) -> List[DecodeRequest]:
311
317
  if not self.queue:
312
318
  return []
313
319
 
@@ -330,7 +336,7 @@ class DecodeTransferQueue:
330
336
  assert len(decode_req.req.output_ids) == 0
331
337
  assert decode_req.req.transferred_output_id is None
332
338
  decode_req.req.transferred_output_id = output_id
333
- transferred_reqs.append(decode_req.req)
339
+ transferred_reqs.append(decode_req)
334
340
  indices_to_remove.add(i)
335
341
  elif poll in [
336
342
  KVPoll.Bootstrapping,
@@ -444,8 +450,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
444
450
 
445
451
  class SchedulerDisaggregationDecodeMixin:
446
452
 
453
+ def _prepare_idle_batch_and_run(self, batch, delay_process=False):
454
+ batch, _ = self.prepare_dp_attn_batch(batch)
455
+ result = None
456
+ if batch:
457
+ result = self.run_batch(batch)
458
+ if not delay_process:
459
+ self.process_batch_result(batch, result)
460
+ return batch, result
461
+
447
462
  @torch.no_grad()
448
- def event_loop_normal_disagg_decode(self):
463
+ def event_loop_normal_disagg_decode(self: Scheduler):
449
464
  """A normal scheduler loop for decode worker in disaggregation mode."""
450
465
 
451
466
  while True:
@@ -456,14 +471,25 @@ class SchedulerDisaggregationDecodeMixin:
456
471
  batch = self.get_next_disagg_decode_batch_to_run()
457
472
  self.cur_batch = batch
458
473
 
474
+ prepare_dp_attn_flag = (
475
+ self.server_args.enable_dp_attention
476
+ or self.server_args.enable_sp_layernorm
477
+ )
478
+
459
479
  if batch:
460
480
  # Generate fake extend output.
461
481
  if batch.forward_mode.is_extend():
462
482
  # Note: Logprobs should be handled on the prefill engine.
463
483
  self.stream_output(batch.reqs, False)
484
+ if prepare_dp_attn_flag:
485
+ self._prepare_idle_batch_and_run(None)
464
486
  else:
487
+ if prepare_dp_attn_flag:
488
+ self.prepare_dp_attn_batch(batch)
465
489
  result = self.run_batch(batch)
466
490
  self.process_batch_result(batch, result)
491
+ elif prepare_dp_attn_flag:
492
+ batch, _ = self._prepare_idle_batch_and_run(None)
467
493
 
468
494
  if batch is None and (
469
495
  len(self.disagg_decode_transfer_queue.queue)
@@ -477,10 +503,10 @@ class SchedulerDisaggregationDecodeMixin:
477
503
  self.last_batch = batch
478
504
 
479
505
  @torch.no_grad()
480
- def event_loop_overlap_disagg_decode(self):
506
+ def event_loop_overlap_disagg_decode(self: Scheduler):
481
507
  result_queue = deque()
482
508
  self.last_batch: Optional[ScheduleBatch] = None
483
- self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
509
+ self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
484
510
 
485
511
  while True:
486
512
  recv_reqs = self.recv_requests()
@@ -489,20 +515,41 @@ class SchedulerDisaggregationDecodeMixin:
489
515
  self.process_decode_queue()
490
516
  batch = self.get_next_disagg_decode_batch_to_run()
491
517
  self.cur_batch = batch
492
- last_batch_is_extend = False
518
+ last_batch_in_queue = False
519
+
520
+ prepare_dp_attn_flag = (
521
+ self.server_args.enable_dp_attention
522
+ or self.server_args.enable_sp_layernorm
523
+ )
493
524
 
494
525
  if batch:
495
526
  # Generate fake extend output.
496
527
  if batch.forward_mode.is_extend():
497
528
  # Note: Logprobs should be handled on the prefill engine.
498
529
  self.stream_output(batch.reqs, False)
499
- last_batch_is_extend = True
530
+ if prepare_dp_attn_flag:
531
+ batch_, result = self._prepare_idle_batch_and_run(
532
+ None, delay_process=True
533
+ )
534
+ if batch_:
535
+ result_queue.append((batch_.copy(), result))
536
+ last_batch_in_queue = True
500
537
  else:
538
+ if prepare_dp_attn_flag:
539
+ self.prepare_dp_attn_batch(batch)
501
540
  result = self.run_batch(batch)
502
541
  result_queue.append((batch.copy(), result))
542
+ last_batch_in_queue = True
543
+ elif prepare_dp_attn_flag:
544
+ batch, result = self._prepare_idle_batch_and_run(
545
+ None, delay_process=True
546
+ )
547
+ if batch:
548
+ result_queue.append((batch.copy(), result))
549
+ last_batch_in_queue = True
503
550
 
504
551
  # Process the results of the previous batch but skip if the last batch is extend
505
- if self.last_batch and not self.last_batch_is_extend:
552
+ if self.last_batch and self.last_batch_in_queue:
506
553
  tmp_batch, tmp_result = result_queue.popleft()
507
554
  self.process_batch_result(tmp_batch, tmp_result)
508
555
 
@@ -516,7 +563,7 @@ class SchedulerDisaggregationDecodeMixin:
516
563
  self.new_token_ratio = self.init_new_token_ratio
517
564
 
518
565
  self.last_batch = batch
519
- self.last_batch_is_extend = last_batch_is_extend
566
+ self.last_batch_in_queue = last_batch_in_queue
520
567
 
521
568
  def get_next_disagg_decode_batch_to_run(
522
569
  self: Scheduler,
@@ -600,8 +647,15 @@ class SchedulerDisaggregationDecodeMixin:
600
647
 
601
648
  def process_decode_queue(self: Scheduler):
602
649
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
650
+
651
+ def _num_pre_alloc(req):
652
+ return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
653
+
654
+ self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
603
655
  self.disagg_decode_transfer_queue.extend(req_conns)
604
656
  alloc_reqs = (
605
657
  self.disagg_decode_transfer_queue.pop_transferred()
606
658
  ) # the requests which kv has arrived
607
- self.waiting_queue.extend(alloc_reqs)
659
+ self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
660
+
661
+ self.waiting_queue.extend([req.req for req in alloc_reqs])
@@ -0,0 +1 @@
1
+ from .conn import FakeKVReceiver, FakeKVSender
@@ -0,0 +1,88 @@
1
+ import logging
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.base.conn import (
8
+ BaseKVManager,
9
+ BaseKVReceiver,
10
+ BaseKVSender,
11
+ KVArgs,
12
+ KVPoll,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19
+ class FakeKVSender(BaseKVSender):
20
+ def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
21
+ self.has_sent = False
22
+
23
+ def poll(self) -> KVPoll:
24
+ if self.has_sent is False:
25
+ # Assume handshake completed instantly
26
+ return KVPoll.WaitingForInput
27
+ else:
28
+ # Assume transfer completed instantly
29
+ logger.info("FakeKVSender poll success")
30
+ return KVPoll.Success
31
+
32
+ def init(
33
+ self,
34
+ kv_indices: list[int],
35
+ aux_index: Optional[int] = None,
36
+ dest_ranks: Optional[list[int]] = None,
37
+ ):
38
+ logger.info(
39
+ f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
40
+ )
41
+ pass
42
+
43
+ def send(
44
+ self,
45
+ kv_indices: npt.NDArray[np.int64],
46
+ index_slice: slice,
47
+ is_last: bool,
48
+ ):
49
+ logger.info(
50
+ f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51
+ )
52
+ if is_last:
53
+ self.has_sent = True
54
+ logger.info(f"FakeKVSender send success")
55
+ else:
56
+ self.has_sent = False
57
+ logger.info(f"FakeKVSender send fake transfering")
58
+
59
+ def failure_exception(self):
60
+ raise Exception("Fake KVSender Exception")
61
+
62
+
63
+ class FakeKVReceiver(BaseKVReceiver):
64
+ def __init__(
65
+ self,
66
+ mgr: BaseKVManager,
67
+ bootstrap_addr: str,
68
+ bootstrap_room: Optional[int] = None,
69
+ ):
70
+ self.has_init = False
71
+
72
+ def poll(self) -> KVPoll:
73
+ if self.has_init is False:
74
+ # Assume handshake completed instantly
75
+ return KVPoll.WaitingForInput
76
+ else:
77
+ # Assume transfer completed instantly
78
+ logger.info("FakeKVReceiver poll success")
79
+ return KVPoll.Success
80
+
81
+ def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
82
+ self.has_init = True
83
+ logger.info(
84
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
85
+ )
86
+
87
+ def failure_exception(self):
88
+ raise Exception("Fake KVReceiver Exception")
@@ -6,6 +6,7 @@ import asyncio
6
6
  import random
7
7
  import urllib
8
8
  from itertools import chain
9
+ from typing import List
9
10
 
10
11
  import aiohttp
11
12
  import orjson
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
14
15
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
15
16
 
16
17
 
18
+ class PrefillConfig:
19
+ def __init__(self, url: str, bootstrap_port: int):
20
+ self.url = url
21
+ self.bootstrap_port = bootstrap_port
22
+
23
+
17
24
  class MiniLoadBalancer:
18
- def __init__(self, prefill_servers, decode_servers):
19
- self.prefill_servers = prefill_servers
25
+ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
26
+ self.prefill_configs = prefill_configs
27
+ self.prefill_servers = [p.url for p in prefill_configs]
20
28
  self.decode_servers = decode_servers
21
29
 
22
30
  def select_pair(self):
23
- return random.choice(self.prefill_servers), random.choice(self.decode_servers)
31
+ prefill_config = random.choice(self.prefill_configs)
32
+ decode_server = random.choice(self.decode_servers)
33
+ return prefill_config.url, prefill_config.bootstrap_port, decode_server
24
34
 
25
35
  async def generate(
26
36
  self, modified_request, prefill_server, decode_server, endpoint
@@ -160,7 +170,7 @@ async def get_model_info():
160
170
 
161
171
  @app.post("/generate")
162
172
  async def handle_generate_request(request_data: dict):
163
- prefill_server, decode_server = load_balancer.select_pair()
173
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
164
174
 
165
175
  # Parse and transform prefill_server for bootstrap data
166
176
  parsed_url = urllib.parse.urlparse(prefill_server)
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
172
182
  modified_request.update(
173
183
  {
174
184
  "bootstrap_host": [hostname] * batch_size,
185
+ "bootstrap_port": [bootstrap_port] * batch_size,
175
186
  "bootstrap_room": [
176
187
  _generate_bootstrap_room() for _ in range(batch_size)
177
188
  ],
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
181
192
  modified_request.update(
182
193
  {
183
194
  "bootstrap_host": hostname,
195
+ "bootstrap_port": bootstrap_port,
184
196
  "bootstrap_room": _generate_bootstrap_room(),
185
197
  }
186
198
  )
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
197
209
 
198
210
  @app.post("/v1/chat/completions")
199
211
  async def handle_completion_request(request_data: dict):
200
- prefill_server, decode_server = load_balancer.select_pair()
212
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
201
213
 
202
214
  # Parse and transform prefill_server for bootstrap data
203
215
  parsed_url = urllib.parse.urlparse(prefill_server)
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
206
218
  modified_request.update(
207
219
  {
208
220
  "bootstrap_host": hostname,
221
+ "bootstrap_port": bootstrap_port,
209
222
  "bootstrap_room": random.randint(0, 2**63 - 1),
210
223
  }
211
224
  )
@@ -255,9 +268,9 @@ async def get_models():
255
268
  raise HTTPException(status_code=500, detail=str(e))
256
269
 
257
270
 
258
- def run(prefill_addrs, decode_addrs, host, port):
271
+ def run(prefill_configs, decode_addrs, host, port):
259
272
  global load_balancer
260
- load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
273
+ load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
261
274
  uvicorn.run(app, host=host, port=port)
262
275
 
263
276
 
@@ -268,6 +281,11 @@ if __name__ == "__main__":
268
281
  parser.add_argument(
269
282
  "--prefill", required=True, help="Comma-separated URLs for prefill servers"
270
283
  )
284
+ parser.add_argument(
285
+ "--prefill-bootstrap-ports",
286
+ help="Comma-separated bootstrap ports for prefill servers",
287
+ default="8998",
288
+ )
271
289
  parser.add_argument(
272
290
  "--decode", required=True, help="Comma-separated URLs for decode servers"
273
291
  )
@@ -278,4 +296,23 @@ if __name__ == "__main__":
278
296
  "--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
279
297
  )
280
298
  args = parser.parse_args()
281
- run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
299
+
300
+ prefill_urls = args.prefill.split(",")
301
+ bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
+
303
+ if len(bootstrap_ports) == 1:
304
+ bootstrap_ports = bootstrap_ports * len(prefill_urls)
305
+ else:
306
+ if len(bootstrap_ports) != len(prefill_urls):
307
+ raise ValueError(
308
+ "Number of prefill URLs must match number of bootstrap ports"
309
+ )
310
+ exit(1)
311
+
312
+ prefill_configs = []
313
+ for url, port in zip(prefill_urls, bootstrap_ports):
314
+ prefill_configs.append(PrefillConfig(url, port))
315
+
316
+ decode_addrs = args.decode.split(",")
317
+
318
+ run(prefill_configs, decode_addrs, args.host, args.port)