sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # Adapted from
17
17
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
18
18
  import dataclasses
19
+ import re
19
20
  from enum import IntEnum, auto
20
21
  from typing import Callable, Dict, List, Optional, Tuple, Union
21
22
 
@@ -48,6 +49,7 @@ class SeparatorStyle(IntEnum):
48
49
  DeepSeekVL2 = auto()
49
50
  QWEN2_VL_EMBED = auto()
50
51
  GEMMA3 = auto()
52
+ MPT = auto()
51
53
 
52
54
 
53
55
  @dataclasses.dataclass
@@ -327,6 +329,16 @@ class Conversation:
327
329
  ret += role
328
330
  return ret
329
331
 
332
+ elif self.sep_style == SeparatorStyle.MPT:
333
+ ret = system_prompt + self.sep
334
+ for role, message in self.messages:
335
+ if message:
336
+ if type(message) is tuple:
337
+ message, _, _ = message
338
+ ret += role + message + self.sep
339
+ else:
340
+ ret += role
341
+ return ret
330
342
  else:
331
343
  raise ValueError(f"Invalid style: {self.sep_style}")
332
344
 
@@ -570,8 +582,11 @@ def generate_chat_conv(
570
582
  real_content += "\n" # for video
571
583
  real_content += content.text
572
584
  elif content.type == "image_url":
573
- # NOTE: Only works for llava
574
- real_content += image_token
585
+ # NOTE: works for llava and intervl2_5
586
+ if conv.name == "internvl-2-5":
587
+ real_content = image_token + real_content
588
+ else:
589
+ real_content += image_token
575
590
  conv.append_image(content.image_url.url)
576
591
  elif content.type == "audio_url":
577
592
  real_content += audio_token
@@ -619,6 +634,20 @@ register_conv_template(
619
634
  )
620
635
  )
621
636
 
637
+ # reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
638
+ register_conv_template(
639
+ Conversation(
640
+ name="mistral",
641
+ system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
642
+ roles=("[INST]", "[/INST]"),
643
+ sep_style=SeparatorStyle.LLAMA2,
644
+ sep=" ",
645
+ sep2=" </s><s>",
646
+ stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
647
+ image_token="[IMG]",
648
+ )
649
+ )
650
+
622
651
  # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
623
652
  register_conv_template(
624
653
  Conversation(
@@ -703,6 +732,19 @@ register_conv_template(
703
732
  )
704
733
  )
705
734
 
735
+ register_conv_template(
736
+ Conversation(
737
+ name="internvl-2-5",
738
+ system_template="<|im_start|>system\n{system_message}",
739
+ system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
740
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
741
+ sep_style=SeparatorStyle.MPT,
742
+ sep="<|im_end|>\n",
743
+ stop_str=["<|im_end|>", "<|action_end|>"],
744
+ image_token="<image>",
745
+ )
746
+ )
747
+
706
748
  # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
707
749
  register_conv_template(
708
750
  Conversation(
@@ -826,90 +868,80 @@ register_conv_template(
826
868
 
827
869
 
828
870
  @register_conv_template_matching_function
829
- def match_deepseek_janus_pro(model_path: str):
830
- if (
831
- "llama" in model_path.lower()
832
- and "3.2" in model_path.lower()
833
- and "vision" in model_path.lower()
834
- ):
871
+ def match_internvl(model_path: str):
872
+ if re.search(r"internvl2_5", model_path, re.IGNORECASE):
873
+ return "internvl-2-5"
874
+
875
+
876
+ @register_conv_template_matching_function
877
+ def match_llama_3_vision(model_path: str):
878
+ if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
835
879
  return "llama_3_vision"
836
880
 
837
881
 
838
882
  @register_conv_template_matching_function
839
883
  def match_deepseek_janus_pro(model_path: str):
840
- if "janus" in model_path.lower():
884
+ if re.search(r"janus", model_path, re.IGNORECASE):
841
885
  return "janus-pro"
842
886
 
843
887
 
844
888
  @register_conv_template_matching_function
845
889
  def match_vicuna(model_path: str):
846
- if "vicuna" in model_path.lower():
847
- return "vicuna_v1.1"
848
- if "llava-v1.5" in model_path.lower():
849
- return "vicuna_v1.1"
850
- if "llava-next-video-7b" in model_path.lower():
890
+ if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
851
891
  return "vicuna_v1.1"
852
892
 
853
893
 
854
894
  @register_conv_template_matching_function
855
895
  def match_llama2_chat(model_path: str):
856
- model_path = model_path.lower()
857
- if "llama-2" in model_path and "chat" in model_path:
858
- return "llama-2"
859
- if (
860
- "mistral" in model_path or "mixtral" in model_path
861
- ) and "instruct" in model_path:
862
- return "llama-2"
863
- if "codellama" in model_path and "instruct" in model_path:
896
+ if re.search(
897
+ r"llama-2.*chat|codellama.*instruct",
898
+ model_path,
899
+ re.IGNORECASE,
900
+ ):
864
901
  return "llama-2"
865
902
 
866
903
 
904
+ @register_conv_template_matching_function
905
+ def match_mistral(model_path: str):
906
+ if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
907
+ return "mistral"
908
+
909
+
867
910
  @register_conv_template_matching_function
868
911
  def match_deepseek_vl(model_path: str):
869
- model_path = model_path.lower()
870
- if "deepseek" in model_path and "vl2" in model_path:
912
+ if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
871
913
  return "deepseek-vl2"
872
914
 
873
915
 
874
916
  @register_conv_template_matching_function
875
- def match_chat_ml(model_path: str):
876
- # import pdb;pdb.set_trace()
877
- model_path = model_path.lower()
878
- # Now the suffix for qwen2 chat model is "instruct"
879
- if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
917
+ def match_qwen_chat_ml(model_path: str):
918
+ if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
880
919
  return "gme-qwen2-vl"
881
- if "qwen" in model_path and "vl" in model_path:
920
+ if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
882
921
  return "qwen2-vl"
883
- if (
884
- "llava-v1.6-34b" in model_path
885
- or "llava-v1.6-yi-34b" in model_path
886
- or "llava-next-video-34b" in model_path
887
- or "llava-onevision-qwen2" in model_path
922
+ if re.search(
923
+ r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
924
+ model_path,
925
+ re.IGNORECASE,
888
926
  ):
889
927
  return "chatml-llava"
890
928
 
891
929
 
892
930
  @register_conv_template_matching_function
893
- def match_gemma_it(model_path: str):
894
- model_path = model_path.lower()
895
- if "gemma" in model_path and "it" in model_path:
896
- return "gemma-it"
897
- if "gemma-3" in model_path and "1b" not in model_path:
898
- # gemma-3-1b-it is completion model
931
+ def match_gemma3_instruct(model_path: str):
932
+ if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
899
933
  return "gemma-it"
900
934
 
901
935
 
902
936
  @register_conv_template_matching_function
903
937
  def match_openbmb_minicpm(model_path: str):
904
- model_path = model_path.lower()
905
- if "minicpm-v" in model_path:
938
+ if re.search(r"minicpm-v", model_path, re.IGNORECASE):
906
939
  return "minicpmv"
907
- elif "minicpm-o" in model_path:
940
+ elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
908
941
  return "minicpmo"
909
942
 
910
943
 
911
944
  @register_conv_template_matching_function
912
945
  def match_moonshot_kimivl(model_path: str):
913
- model_path = model_path.lower()
914
- if "kimi" in model_path and "vl" in model_path:
946
+ if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
915
947
  return "kimi-vl"
@@ -37,6 +37,7 @@ class BaseKVManager(ABC):
37
37
  args: KVArgs,
38
38
  disaggregation_mode: DisaggregationMode,
39
39
  server_args: ServerArgs,
40
+ is_mla_backend: Optional[bool] = False,
40
41
  ): ...
41
42
 
42
43
 
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import os
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from typing import TYPE_CHECKING, List, Optional, Tuple
@@ -37,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
37
38
  ReqToMetadataIdxAllocator,
38
39
  TransferBackend,
39
40
  get_kv_class,
41
+ is_mla_backend,
40
42
  kv_to_page_indices,
41
43
  poll_and_all_reduce,
42
44
  )
@@ -86,6 +88,7 @@ class DecodePreallocQueue:
86
88
  self.req_to_token_pool = req_to_token_pool
87
89
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
88
90
  self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
91
+ self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
89
92
  self.aux_dtype = aux_dtype
90
93
  self.metadata_buffers = metadata_buffers
91
94
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
@@ -97,7 +100,9 @@ class DecodePreallocQueue:
97
100
  self.tp_size = tp_size
98
101
  self.bootstrap_port = bootstrap_port
99
102
 
100
- self.num_reserved_decode_tokens = 512
103
+ self.num_reserved_decode_tokens = int(
104
+ os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
105
+ )
101
106
 
102
107
  # Queue for requests pending pre-allocation
103
108
  self.queue: List[DecodeRequest] = []
@@ -128,7 +133,10 @@ class DecodePreallocQueue:
128
133
  kv_args.gpu_id = self.scheduler.gpu_id
129
134
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
130
135
  kv_manager = kv_manager_class(
131
- kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
136
+ kv_args,
137
+ DisaggregationMode.DECODE,
138
+ self.scheduler.server_args,
139
+ self.is_mla_backend,
132
140
  )
133
141
  return kv_manager
134
142
 
@@ -506,7 +514,7 @@ class SchedulerDisaggregationDecodeMixin:
506
514
  def event_loop_overlap_disagg_decode(self: Scheduler):
507
515
  result_queue = deque()
508
516
  self.last_batch: Optional[ScheduleBatch] = None
509
- self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
517
+ self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
510
518
 
511
519
  while True:
512
520
  recv_reqs = self.recv_requests()
@@ -54,7 +54,7 @@ class FakeKVSender(BaseKVSender):
54
54
  logger.info(f"FakeKVSender send success")
55
55
  else:
56
56
  self.has_sent = False
57
- logger.info(f"FakeKVSender send fake transfering")
57
+ logger.info(f"FakeKVSender send fake transferring")
58
58
 
59
59
  def failure_exception(self):
60
60
  raise Exception("Fake KVSender Exception")
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
3
3
  """
4
4
 
5
5
  import asyncio
6
+ import dataclasses
7
+ import logging
6
8
  import random
7
9
  import urllib
8
10
  from itertools import chain
9
- from typing import List
11
+ from typing import List, Optional
10
12
 
11
13
  import aiohttp
12
14
  import orjson
@@ -14,11 +16,32 @@ import uvicorn
14
16
  from fastapi import FastAPI, HTTPException
15
17
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
16
18
 
19
+ from sglang.srt.disaggregation.utils import PDRegistryRequest
17
20
 
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("pdlb")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ @dataclasses.dataclass
18
42
  class PrefillConfig:
19
- def __init__(self, url: str, bootstrap_port: int):
20
- self.url = url
21
- self.bootstrap_port = bootstrap_port
43
+ url: str
44
+ bootstrap_port: Optional[int] = None
22
45
 
23
46
 
24
47
  class MiniLoadBalancer:
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
28
51
  self.decode_servers = decode_servers
29
52
 
30
53
  def select_pair(self):
54
+ # TODO: return some message instead of panic
55
+ assert len(self.prefill_configs) > 0, "No prefill servers available"
56
+ assert len(self.decode_servers) > 0, "No decode servers available"
57
+
31
58
  prefill_config = random.choice(self.prefill_configs)
32
59
  decode_server = random.choice(self.decode_servers)
33
60
  return prefill_config.url, prefill_config.bootstrap_port, decode_server
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
47
74
  session.post(f"{decode_server}/{endpoint}", json=modified_request),
48
75
  ]
49
76
  # Wait for both responses to complete. Prefill should end first.
50
- prefill_response, decode_response = await asyncio.gather(*tasks)
77
+ _, decode_response = await asyncio.gather(*tasks)
51
78
 
52
79
  return ORJSONResponse(
53
80
  content=await decode_response.json(),
@@ -268,6 +295,32 @@ async def get_models():
268
295
  raise HTTPException(status_code=500, detail=str(e))
269
296
 
270
297
 
298
+ @app.post("/register")
299
+ async def register(obj: PDRegistryRequest):
300
+ if obj.mode == "prefill":
301
+ load_balancer.prefill_configs.append(
302
+ PrefillConfig(obj.registry_url, obj.bootstrap_port)
303
+ )
304
+ logger.info(
305
+ f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
306
+ )
307
+ elif obj.mode == "decode":
308
+ load_balancer.decode_servers.append(obj.registry_url)
309
+ logger.info(f"Registered decode server: {obj.registry_url}")
310
+ else:
311
+ raise HTTPException(
312
+ status_code=400,
313
+ detail="Invalid mode. Must be either PREFILL or DECODE.",
314
+ )
315
+
316
+ logger.info(
317
+ f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
318
+ f"#Decode servers: {len(load_balancer.decode_servers)}"
319
+ )
320
+
321
+ return Response(status_code=200)
322
+
323
+
271
324
  def run(prefill_configs, decode_addrs, host, port):
272
325
  global load_balancer
273
326
  load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
@@ -279,15 +332,16 @@ if __name__ == "__main__":
279
332
 
280
333
  parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
281
334
  parser.add_argument(
282
- "--prefill", required=True, help="Comma-separated URLs for prefill servers"
335
+ "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
283
336
  )
284
337
  parser.add_argument(
285
- "--prefill-bootstrap-ports",
286
- help="Comma-separated bootstrap ports for prefill servers",
287
- default="8998",
338
+ "--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
288
339
  )
289
340
  parser.add_argument(
290
- "--decode", required=True, help="Comma-separated URLs for decode servers"
341
+ "--prefill-bootstrap-ports",
342
+ type=int,
343
+ nargs="+",
344
+ help="Bootstrap ports for prefill servers",
291
345
  )
292
346
  parser.add_argument(
293
347
  "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
@@ -297,22 +351,19 @@ if __name__ == "__main__":
297
351
  )
298
352
  args = parser.parse_args()
299
353
 
300
- prefill_urls = args.prefill.split(",")
301
- bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
-
303
- if len(bootstrap_ports) == 1:
304
- bootstrap_ports = bootstrap_ports * len(prefill_urls)
354
+ bootstrap_ports = args.prefill_bootstrap_ports
355
+ if bootstrap_ports is None:
356
+ bootstrap_ports = [None] * len(args.prefill)
357
+ elif len(bootstrap_ports) == 1:
358
+ bootstrap_ports = bootstrap_ports * len(args.prefill)
305
359
  else:
306
- if len(bootstrap_ports) != len(prefill_urls):
360
+ if len(bootstrap_ports) != len(args.prefill):
307
361
  raise ValueError(
308
362
  "Number of prefill URLs must match number of bootstrap ports"
309
363
  )
310
- exit(1)
311
-
312
- prefill_configs = []
313
- for url, port in zip(prefill_urls, bootstrap_ports):
314
- prefill_configs.append(PrefillConfig(url, port))
315
364
 
316
- decode_addrs = args.decode.split(",")
365
+ prefill_configs = [
366
+ PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
367
+ ]
317
368
 
318
- run(prefill_configs, decode_addrs, args.host, args.port)
369
+ run(prefill_configs, args.decode, args.host, args.port)