sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
154
154
  gpu_id=tp_rank,
155
155
  tp_rank=tp_rank,
156
156
  tp_size=server_args.tp_size,
157
+ pp_rank=0,
158
+ pp_size=1,
157
159
  nccl_port=port_args.nccl_port,
158
160
  server_args=server_args,
159
161
  )
sglang/check_env.py CHANGED
@@ -20,7 +20,7 @@ def is_cuda_v2():
20
20
  PACKAGE_LIST = [
21
21
  "sglang",
22
22
  "sgl_kernel",
23
- "flashinfer",
23
+ "flashinfer_python",
24
24
  "triton",
25
25
  "transformers",
26
26
  "torchao",
@@ -36,8 +36,8 @@ PACKAGE_LIST = [
36
36
  "packaging",
37
37
  "psutil",
38
38
  "pydantic",
39
- "multipart",
40
- "zmq",
39
+ "python-multipart",
40
+ "pyzmq",
41
41
  "torchao",
42
42
  "uvicorn",
43
43
  "uvloop",
@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
3
3
  from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
4
4
  from sglang.srt.configs.exaone import ExaoneConfig
5
5
  from sglang.srt.configs.janus_pro import MultiModalityConfig
6
+ from sglang.srt.configs.kimi_vl import KimiVLConfig
7
+ from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
6
8
 
7
9
  __all__ = [
8
10
  "ExaoneConfig",
@@ -10,4 +12,6 @@ __all__ = [
10
12
  "DbrxConfig",
11
13
  "DeepseekVL2Config",
12
14
  "MultiModalityConfig",
15
+ "KimiVLConfig",
16
+ "MoonViTConfig",
13
17
  ]
@@ -0,0 +1,38 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3
+ from typing import Optional, Union
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+ from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
8
+ from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
9
+
10
+
11
+ class KimiVLConfig(PretrainedConfig):
12
+ model_type = "kimi_vl"
13
+
14
+ def __init__(
15
+ self,
16
+ vision_config: Optional[Union[dict, MoonViTConfig]] = None,
17
+ text_config: Optional[Union[dict, DeepseekV2Config]] = None,
18
+ ignore_index: int = -100,
19
+ media_placeholder_token_id: int = 163605,
20
+ pad_token_id: int = 0,
21
+ **kwargs
22
+ ):
23
+ if vision_config is None:
24
+ vision_config = MoonViTConfig()
25
+ elif isinstance(vision_config, dict):
26
+ vision_config = MoonViTConfig(**vision_config)
27
+ self.vision_config = vision_config
28
+
29
+ if text_config is None:
30
+ text_config = DeepseekV2Config()
31
+ elif isinstance(text_config, dict):
32
+ text_config = DeepseekV2Config(**text_config)
33
+ self.text_config = text_config
34
+
35
+ self.ignore_index = ignore_index
36
+ self.media_placeholder_token_id = media_placeholder_token_id
37
+
38
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -0,0 +1,32 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class MoonViTConfig(PretrainedConfig):
7
+ model_type = "moonvit"
8
+
9
+ def __init__(
10
+ self,
11
+ patch_size: int = 14,
12
+ init_pos_emb_height: int = 64,
13
+ init_pos_emb_width: int = 64,
14
+ num_attention_heads: int = 16,
15
+ num_hidden_layers: int = 27,
16
+ hidden_size: int = 1152,
17
+ intermediate_size: int = 4304,
18
+ merge_kernel_size: tuple[int, int] = (2, 2),
19
+ **kwargs,
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.patch_size = patch_size
23
+ # Positional embedding config
24
+ self.init_pos_emb_height = init_pos_emb_height
25
+ self.init_pos_emb_width = init_pos_emb_width
26
+ # Transformer config
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.hidden_size = hidden_size
30
+ self.intermediate_size = intermediate_size
31
+ # Patch merger config
32
+ self.merge_kernel_size = merge_kernel_size
@@ -47,6 +47,7 @@ class ModelConfig:
47
47
  dtype: str = "auto",
48
48
  quantization: Optional[str] = None,
49
49
  override_config_file: Optional[str] = None,
50
+ is_draft_model: bool = False,
50
51
  ) -> None:
51
52
 
52
53
  self.model_path = model_path
@@ -85,6 +86,12 @@ class ModelConfig:
85
86
  else:
86
87
  enable_multimodal = True
87
88
 
89
+ if (
90
+ is_draft_model
91
+ and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
92
+ ):
93
+ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
94
+
88
95
  # Check model type
89
96
  self.is_generation = is_generation_model(
90
97
  self.hf_config.architectures, is_embedding
@@ -169,6 +176,13 @@ class ModelConfig:
169
176
  self.attention_arch = AttentionArch.MLA
170
177
  self.kv_lora_rank = self.hf_text_config.kv_lora_rank
171
178
  self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
179
+ elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
180
+ self.head_dim = 256
181
+ self.attention_arch = AttentionArch.MLA
182
+ self.kv_lora_rank = self.hf_text_config.kv_lora_rank
183
+ self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
184
+ self.v_head_dim = self.hf_text_config.v_head_dim
185
+ self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
172
186
  else:
173
187
  self.attention_arch = AttentionArch.MHA
174
188
 
@@ -523,6 +537,7 @@ multimodal_model_archs = [
523
537
  "Qwen2VLForConditionalGeneration",
524
538
  "Qwen2_5_VLForConditionalGeneration",
525
539
  "CLIPModel",
540
+ "KimiVLForConditionalGeneration",
526
541
  ]
527
542
 
528
543
 
@@ -17,7 +17,7 @@
17
17
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
18
18
  import dataclasses
19
19
  from enum import IntEnum, auto
20
- from typing import Dict, List, Optional, Tuple, Union
20
+ from typing import Callable, Dict, List, Optional, Tuple, Union
21
21
 
22
22
  from sglang.srt.openai_api.protocol import ChatCompletionRequest
23
23
 
@@ -407,6 +407,7 @@ class Conversation:
407
407
 
408
408
  # A global registry for all conversation templates
409
409
  chat_templates: Dict[str, Conversation] = {}
410
+ matching_function_registry: List[Callable] = []
410
411
 
411
412
 
412
413
  def register_conv_template(template: Conversation, override: bool = False):
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
419
420
  chat_templates[template.name] = template
420
421
 
421
422
 
423
+ def register_conv_template_matching_function(func):
424
+ matching_function_registry.append(func)
425
+
426
+
427
+ def get_conv_template_by_model_path(model_path):
428
+ for matching_func in matching_function_registry:
429
+ conv_name = matching_func(model_path)
430
+ if conv_name is not None:
431
+ return conv_name
432
+ return None
433
+
434
+
422
435
  def chat_template_exists(template_name: str) -> bool:
423
436
  return template_name in chat_templates
424
437
 
@@ -792,3 +805,111 @@ register_conv_template(
792
805
  audio_token="(<audio>./</audio>)",
793
806
  )
794
807
  )
808
+
809
+ # Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
810
+ register_conv_template(
811
+ Conversation(
812
+ name="kimi-vl",
813
+ system_message="You are a helpful assistant",
814
+ system_template="<|im_system|>system<|im_middle|>{system_message}",
815
+ roles=(
816
+ "<|im_user|>user<|im_middle|>",
817
+ "<|im_assistant|>assistant<|im_middle|>",
818
+ ),
819
+ messages=[],
820
+ sep="<|im_end|>",
821
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
822
+ stop_str="<|im_end|>",
823
+ image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
824
+ )
825
+ )
826
+
827
+
828
+ @register_conv_template_matching_function
829
+ def match_deepseek_janus_pro(model_path: str):
830
+ if (
831
+ "llama" in model_path.lower()
832
+ and "3.2" in model_path.lower()
833
+ and "vision" in model_path.lower()
834
+ ):
835
+ return "llama_3_vision"
836
+
837
+
838
+ @register_conv_template_matching_function
839
+ def match_deepseek_janus_pro(model_path: str):
840
+ if "janus" in model_path.lower():
841
+ return "janus-pro"
842
+
843
+
844
+ @register_conv_template_matching_function
845
+ def match_vicuna(model_path: str):
846
+ if "vicuna" in model_path.lower():
847
+ return "vicuna_v1.1"
848
+ if "llava-v1.5" in model_path.lower():
849
+ return "vicuna_v1.1"
850
+ if "llava-next-video-7b" in model_path.lower():
851
+ return "vicuna_v1.1"
852
+
853
+
854
+ @register_conv_template_matching_function
855
+ def match_llama2_chat(model_path: str):
856
+ model_path = model_path.lower()
857
+ if "llama-2" in model_path and "chat" in model_path:
858
+ return "llama-2"
859
+ if (
860
+ "mistral" in model_path or "mixtral" in model_path
861
+ ) and "instruct" in model_path:
862
+ return "llama-2"
863
+ if "codellama" in model_path and "instruct" in model_path:
864
+ return "llama-2"
865
+
866
+
867
+ @register_conv_template_matching_function
868
+ def match_deepseek_vl(model_path: str):
869
+ model_path = model_path.lower()
870
+ if "deepseek" in model_path and "vl2" in model_path:
871
+ return "deepseek-vl2"
872
+
873
+
874
+ @register_conv_template_matching_function
875
+ def match_chat_ml(model_path: str):
876
+ # import pdb;pdb.set_trace()
877
+ model_path = model_path.lower()
878
+ # Now the suffix for qwen2 chat model is "instruct"
879
+ if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
880
+ return "gme-qwen2-vl"
881
+ if "qwen" in model_path and "vl" in model_path:
882
+ return "qwen2-vl"
883
+ if (
884
+ "llava-v1.6-34b" in model_path
885
+ or "llava-v1.6-yi-34b" in model_path
886
+ or "llava-next-video-34b" in model_path
887
+ or "llava-onevision-qwen2" in model_path
888
+ ):
889
+ return "chatml-llava"
890
+
891
+
892
+ @register_conv_template_matching_function
893
+ def match_gemma_it(model_path: str):
894
+ model_path = model_path.lower()
895
+ if "gemma" in model_path and "it" in model_path:
896
+ return "gemma-it"
897
+ if "gemma-3" in model_path and "1b" not in model_path:
898
+ # gemma-3-1b-it is completion model
899
+ return "gemma-it"
900
+
901
+
902
+ @register_conv_template_matching_function
903
+ def match_openbmb_minicpm(model_path: str):
904
+ model_path = model_path.lower()
905
+ if "minicpm-v" in model_path:
906
+ return "minicpmv"
907
+ elif "minicpm-o" in model_path:
908
+ return "minicpmo"
909
+
910
+
911
+ @register_conv_template_matching_function
912
+ def match_moonshot_kimivl(model_path: str):
913
+ model_path = model_path.lower()
914
+ if "kimi" in model_path and "vl" in model_path:
915
+ return "kimi-vl"
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
32
32
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
33
33
  from sglang.srt.disaggregation.utils import (
34
34
  DisaggregationMode,
35
+ FakeBootstrapHost,
35
36
  KVClassType,
36
37
  ReqToMetadataIdxAllocator,
37
38
  TransferBackend,
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
133
134
 
134
135
  def add(self, req: Req) -> None:
135
136
  """Add a request to the pending queue."""
136
-
137
- kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137
+ if req.bootstrap_host == FakeBootstrapHost:
138
+ # Fake transfer for warmup reqs
139
+ kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
140
+ else:
141
+ kv_receiver_class = get_kv_class(
142
+ self.transfer_backend, KVClassType.RECEIVER
143
+ )
138
144
  kv_receiver = kv_receiver_class(
139
145
  mgr=self.kv_manager,
140
146
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
@@ -0,0 +1 @@
1
+ from .conn import FakeKVReceiver, FakeKVSender
@@ -0,0 +1,88 @@
1
+ import logging
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.base.conn import (
8
+ BaseKVManager,
9
+ BaseKVReceiver,
10
+ BaseKVSender,
11
+ KVArgs,
12
+ KVPoll,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19
+ class FakeKVSender(BaseKVSender):
20
+ def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
21
+ self.has_sent = False
22
+
23
+ def poll(self) -> KVPoll:
24
+ if self.has_sent is False:
25
+ # Assume handshake completed instantly
26
+ return KVPoll.WaitingForInput
27
+ else:
28
+ # Assume transfer completed instantly
29
+ logger.info("FakeKVSender poll success")
30
+ return KVPoll.Success
31
+
32
+ def init(
33
+ self,
34
+ kv_indices: list[int],
35
+ aux_index: Optional[int] = None,
36
+ dest_ranks: Optional[list[int]] = None,
37
+ ):
38
+ logger.info(
39
+ f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
40
+ )
41
+ pass
42
+
43
+ def send(
44
+ self,
45
+ kv_indices: npt.NDArray[np.int64],
46
+ index_slice: slice,
47
+ is_last: bool,
48
+ ):
49
+ logger.info(
50
+ f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51
+ )
52
+ if is_last:
53
+ self.has_sent = True
54
+ logger.info(f"FakeKVSender send success")
55
+ else:
56
+ self.has_sent = False
57
+ logger.info(f"FakeKVSender send fake transfering")
58
+
59
+ def failure_exception(self):
60
+ raise Exception("Fake KVSender Exception")
61
+
62
+
63
+ class FakeKVReceiver(BaseKVReceiver):
64
+ def __init__(
65
+ self,
66
+ mgr: BaseKVManager,
67
+ bootstrap_addr: str,
68
+ bootstrap_room: Optional[int] = None,
69
+ ):
70
+ self.has_init = False
71
+
72
+ def poll(self) -> KVPoll:
73
+ if self.has_init is False:
74
+ # Assume handshake completed instantly
75
+ return KVPoll.WaitingForInput
76
+ else:
77
+ # Assume transfer completed instantly
78
+ logger.info("FakeKVReceiver poll success")
79
+ return KVPoll.Success
80
+
81
+ def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
82
+ self.has_init = True
83
+ logger.info(
84
+ f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
85
+ )
86
+
87
+ def failure_exception(self):
88
+ raise Exception("Fake KVReceiver Exception")
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
+ import threading
23
24
  from collections import deque
24
25
  from typing import TYPE_CHECKING, List, Optional
25
26
 
@@ -28,6 +29,7 @@ import torch
28
29
  from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
29
30
  from sglang.srt.disaggregation.utils import (
30
31
  DisaggregationMode,
32
+ FakeBootstrapHost,
31
33
  KVClassType,
32
34
  ReqToMetadataIdxAllocator,
33
35
  TransferBackend,
@@ -115,7 +117,11 @@ class PrefillBootstrapQueue:
115
117
  return kv_manager
116
118
 
117
119
  def add(self, req: Req) -> None:
118
- kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
120
+ if req.bootstrap_host == FakeBootstrapHost:
121
+ # Fake transfer for warmup reqs
122
+ kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
123
+ else:
124
+ kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
119
125
  req.disagg_kv_sender = kv_sender_class(
120
126
  mgr=self.kv_manager,
121
127
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
@@ -256,7 +262,10 @@ class SchedulerDisaggregationPrefillMixin:
256
262
  self.running_batch.batch_is_full = False
257
263
 
258
264
  def process_batch_result_disagg_prefill(
259
- self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
265
+ self: Scheduler,
266
+ batch: ScheduleBatch,
267
+ result: GenerationBatchResult,
268
+ launch_done: Optional[threading.Event] = None,
260
269
  ) -> None:
261
270
  """
262
271
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -280,7 +289,7 @@ class SchedulerDisaggregationPrefillMixin:
280
289
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
281
290
  if self.enable_overlap:
282
291
  # wait
283
- _, next_token_ids = self.tp_worker.resolve_batch_result(bid)
292
+ _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
284
293
  else:
285
294
  next_token_ids = result.next_token_ids.tolist()
286
295
 
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
15
15
  DECODE = "decode"
16
16
 
17
17
 
18
+ FakeBootstrapHost = "2.2.2.2"
19
+
20
+
18
21
  def poll_and_all_reduce(pollers, gloo_group):
19
22
  polls = [int(poller.poll()) for poller in pollers]
20
23
  tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
@@ -59,6 +62,8 @@ class KVClassType(Enum):
59
62
 
60
63
 
61
64
  def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
65
+ from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
66
+
62
67
  if transfer_backend == TransferBackend.MOONCAKE:
63
68
  from sglang.srt.disaggregation.mooncake import (
64
69
  MooncakeKVBootstrapServer,
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
70
75
  class_mapping = {
71
76
  KVClassType.MANAGER: MooncakeKVManager,
72
77
  KVClassType.SENDER: MooncakeKVSender,
73
- KVClassType.RECEIVER: MooncakeKVReceiver,
78
+ KVClassType.RECEIVER: (MooncakeKVReceiver),
74
79
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
75
80
  }
76
81
  return class_mapping.get(class_type)
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
85
90
  class_mapping = {
86
91
  KVClassType.MANAGER: NixlKVManager,
87
92
  KVClassType.SENDER: NixlKVSender,
88
- KVClassType.RECEIVER: NixlKVReceiver,
93
+ KVClassType.RECEIVER: (NixlKVReceiver),
89
94
  KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
90
95
  }
91
96
  return class_mapping.get(class_type)
97
+ if transfer_backend == TransferBackend.FAKE:
98
+ from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
99
+
100
+ class_mapping = {
101
+ KVClassType.SENDER: FakeKVSender,
102
+ KVClassType.RECEIVER: (FakeKVReceiver),
103
+ }
104
+ return class_mapping.get(class_type)
105
+
92
106
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
93
107
 
94
108
 
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
58
58
  )
59
59
  from sglang.srt.managers.scheduler import run_scheduler_process
60
60
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
61
- from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
61
+ from sglang.srt.openai_api.adapter import (
62
+ guess_chat_template_name_from_model_path,
63
+ load_chat_template_for_openai_api,
64
+ )
62
65
  from sglang.srt.server_args import PortArgs, ServerArgs
63
66
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
64
67
  from sglang.srt.utils import (
@@ -66,6 +69,7 @@ from sglang.srt.utils import (
66
69
  assert_pkg_version,
67
70
  configure_logger,
68
71
  get_zmq_socket,
72
+ is_cuda,
69
73
  kill_process_tree,
70
74
  launch_dummy_health_check_server,
71
75
  maybe_set_triton_cache_manager,
@@ -78,6 +82,8 @@ from sglang.version import __version__
78
82
  logger = logging.getLogger(__name__)
79
83
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
80
84
 
85
+ _is_cuda = is_cuda()
86
+
81
87
 
82
88
  class Engine(EngineBase):
83
89
  """
@@ -120,7 +126,6 @@ class Engine(EngineBase):
120
126
  server_args=server_args,
121
127
  port_args=port_args,
122
128
  )
123
-
124
129
  self.server_args = server_args
125
130
  self.tokenizer_manager = tokenizer_manager
126
131
  self.scheduler_info = scheduler_info
@@ -295,7 +300,6 @@ class Engine(EngineBase):
295
300
  internal_states = loop.run_until_complete(
296
301
  self.tokenizer_manager.get_internal_state()
297
302
  )
298
-
299
303
  return {
300
304
  **dataclasses.asdict(self.tokenizer_manager.server_args),
301
305
  **self.scheduler_info,
@@ -447,11 +451,17 @@ def _set_envs_and_config(server_args: ServerArgs):
447
451
  if server_args.attention_backend == "flashinfer":
448
452
  assert_pkg_version(
449
453
  "flashinfer_python",
450
- "0.2.3",
454
+ "0.2.5",
451
455
  "Please uninstall the old version and "
452
456
  "reinstall the latest version by following the instructions "
453
457
  "at https://docs.flashinfer.ai/installation.html.",
454
458
  )
459
+ if _is_cuda:
460
+ assert_pkg_version(
461
+ "sgl-kernel",
462
+ "0.1.1",
463
+ "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
464
+ )
455
465
 
456
466
  def sigchld_handler(signum, frame):
457
467
  pid, exitcode = os.waitpid(0, os.WNOHANG)
@@ -508,25 +518,44 @@ def _launch_subprocesses(
508
518
  )
509
519
 
510
520
  scheduler_pipe_readers = []
511
- tp_size_per_node = server_args.tp_size // server_args.nnodes
521
+
522
+ nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
523
+ tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
512
524
  tp_rank_range = range(
513
- tp_size_per_node * server_args.node_rank,
514
- tp_size_per_node * (server_args.node_rank + 1),
525
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
526
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
515
527
  )
516
- for tp_rank in tp_rank_range:
517
- reader, writer = mp.Pipe(duplex=False)
518
- gpu_id = (
519
- server_args.base_gpu_id
520
- + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
521
- )
522
- proc = mp.Process(
523
- target=run_scheduler_process,
524
- args=(server_args, port_args, gpu_id, tp_rank, None, writer),
525
- )
526
- with memory_saver_adapter.configure_subprocess():
527
- proc.start()
528
- scheduler_procs.append(proc)
529
- scheduler_pipe_readers.append(reader)
528
+
529
+ pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
530
+ pp_rank_range = range(
531
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
532
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
533
+ )
534
+
535
+ for pp_rank in pp_rank_range:
536
+ for tp_rank in tp_rank_range:
537
+ reader, writer = mp.Pipe(duplex=False)
538
+ gpu_id = (
539
+ server_args.base_gpu_id
540
+ + ((pp_rank % pp_size_per_node) * tp_size_per_node)
541
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
542
+ )
543
+ proc = mp.Process(
544
+ target=run_scheduler_process,
545
+ args=(
546
+ server_args,
547
+ port_args,
548
+ gpu_id,
549
+ tp_rank,
550
+ pp_rank,
551
+ None,
552
+ writer,
553
+ ),
554
+ )
555
+ with memory_saver_adapter.configure_subprocess():
556
+ proc.start()
557
+ scheduler_procs.append(proc)
558
+ scheduler_pipe_readers.append(reader)
530
559
  else:
531
560
  # Launch the data parallel controller
532
561
  reader, writer = mp.Pipe(duplex=False)
@@ -575,6 +604,8 @@ def _launch_subprocesses(
575
604
  load_chat_template_for_openai_api(
576
605
  tokenizer_manager, server_args.chat_template, server_args.model_path
577
606
  )
607
+ else:
608
+ guess_chat_template_name_from_model_path(server_args.model_path)
578
609
 
579
610
  if server_args.completion_template:
580
611
  load_completion_template_for_openai_api(server_args.completion_template)