sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
sglang/check_env.py
CHANGED
@@ -20,7 +20,7 @@ def is_cuda_v2():
|
|
20
20
|
PACKAGE_LIST = [
|
21
21
|
"sglang",
|
22
22
|
"sgl_kernel",
|
23
|
-
"
|
23
|
+
"flashinfer_python",
|
24
24
|
"triton",
|
25
25
|
"transformers",
|
26
26
|
"torchao",
|
@@ -36,8 +36,8 @@ PACKAGE_LIST = [
|
|
36
36
|
"packaging",
|
37
37
|
"psutil",
|
38
38
|
"pydantic",
|
39
|
-
"multipart",
|
40
|
-
"
|
39
|
+
"python-multipart",
|
40
|
+
"pyzmq",
|
41
41
|
"torchao",
|
42
42
|
"uvicorn",
|
43
43
|
"uvloop",
|
sglang/srt/configs/__init__.py
CHANGED
@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
|
|
3
3
|
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
4
4
|
from sglang.srt.configs.exaone import ExaoneConfig
|
5
5
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
6
|
+
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
7
|
+
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
6
8
|
|
7
9
|
__all__ = [
|
8
10
|
"ExaoneConfig",
|
@@ -10,4 +12,6 @@ __all__ = [
|
|
10
12
|
"DbrxConfig",
|
11
13
|
"DeepseekVL2Config",
|
12
14
|
"MultiModalityConfig",
|
15
|
+
"KimiVLConfig",
|
16
|
+
"MoonViTConfig",
|
13
17
|
]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from transformers.configuration_utils import PretrainedConfig
|
6
|
+
|
7
|
+
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
8
|
+
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
9
|
+
|
10
|
+
|
11
|
+
class KimiVLConfig(PretrainedConfig):
|
12
|
+
model_type = "kimi_vl"
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
17
|
+
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
18
|
+
ignore_index: int = -100,
|
19
|
+
media_placeholder_token_id: int = 163605,
|
20
|
+
pad_token_id: int = 0,
|
21
|
+
**kwargs
|
22
|
+
):
|
23
|
+
if vision_config is None:
|
24
|
+
vision_config = MoonViTConfig()
|
25
|
+
elif isinstance(vision_config, dict):
|
26
|
+
vision_config = MoonViTConfig(**vision_config)
|
27
|
+
self.vision_config = vision_config
|
28
|
+
|
29
|
+
if text_config is None:
|
30
|
+
text_config = DeepseekV2Config()
|
31
|
+
elif isinstance(text_config, dict):
|
32
|
+
text_config = DeepseekV2Config(**text_config)
|
33
|
+
self.text_config = text_config
|
34
|
+
|
35
|
+
self.ignore_index = ignore_index
|
36
|
+
self.media_placeholder_token_id = media_placeholder_token_id
|
37
|
+
|
38
|
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
3
|
+
from transformers.configuration_utils import PretrainedConfig
|
4
|
+
|
5
|
+
|
6
|
+
class MoonViTConfig(PretrainedConfig):
|
7
|
+
model_type = "moonvit"
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
patch_size: int = 14,
|
12
|
+
init_pos_emb_height: int = 64,
|
13
|
+
init_pos_emb_width: int = 64,
|
14
|
+
num_attention_heads: int = 16,
|
15
|
+
num_hidden_layers: int = 27,
|
16
|
+
hidden_size: int = 1152,
|
17
|
+
intermediate_size: int = 4304,
|
18
|
+
merge_kernel_size: tuple[int, int] = (2, 2),
|
19
|
+
**kwargs,
|
20
|
+
):
|
21
|
+
super().__init__(**kwargs)
|
22
|
+
self.patch_size = patch_size
|
23
|
+
# Positional embedding config
|
24
|
+
self.init_pos_emb_height = init_pos_emb_height
|
25
|
+
self.init_pos_emb_width = init_pos_emb_width
|
26
|
+
# Transformer config
|
27
|
+
self.num_hidden_layers = num_hidden_layers
|
28
|
+
self.num_attention_heads = num_attention_heads
|
29
|
+
self.hidden_size = hidden_size
|
30
|
+
self.intermediate_size = intermediate_size
|
31
|
+
# Patch merger config
|
32
|
+
self.merge_kernel_size = merge_kernel_size
|
@@ -47,6 +47,7 @@ class ModelConfig:
|
|
47
47
|
dtype: str = "auto",
|
48
48
|
quantization: Optional[str] = None,
|
49
49
|
override_config_file: Optional[str] = None,
|
50
|
+
is_draft_model: bool = False,
|
50
51
|
) -> None:
|
51
52
|
|
52
53
|
self.model_path = model_path
|
@@ -85,6 +86,12 @@ class ModelConfig:
|
|
85
86
|
else:
|
86
87
|
enable_multimodal = True
|
87
88
|
|
89
|
+
if (
|
90
|
+
is_draft_model
|
91
|
+
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
|
92
|
+
):
|
93
|
+
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
94
|
+
|
88
95
|
# Check model type
|
89
96
|
self.is_generation = is_generation_model(
|
90
97
|
self.hf_config.architectures, is_embedding
|
@@ -169,6 +176,13 @@ class ModelConfig:
|
|
169
176
|
self.attention_arch = AttentionArch.MLA
|
170
177
|
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
171
178
|
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
179
|
+
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
|
180
|
+
self.head_dim = 256
|
181
|
+
self.attention_arch = AttentionArch.MLA
|
182
|
+
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
183
|
+
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
184
|
+
self.v_head_dim = self.hf_text_config.v_head_dim
|
185
|
+
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
|
172
186
|
else:
|
173
187
|
self.attention_arch = AttentionArch.MHA
|
174
188
|
|
@@ -523,6 +537,7 @@ multimodal_model_archs = [
|
|
523
537
|
"Qwen2VLForConditionalGeneration",
|
524
538
|
"Qwen2_5_VLForConditionalGeneration",
|
525
539
|
"CLIPModel",
|
540
|
+
"KimiVLForConditionalGeneration",
|
526
541
|
]
|
527
542
|
|
528
543
|
|
sglang/srt/conversation.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
18
18
|
import dataclasses
|
19
19
|
from enum import IntEnum, auto
|
20
|
-
from typing import Dict, List, Optional, Tuple, Union
|
20
|
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
23
23
|
|
@@ -407,6 +407,7 @@ class Conversation:
|
|
407
407
|
|
408
408
|
# A global registry for all conversation templates
|
409
409
|
chat_templates: Dict[str, Conversation] = {}
|
410
|
+
matching_function_registry: List[Callable] = []
|
410
411
|
|
411
412
|
|
412
413
|
def register_conv_template(template: Conversation, override: bool = False):
|
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
|
|
419
420
|
chat_templates[template.name] = template
|
420
421
|
|
421
422
|
|
423
|
+
def register_conv_template_matching_function(func):
|
424
|
+
matching_function_registry.append(func)
|
425
|
+
|
426
|
+
|
427
|
+
def get_conv_template_by_model_path(model_path):
|
428
|
+
for matching_func in matching_function_registry:
|
429
|
+
conv_name = matching_func(model_path)
|
430
|
+
if conv_name is not None:
|
431
|
+
return conv_name
|
432
|
+
return None
|
433
|
+
|
434
|
+
|
422
435
|
def chat_template_exists(template_name: str) -> bool:
|
423
436
|
return template_name in chat_templates
|
424
437
|
|
@@ -792,3 +805,111 @@ register_conv_template(
|
|
792
805
|
audio_token="(<audio>./</audio>)",
|
793
806
|
)
|
794
807
|
)
|
808
|
+
|
809
|
+
# Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
|
810
|
+
register_conv_template(
|
811
|
+
Conversation(
|
812
|
+
name="kimi-vl",
|
813
|
+
system_message="You are a helpful assistant",
|
814
|
+
system_template="<|im_system|>system<|im_middle|>{system_message}",
|
815
|
+
roles=(
|
816
|
+
"<|im_user|>user<|im_middle|>",
|
817
|
+
"<|im_assistant|>assistant<|im_middle|>",
|
818
|
+
),
|
819
|
+
messages=[],
|
820
|
+
sep="<|im_end|>",
|
821
|
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
822
|
+
stop_str="<|im_end|>",
|
823
|
+
image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
|
824
|
+
)
|
825
|
+
)
|
826
|
+
|
827
|
+
|
828
|
+
@register_conv_template_matching_function
|
829
|
+
def match_deepseek_janus_pro(model_path: str):
|
830
|
+
if (
|
831
|
+
"llama" in model_path.lower()
|
832
|
+
and "3.2" in model_path.lower()
|
833
|
+
and "vision" in model_path.lower()
|
834
|
+
):
|
835
|
+
return "llama_3_vision"
|
836
|
+
|
837
|
+
|
838
|
+
@register_conv_template_matching_function
|
839
|
+
def match_deepseek_janus_pro(model_path: str):
|
840
|
+
if "janus" in model_path.lower():
|
841
|
+
return "janus-pro"
|
842
|
+
|
843
|
+
|
844
|
+
@register_conv_template_matching_function
|
845
|
+
def match_vicuna(model_path: str):
|
846
|
+
if "vicuna" in model_path.lower():
|
847
|
+
return "vicuna_v1.1"
|
848
|
+
if "llava-v1.5" in model_path.lower():
|
849
|
+
return "vicuna_v1.1"
|
850
|
+
if "llava-next-video-7b" in model_path.lower():
|
851
|
+
return "vicuna_v1.1"
|
852
|
+
|
853
|
+
|
854
|
+
@register_conv_template_matching_function
|
855
|
+
def match_llama2_chat(model_path: str):
|
856
|
+
model_path = model_path.lower()
|
857
|
+
if "llama-2" in model_path and "chat" in model_path:
|
858
|
+
return "llama-2"
|
859
|
+
if (
|
860
|
+
"mistral" in model_path or "mixtral" in model_path
|
861
|
+
) and "instruct" in model_path:
|
862
|
+
return "llama-2"
|
863
|
+
if "codellama" in model_path and "instruct" in model_path:
|
864
|
+
return "llama-2"
|
865
|
+
|
866
|
+
|
867
|
+
@register_conv_template_matching_function
|
868
|
+
def match_deepseek_vl(model_path: str):
|
869
|
+
model_path = model_path.lower()
|
870
|
+
if "deepseek" in model_path and "vl2" in model_path:
|
871
|
+
return "deepseek-vl2"
|
872
|
+
|
873
|
+
|
874
|
+
@register_conv_template_matching_function
|
875
|
+
def match_chat_ml(model_path: str):
|
876
|
+
# import pdb;pdb.set_trace()
|
877
|
+
model_path = model_path.lower()
|
878
|
+
# Now the suffix for qwen2 chat model is "instruct"
|
879
|
+
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
|
880
|
+
return "gme-qwen2-vl"
|
881
|
+
if "qwen" in model_path and "vl" in model_path:
|
882
|
+
return "qwen2-vl"
|
883
|
+
if (
|
884
|
+
"llava-v1.6-34b" in model_path
|
885
|
+
or "llava-v1.6-yi-34b" in model_path
|
886
|
+
or "llava-next-video-34b" in model_path
|
887
|
+
or "llava-onevision-qwen2" in model_path
|
888
|
+
):
|
889
|
+
return "chatml-llava"
|
890
|
+
|
891
|
+
|
892
|
+
@register_conv_template_matching_function
|
893
|
+
def match_gemma_it(model_path: str):
|
894
|
+
model_path = model_path.lower()
|
895
|
+
if "gemma" in model_path and "it" in model_path:
|
896
|
+
return "gemma-it"
|
897
|
+
if "gemma-3" in model_path and "1b" not in model_path:
|
898
|
+
# gemma-3-1b-it is completion model
|
899
|
+
return "gemma-it"
|
900
|
+
|
901
|
+
|
902
|
+
@register_conv_template_matching_function
|
903
|
+
def match_openbmb_minicpm(model_path: str):
|
904
|
+
model_path = model_path.lower()
|
905
|
+
if "minicpm-v" in model_path:
|
906
|
+
return "minicpmv"
|
907
|
+
elif "minicpm-o" in model_path:
|
908
|
+
return "minicpmo"
|
909
|
+
|
910
|
+
|
911
|
+
@register_conv_template_matching_function
|
912
|
+
def match_moonshot_kimivl(model_path: str):
|
913
|
+
model_path = model_path.lower()
|
914
|
+
if "kimi" in model_path and "vl" in model_path:
|
915
|
+
return "kimi-vl"
|
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
|
|
32
32
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
33
33
|
from sglang.srt.disaggregation.utils import (
|
34
34
|
DisaggregationMode,
|
35
|
+
FakeBootstrapHost,
|
35
36
|
KVClassType,
|
36
37
|
ReqToMetadataIdxAllocator,
|
37
38
|
TransferBackend,
|
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
|
|
133
134
|
|
134
135
|
def add(self, req: Req) -> None:
|
135
136
|
"""Add a request to the pending queue."""
|
136
|
-
|
137
|
-
|
137
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
138
|
+
# Fake transfer for warmup reqs
|
139
|
+
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
140
|
+
else:
|
141
|
+
kv_receiver_class = get_kv_class(
|
142
|
+
self.transfer_backend, KVClassType.RECEIVER
|
143
|
+
)
|
138
144
|
kv_receiver = kv_receiver_class(
|
139
145
|
mgr=self.kv_manager,
|
140
146
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import FakeKVReceiver, FakeKVSender
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
|
7
|
+
from sglang.srt.disaggregation.base.conn import (
|
8
|
+
BaseKVManager,
|
9
|
+
BaseKVReceiver,
|
10
|
+
BaseKVSender,
|
11
|
+
KVArgs,
|
12
|
+
KVPoll,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
19
|
+
class FakeKVSender(BaseKVSender):
|
20
|
+
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
21
|
+
self.has_sent = False
|
22
|
+
|
23
|
+
def poll(self) -> KVPoll:
|
24
|
+
if self.has_sent is False:
|
25
|
+
# Assume handshake completed instantly
|
26
|
+
return KVPoll.WaitingForInput
|
27
|
+
else:
|
28
|
+
# Assume transfer completed instantly
|
29
|
+
logger.info("FakeKVSender poll success")
|
30
|
+
return KVPoll.Success
|
31
|
+
|
32
|
+
def init(
|
33
|
+
self,
|
34
|
+
kv_indices: list[int],
|
35
|
+
aux_index: Optional[int] = None,
|
36
|
+
dest_ranks: Optional[list[int]] = None,
|
37
|
+
):
|
38
|
+
logger.info(
|
39
|
+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
|
40
|
+
)
|
41
|
+
pass
|
42
|
+
|
43
|
+
def send(
|
44
|
+
self,
|
45
|
+
kv_indices: npt.NDArray[np.int64],
|
46
|
+
index_slice: slice,
|
47
|
+
is_last: bool,
|
48
|
+
):
|
49
|
+
logger.info(
|
50
|
+
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
|
51
|
+
)
|
52
|
+
if is_last:
|
53
|
+
self.has_sent = True
|
54
|
+
logger.info(f"FakeKVSender send success")
|
55
|
+
else:
|
56
|
+
self.has_sent = False
|
57
|
+
logger.info(f"FakeKVSender send fake transfering")
|
58
|
+
|
59
|
+
def failure_exception(self):
|
60
|
+
raise Exception("Fake KVSender Exception")
|
61
|
+
|
62
|
+
|
63
|
+
class FakeKVReceiver(BaseKVReceiver):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
mgr: BaseKVManager,
|
67
|
+
bootstrap_addr: str,
|
68
|
+
bootstrap_room: Optional[int] = None,
|
69
|
+
):
|
70
|
+
self.has_init = False
|
71
|
+
|
72
|
+
def poll(self) -> KVPoll:
|
73
|
+
if self.has_init is False:
|
74
|
+
# Assume handshake completed instantly
|
75
|
+
return KVPoll.WaitingForInput
|
76
|
+
else:
|
77
|
+
# Assume transfer completed instantly
|
78
|
+
logger.info("FakeKVReceiver poll success")
|
79
|
+
return KVPoll.Success
|
80
|
+
|
81
|
+
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
82
|
+
self.has_init = True
|
83
|
+
logger.info(
|
84
|
+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
85
|
+
)
|
86
|
+
|
87
|
+
def failure_exception(self):
|
88
|
+
raise Exception("Fake KVReceiver Exception")
|
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import logging
|
23
|
+
import threading
|
23
24
|
from collections import deque
|
24
25
|
from typing import TYPE_CHECKING, List, Optional
|
25
26
|
|
@@ -28,6 +29,7 @@ import torch
|
|
28
29
|
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
29
30
|
from sglang.srt.disaggregation.utils import (
|
30
31
|
DisaggregationMode,
|
32
|
+
FakeBootstrapHost,
|
31
33
|
KVClassType,
|
32
34
|
ReqToMetadataIdxAllocator,
|
33
35
|
TransferBackend,
|
@@ -115,7 +117,11 @@ class PrefillBootstrapQueue:
|
|
115
117
|
return kv_manager
|
116
118
|
|
117
119
|
def add(self, req: Req) -> None:
|
118
|
-
|
120
|
+
if req.bootstrap_host == FakeBootstrapHost:
|
121
|
+
# Fake transfer for warmup reqs
|
122
|
+
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
123
|
+
else:
|
124
|
+
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
119
125
|
req.disagg_kv_sender = kv_sender_class(
|
120
126
|
mgr=self.kv_manager,
|
121
127
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
@@ -256,7 +262,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
256
262
|
self.running_batch.batch_is_full = False
|
257
263
|
|
258
264
|
def process_batch_result_disagg_prefill(
|
259
|
-
self: Scheduler,
|
265
|
+
self: Scheduler,
|
266
|
+
batch: ScheduleBatch,
|
267
|
+
result: GenerationBatchResult,
|
268
|
+
launch_done: Optional[threading.Event] = None,
|
260
269
|
) -> None:
|
261
270
|
"""
|
262
271
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
@@ -280,7 +289,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
280
289
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
281
290
|
if self.enable_overlap:
|
282
291
|
# wait
|
283
|
-
_, next_token_ids = self.tp_worker.
|
292
|
+
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
284
293
|
else:
|
285
294
|
next_token_ids = result.next_token_ids.tolist()
|
286
295
|
|
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
|
|
15
15
|
DECODE = "decode"
|
16
16
|
|
17
17
|
|
18
|
+
FakeBootstrapHost = "2.2.2.2"
|
19
|
+
|
20
|
+
|
18
21
|
def poll_and_all_reduce(pollers, gloo_group):
|
19
22
|
polls = [int(poller.poll()) for poller in pollers]
|
20
23
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
@@ -59,6 +62,8 @@ class KVClassType(Enum):
|
|
59
62
|
|
60
63
|
|
61
64
|
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
65
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
66
|
+
|
62
67
|
if transfer_backend == TransferBackend.MOONCAKE:
|
63
68
|
from sglang.srt.disaggregation.mooncake import (
|
64
69
|
MooncakeKVBootstrapServer,
|
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
70
75
|
class_mapping = {
|
71
76
|
KVClassType.MANAGER: MooncakeKVManager,
|
72
77
|
KVClassType.SENDER: MooncakeKVSender,
|
73
|
-
KVClassType.RECEIVER: MooncakeKVReceiver,
|
78
|
+
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
74
79
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
75
80
|
}
|
76
81
|
return class_mapping.get(class_type)
|
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
85
90
|
class_mapping = {
|
86
91
|
KVClassType.MANAGER: NixlKVManager,
|
87
92
|
KVClassType.SENDER: NixlKVSender,
|
88
|
-
KVClassType.RECEIVER: NixlKVReceiver,
|
93
|
+
KVClassType.RECEIVER: (NixlKVReceiver),
|
89
94
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
90
95
|
}
|
91
96
|
return class_mapping.get(class_type)
|
97
|
+
if transfer_backend == TransferBackend.FAKE:
|
98
|
+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
99
|
+
|
100
|
+
class_mapping = {
|
101
|
+
KVClassType.SENDER: FakeKVSender,
|
102
|
+
KVClassType.RECEIVER: (FakeKVReceiver),
|
103
|
+
}
|
104
|
+
return class_mapping.get(class_type)
|
105
|
+
|
92
106
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
93
107
|
|
94
108
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
|
|
58
58
|
)
|
59
59
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
60
60
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
61
|
-
from sglang.srt.openai_api.adapter import
|
61
|
+
from sglang.srt.openai_api.adapter import (
|
62
|
+
guess_chat_template_name_from_model_path,
|
63
|
+
load_chat_template_for_openai_api,
|
64
|
+
)
|
62
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
63
66
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
64
67
|
from sglang.srt.utils import (
|
@@ -66,6 +69,7 @@ from sglang.srt.utils import (
|
|
66
69
|
assert_pkg_version,
|
67
70
|
configure_logger,
|
68
71
|
get_zmq_socket,
|
72
|
+
is_cuda,
|
69
73
|
kill_process_tree,
|
70
74
|
launch_dummy_health_check_server,
|
71
75
|
maybe_set_triton_cache_manager,
|
@@ -78,6 +82,8 @@ from sglang.version import __version__
|
|
78
82
|
logger = logging.getLogger(__name__)
|
79
83
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
80
84
|
|
85
|
+
_is_cuda = is_cuda()
|
86
|
+
|
81
87
|
|
82
88
|
class Engine(EngineBase):
|
83
89
|
"""
|
@@ -120,7 +126,6 @@ class Engine(EngineBase):
|
|
120
126
|
server_args=server_args,
|
121
127
|
port_args=port_args,
|
122
128
|
)
|
123
|
-
|
124
129
|
self.server_args = server_args
|
125
130
|
self.tokenizer_manager = tokenizer_manager
|
126
131
|
self.scheduler_info = scheduler_info
|
@@ -295,7 +300,6 @@ class Engine(EngineBase):
|
|
295
300
|
internal_states = loop.run_until_complete(
|
296
301
|
self.tokenizer_manager.get_internal_state()
|
297
302
|
)
|
298
|
-
|
299
303
|
return {
|
300
304
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
301
305
|
**self.scheduler_info,
|
@@ -447,11 +451,17 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
447
451
|
if server_args.attention_backend == "flashinfer":
|
448
452
|
assert_pkg_version(
|
449
453
|
"flashinfer_python",
|
450
|
-
"0.2.
|
454
|
+
"0.2.5",
|
451
455
|
"Please uninstall the old version and "
|
452
456
|
"reinstall the latest version by following the instructions "
|
453
457
|
"at https://docs.flashinfer.ai/installation.html.",
|
454
458
|
)
|
459
|
+
if _is_cuda:
|
460
|
+
assert_pkg_version(
|
461
|
+
"sgl-kernel",
|
462
|
+
"0.1.1",
|
463
|
+
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
464
|
+
)
|
455
465
|
|
456
466
|
def sigchld_handler(signum, frame):
|
457
467
|
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
@@ -508,25 +518,44 @@ def _launch_subprocesses(
|
|
508
518
|
)
|
509
519
|
|
510
520
|
scheduler_pipe_readers = []
|
511
|
-
|
521
|
+
|
522
|
+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
523
|
+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
512
524
|
tp_rank_range = range(
|
513
|
-
tp_size_per_node * server_args.node_rank,
|
514
|
-
tp_size_per_node * (server_args.node_rank + 1),
|
525
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
526
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
515
527
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
528
|
+
|
529
|
+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
530
|
+
pp_rank_range = range(
|
531
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
532
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
533
|
+
)
|
534
|
+
|
535
|
+
for pp_rank in pp_rank_range:
|
536
|
+
for tp_rank in tp_rank_range:
|
537
|
+
reader, writer = mp.Pipe(duplex=False)
|
538
|
+
gpu_id = (
|
539
|
+
server_args.base_gpu_id
|
540
|
+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
541
|
+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
542
|
+
)
|
543
|
+
proc = mp.Process(
|
544
|
+
target=run_scheduler_process,
|
545
|
+
args=(
|
546
|
+
server_args,
|
547
|
+
port_args,
|
548
|
+
gpu_id,
|
549
|
+
tp_rank,
|
550
|
+
pp_rank,
|
551
|
+
None,
|
552
|
+
writer,
|
553
|
+
),
|
554
|
+
)
|
555
|
+
with memory_saver_adapter.configure_subprocess():
|
556
|
+
proc.start()
|
557
|
+
scheduler_procs.append(proc)
|
558
|
+
scheduler_pipe_readers.append(reader)
|
530
559
|
else:
|
531
560
|
# Launch the data parallel controller
|
532
561
|
reader, writer = mp.Pipe(duplex=False)
|
@@ -575,6 +604,8 @@ def _launch_subprocesses(
|
|
575
604
|
load_chat_template_for_openai_api(
|
576
605
|
tokenizer_manager, server_args.chat_template, server_args.model_path
|
577
606
|
)
|
607
|
+
else:
|
608
|
+
guess_chat_template_name_from_model_path(server_args.model_path)
|
578
609
|
|
579
610
|
if server_args.completion_template:
|
580
611
|
load_completion_template_for_openai_api(server_args.completion_template)
|