sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/conversation.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# Adapted from
|
17
17
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
18
18
|
import dataclasses
|
19
|
+
import re
|
19
20
|
from enum import IntEnum, auto
|
20
21
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
21
22
|
|
@@ -48,6 +49,7 @@ class SeparatorStyle(IntEnum):
|
|
48
49
|
DeepSeekVL2 = auto()
|
49
50
|
QWEN2_VL_EMBED = auto()
|
50
51
|
GEMMA3 = auto()
|
52
|
+
MPT = auto()
|
51
53
|
|
52
54
|
|
53
55
|
@dataclasses.dataclass
|
@@ -327,6 +329,16 @@ class Conversation:
|
|
327
329
|
ret += role
|
328
330
|
return ret
|
329
331
|
|
332
|
+
elif self.sep_style == SeparatorStyle.MPT:
|
333
|
+
ret = system_prompt + self.sep
|
334
|
+
for role, message in self.messages:
|
335
|
+
if message:
|
336
|
+
if type(message) is tuple:
|
337
|
+
message, _, _ = message
|
338
|
+
ret += role + message + self.sep
|
339
|
+
else:
|
340
|
+
ret += role
|
341
|
+
return ret
|
330
342
|
else:
|
331
343
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
332
344
|
|
@@ -570,8 +582,11 @@ def generate_chat_conv(
|
|
570
582
|
real_content += "\n" # for video
|
571
583
|
real_content += content.text
|
572
584
|
elif content.type == "image_url":
|
573
|
-
# NOTE:
|
574
|
-
|
585
|
+
# NOTE: works for llava and intervl2_5
|
586
|
+
if conv.name == "internvl-2-5":
|
587
|
+
real_content = image_token + real_content
|
588
|
+
else:
|
589
|
+
real_content += image_token
|
575
590
|
conv.append_image(content.image_url.url)
|
576
591
|
elif content.type == "audio_url":
|
577
592
|
real_content += audio_token
|
@@ -619,6 +634,20 @@ register_conv_template(
|
|
619
634
|
)
|
620
635
|
)
|
621
636
|
|
637
|
+
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
638
|
+
register_conv_template(
|
639
|
+
Conversation(
|
640
|
+
name="mistral",
|
641
|
+
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
|
642
|
+
roles=("[INST]", "[/INST]"),
|
643
|
+
sep_style=SeparatorStyle.LLAMA2,
|
644
|
+
sep=" ",
|
645
|
+
sep2=" </s><s>",
|
646
|
+
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
|
647
|
+
image_token="[IMG]",
|
648
|
+
)
|
649
|
+
)
|
650
|
+
|
622
651
|
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
623
652
|
register_conv_template(
|
624
653
|
Conversation(
|
@@ -703,6 +732,19 @@ register_conv_template(
|
|
703
732
|
)
|
704
733
|
)
|
705
734
|
|
735
|
+
register_conv_template(
|
736
|
+
Conversation(
|
737
|
+
name="internvl-2-5",
|
738
|
+
system_template="<|im_start|>system\n{system_message}",
|
739
|
+
system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
740
|
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
741
|
+
sep_style=SeparatorStyle.MPT,
|
742
|
+
sep="<|im_end|>\n",
|
743
|
+
stop_str=["<|im_end|>", "<|action_end|>"],
|
744
|
+
image_token="<image>",
|
745
|
+
)
|
746
|
+
)
|
747
|
+
|
706
748
|
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
707
749
|
register_conv_template(
|
708
750
|
Conversation(
|
@@ -826,90 +868,80 @@ register_conv_template(
|
|
826
868
|
|
827
869
|
|
828
870
|
@register_conv_template_matching_function
|
829
|
-
def
|
830
|
-
if (
|
831
|
-
"
|
832
|
-
|
833
|
-
|
834
|
-
|
871
|
+
def match_internvl(model_path: str):
|
872
|
+
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
873
|
+
return "internvl-2-5"
|
874
|
+
|
875
|
+
|
876
|
+
@register_conv_template_matching_function
|
877
|
+
def match_llama_3_vision(model_path: str):
|
878
|
+
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
835
879
|
return "llama_3_vision"
|
836
880
|
|
837
881
|
|
838
882
|
@register_conv_template_matching_function
|
839
883
|
def match_deepseek_janus_pro(model_path: str):
|
840
|
-
if "janus"
|
884
|
+
if re.search(r"janus", model_path, re.IGNORECASE):
|
841
885
|
return "janus-pro"
|
842
886
|
|
843
887
|
|
844
888
|
@register_conv_template_matching_function
|
845
889
|
def match_vicuna(model_path: str):
|
846
|
-
if "vicuna"
|
847
|
-
return "vicuna_v1.1"
|
848
|
-
if "llava-v1.5" in model_path.lower():
|
849
|
-
return "vicuna_v1.1"
|
850
|
-
if "llava-next-video-7b" in model_path.lower():
|
890
|
+
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
851
891
|
return "vicuna_v1.1"
|
852
892
|
|
853
893
|
|
854
894
|
@register_conv_template_matching_function
|
855
895
|
def match_llama2_chat(model_path: str):
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
) and "instruct" in model_path:
|
862
|
-
return "llama-2"
|
863
|
-
if "codellama" in model_path and "instruct" in model_path:
|
896
|
+
if re.search(
|
897
|
+
r"llama-2.*chat|codellama.*instruct",
|
898
|
+
model_path,
|
899
|
+
re.IGNORECASE,
|
900
|
+
):
|
864
901
|
return "llama-2"
|
865
902
|
|
866
903
|
|
904
|
+
@register_conv_template_matching_function
|
905
|
+
def match_mistral(model_path: str):
|
906
|
+
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
907
|
+
return "mistral"
|
908
|
+
|
909
|
+
|
867
910
|
@register_conv_template_matching_function
|
868
911
|
def match_deepseek_vl(model_path: str):
|
869
|
-
|
870
|
-
if "deepseek" in model_path and "vl2" in model_path:
|
912
|
+
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
871
913
|
return "deepseek-vl2"
|
872
914
|
|
873
915
|
|
874
916
|
@register_conv_template_matching_function
|
875
|
-
def
|
876
|
-
|
877
|
-
model_path = model_path.lower()
|
878
|
-
# Now the suffix for qwen2 chat model is "instruct"
|
879
|
-
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
|
917
|
+
def match_qwen_chat_ml(model_path: str):
|
918
|
+
if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
|
880
919
|
return "gme-qwen2-vl"
|
881
|
-
if "qwen"
|
920
|
+
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
882
921
|
return "qwen2-vl"
|
883
|
-
if (
|
884
|
-
"llava-v1
|
885
|
-
|
886
|
-
|
887
|
-
or "llava-onevision-qwen2" in model_path
|
922
|
+
if re.search(
|
923
|
+
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
924
|
+
model_path,
|
925
|
+
re.IGNORECASE,
|
888
926
|
):
|
889
927
|
return "chatml-llava"
|
890
928
|
|
891
929
|
|
892
930
|
@register_conv_template_matching_function
|
893
|
-
def
|
894
|
-
|
895
|
-
if "gemma" in model_path and "it" in model_path:
|
896
|
-
return "gemma-it"
|
897
|
-
if "gemma-3" in model_path and "1b" not in model_path:
|
898
|
-
# gemma-3-1b-it is completion model
|
931
|
+
def match_gemma3_instruct(model_path: str):
|
932
|
+
if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
|
899
933
|
return "gemma-it"
|
900
934
|
|
901
935
|
|
902
936
|
@register_conv_template_matching_function
|
903
937
|
def match_openbmb_minicpm(model_path: str):
|
904
|
-
|
905
|
-
if "minicpm-v" in model_path:
|
938
|
+
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
906
939
|
return "minicpmv"
|
907
|
-
elif "minicpm-o"
|
940
|
+
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
908
941
|
return "minicpmo"
|
909
942
|
|
910
943
|
|
911
944
|
@register_conv_template_matching_function
|
912
945
|
def match_moonshot_kimivl(model_path: str):
|
913
|
-
|
914
|
-
if "kimi" in model_path and "vl" in model_path:
|
946
|
+
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
|
915
947
|
return "kimi-vl"
|
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
import os
|
24
25
|
from collections import deque
|
25
26
|
from dataclasses import dataclass
|
26
27
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
@@ -37,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
|
|
37
38
|
ReqToMetadataIdxAllocator,
|
38
39
|
TransferBackend,
|
39
40
|
get_kv_class,
|
41
|
+
is_mla_backend,
|
40
42
|
kv_to_page_indices,
|
41
43
|
poll_and_all_reduce,
|
42
44
|
)
|
@@ -86,6 +88,7 @@ class DecodePreallocQueue:
|
|
86
88
|
self.req_to_token_pool = req_to_token_pool
|
87
89
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
88
90
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
91
|
+
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
89
92
|
self.aux_dtype = aux_dtype
|
90
93
|
self.metadata_buffers = metadata_buffers
|
91
94
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
@@ -97,7 +100,9 @@ class DecodePreallocQueue:
|
|
97
100
|
self.tp_size = tp_size
|
98
101
|
self.bootstrap_port = bootstrap_port
|
99
102
|
|
100
|
-
self.num_reserved_decode_tokens =
|
103
|
+
self.num_reserved_decode_tokens = int(
|
104
|
+
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
105
|
+
)
|
101
106
|
|
102
107
|
# Queue for requests pending pre-allocation
|
103
108
|
self.queue: List[DecodeRequest] = []
|
@@ -128,7 +133,10 @@ class DecodePreallocQueue:
|
|
128
133
|
kv_args.gpu_id = self.scheduler.gpu_id
|
129
134
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
130
135
|
kv_manager = kv_manager_class(
|
131
|
-
kv_args,
|
136
|
+
kv_args,
|
137
|
+
DisaggregationMode.DECODE,
|
138
|
+
self.scheduler.server_args,
|
139
|
+
self.is_mla_backend,
|
132
140
|
)
|
133
141
|
return kv_manager
|
134
142
|
|
@@ -506,7 +514,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
506
514
|
def event_loop_overlap_disagg_decode(self: Scheduler):
|
507
515
|
result_queue = deque()
|
508
516
|
self.last_batch: Optional[ScheduleBatch] = None
|
509
|
-
self.last_batch_in_queue = False # last batch is
|
517
|
+
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
510
518
|
|
511
519
|
while True:
|
512
520
|
recv_reqs = self.recv_requests()
|
@@ -54,7 +54,7 @@ class FakeKVSender(BaseKVSender):
|
|
54
54
|
logger.info(f"FakeKVSender send success")
|
55
55
|
else:
|
56
56
|
self.has_sent = False
|
57
|
-
logger.info(f"FakeKVSender send fake
|
57
|
+
logger.info(f"FakeKVSender send fake transferring")
|
58
58
|
|
59
59
|
def failure_exception(self):
|
60
60
|
raise Exception("Fake KVSender Exception")
|
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import asyncio
|
6
|
+
import dataclasses
|
7
|
+
import logging
|
6
8
|
import random
|
7
9
|
import urllib
|
8
10
|
from itertools import chain
|
9
|
-
from typing import List
|
11
|
+
from typing import List, Optional
|
10
12
|
|
11
13
|
import aiohttp
|
12
14
|
import orjson
|
@@ -14,11 +16,32 @@ import uvicorn
|
|
14
16
|
from fastapi import FastAPI, HTTPException
|
15
17
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
16
18
|
|
19
|
+
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
17
20
|
|
21
|
+
|
22
|
+
def setup_logger():
|
23
|
+
logger = logging.getLogger("pdlb")
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
|
26
|
+
formatter = logging.Formatter(
|
27
|
+
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
28
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29
|
+
)
|
30
|
+
|
31
|
+
handler = logging.StreamHandler()
|
32
|
+
handler.setFormatter(formatter)
|
33
|
+
logger.addHandler(handler)
|
34
|
+
|
35
|
+
return logger
|
36
|
+
|
37
|
+
|
38
|
+
logger = setup_logger()
|
39
|
+
|
40
|
+
|
41
|
+
@dataclasses.dataclass
|
18
42
|
class PrefillConfig:
|
19
|
-
|
20
|
-
|
21
|
-
self.bootstrap_port = bootstrap_port
|
43
|
+
url: str
|
44
|
+
bootstrap_port: Optional[int] = None
|
22
45
|
|
23
46
|
|
24
47
|
class MiniLoadBalancer:
|
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
|
|
28
51
|
self.decode_servers = decode_servers
|
29
52
|
|
30
53
|
def select_pair(self):
|
54
|
+
# TODO: return some message instead of panic
|
55
|
+
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
56
|
+
assert len(self.decode_servers) > 0, "No decode servers available"
|
57
|
+
|
31
58
|
prefill_config = random.choice(self.prefill_configs)
|
32
59
|
decode_server = random.choice(self.decode_servers)
|
33
60
|
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
|
|
47
74
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
48
75
|
]
|
49
76
|
# Wait for both responses to complete. Prefill should end first.
|
50
|
-
|
77
|
+
_, decode_response = await asyncio.gather(*tasks)
|
51
78
|
|
52
79
|
return ORJSONResponse(
|
53
80
|
content=await decode_response.json(),
|
@@ -268,6 +295,32 @@ async def get_models():
|
|
268
295
|
raise HTTPException(status_code=500, detail=str(e))
|
269
296
|
|
270
297
|
|
298
|
+
@app.post("/register")
|
299
|
+
async def register(obj: PDRegistryRequest):
|
300
|
+
if obj.mode == "prefill":
|
301
|
+
load_balancer.prefill_configs.append(
|
302
|
+
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
303
|
+
)
|
304
|
+
logger.info(
|
305
|
+
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
306
|
+
)
|
307
|
+
elif obj.mode == "decode":
|
308
|
+
load_balancer.decode_servers.append(obj.registry_url)
|
309
|
+
logger.info(f"Registered decode server: {obj.registry_url}")
|
310
|
+
else:
|
311
|
+
raise HTTPException(
|
312
|
+
status_code=400,
|
313
|
+
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
314
|
+
)
|
315
|
+
|
316
|
+
logger.info(
|
317
|
+
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
318
|
+
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
319
|
+
)
|
320
|
+
|
321
|
+
return Response(status_code=200)
|
322
|
+
|
323
|
+
|
271
324
|
def run(prefill_configs, decode_addrs, host, port):
|
272
325
|
global load_balancer
|
273
326
|
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
@@ -279,15 +332,16 @@ if __name__ == "__main__":
|
|
279
332
|
|
280
333
|
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
281
334
|
parser.add_argument(
|
282
|
-
"--prefill",
|
335
|
+
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
283
336
|
)
|
284
337
|
parser.add_argument(
|
285
|
-
"--
|
286
|
-
help="Comma-separated bootstrap ports for prefill servers",
|
287
|
-
default="8998",
|
338
|
+
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
288
339
|
)
|
289
340
|
parser.add_argument(
|
290
|
-
"--
|
341
|
+
"--prefill-bootstrap-ports",
|
342
|
+
type=int,
|
343
|
+
nargs="+",
|
344
|
+
help="Bootstrap ports for prefill servers",
|
291
345
|
)
|
292
346
|
parser.add_argument(
|
293
347
|
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
@@ -297,22 +351,19 @@ if __name__ == "__main__":
|
|
297
351
|
)
|
298
352
|
args = parser.parse_args()
|
299
353
|
|
300
|
-
|
301
|
-
bootstrap_ports
|
302
|
-
|
303
|
-
|
304
|
-
bootstrap_ports = bootstrap_ports * len(
|
354
|
+
bootstrap_ports = args.prefill_bootstrap_ports
|
355
|
+
if bootstrap_ports is None:
|
356
|
+
bootstrap_ports = [None] * len(args.prefill)
|
357
|
+
elif len(bootstrap_ports) == 1:
|
358
|
+
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
305
359
|
else:
|
306
|
-
if len(bootstrap_ports) != len(
|
360
|
+
if len(bootstrap_ports) != len(args.prefill):
|
307
361
|
raise ValueError(
|
308
362
|
"Number of prefill URLs must match number of bootstrap ports"
|
309
363
|
)
|
310
|
-
exit(1)
|
311
|
-
|
312
|
-
prefill_configs = []
|
313
|
-
for url, port in zip(prefill_urls, bootstrap_ports):
|
314
|
-
prefill_configs.append(PrefillConfig(url, port))
|
315
364
|
|
316
|
-
|
365
|
+
prefill_configs = [
|
366
|
+
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
367
|
+
]
|
317
368
|
|
318
|
-
run(prefill_configs,
|
369
|
+
run(prefill_configs, args.decode, args.host, args.port)
|