sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -290,6 +290,9 @@ class DictOutput(object):
290
290
  def __getitem__(self, item):
291
291
  return self.__dict__[item]
292
292
 
293
+ def __contains__(self, key):
294
+ return key in self.__dict__
295
+
293
296
  def __setitem__(self, key, value):
294
297
  self.__dict__[key] = value
295
298
 
@@ -0,0 +1,38 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3
+ from typing import Optional, Union
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+ from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
8
+ from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
9
+
10
+
11
+ class KimiVLConfig(PretrainedConfig):
12
+ model_type = "kimi_vl"
13
+
14
+ def __init__(
15
+ self,
16
+ vision_config: Optional[Union[dict, MoonViTConfig]] = None,
17
+ text_config: Optional[Union[dict, DeepseekV2Config]] = None,
18
+ ignore_index: int = -100,
19
+ media_placeholder_token_id: int = 163605,
20
+ pad_token_id: int = 0,
21
+ **kwargs
22
+ ):
23
+ if vision_config is None:
24
+ vision_config = MoonViTConfig()
25
+ elif isinstance(vision_config, dict):
26
+ vision_config = MoonViTConfig(**vision_config)
27
+ self.vision_config = vision_config
28
+
29
+ if text_config is None:
30
+ text_config = DeepseekV2Config()
31
+ elif isinstance(text_config, dict):
32
+ text_config = DeepseekV2Config(**text_config)
33
+ self.text_config = text_config
34
+
35
+ self.ignore_index = ignore_index
36
+ self.media_placeholder_token_id = media_placeholder_token_id
37
+
38
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -0,0 +1,32 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class MoonViTConfig(PretrainedConfig):
7
+ model_type = "moonvit"
8
+
9
+ def __init__(
10
+ self,
11
+ patch_size: int = 14,
12
+ init_pos_emb_height: int = 64,
13
+ init_pos_emb_width: int = 64,
14
+ num_attention_heads: int = 16,
15
+ num_hidden_layers: int = 27,
16
+ hidden_size: int = 1152,
17
+ intermediate_size: int = 4304,
18
+ merge_kernel_size: tuple[int, int] = (2, 2),
19
+ **kwargs,
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.patch_size = patch_size
23
+ # Positional embedding config
24
+ self.init_pos_emb_height = init_pos_emb_height
25
+ self.init_pos_emb_width = init_pos_emb_width
26
+ # Transformer config
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.hidden_size = hidden_size
30
+ self.intermediate_size = intermediate_size
31
+ # Patch merger config
32
+ self.merge_kernel_size = merge_kernel_size
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
26
26
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
27
+ from sglang.srt.server_args import ServerArgs
27
28
  from sglang.srt.utils import get_bool_env_var, is_hip
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class ModelConfig:
47
48
  dtype: str = "auto",
48
49
  quantization: Optional[str] = None,
49
50
  override_config_file: Optional[str] = None,
51
+ is_draft_model: bool = False,
50
52
  ) -> None:
51
53
 
52
54
  self.model_path = model_path
@@ -85,6 +87,12 @@ class ModelConfig:
85
87
  else:
86
88
  enable_multimodal = True
87
89
 
90
+ if (
91
+ is_draft_model
92
+ and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
93
+ ):
94
+ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
95
+
88
96
  # Check model type
89
97
  self.is_generation = is_generation_model(
90
98
  self.hf_config.architectures, is_embedding
@@ -169,6 +177,13 @@ class ModelConfig:
169
177
  self.attention_arch = AttentionArch.MLA
170
178
  self.kv_lora_rank = self.hf_text_config.kv_lora_rank
171
179
  self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
180
+ elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
181
+ self.head_dim = 256
182
+ self.attention_arch = AttentionArch.MLA
183
+ self.kv_lora_rank = self.hf_text_config.kv_lora_rank
184
+ self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
185
+ self.v_head_dim = self.hf_text_config.v_head_dim
186
+ self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
172
187
  else:
173
188
  self.attention_arch = AttentionArch.MHA
174
189
 
@@ -196,6 +211,21 @@ class ModelConfig:
196
211
  self.hf_eos_token_id = self.get_hf_eos_token_id()
197
212
  self.image_token_id = getattr(self.hf_config, "image_token_id", None)
198
213
 
214
+ @staticmethod
215
+ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
216
+ return ModelConfig(
217
+ model_path=model_path or server_args.model_path,
218
+ trust_remote_code=server_args.trust_remote_code,
219
+ revision=server_args.revision,
220
+ context_length=server_args.context_length,
221
+ model_override_args=server_args.json_model_override_args,
222
+ is_embedding=server_args.is_embedding,
223
+ enable_multimodal=server_args.enable_multimodal,
224
+ dtype=server_args.dtype,
225
+ quantization=server_args.quantization,
226
+ **kwargs,
227
+ )
228
+
199
229
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
200
230
  def get_total_num_kv_heads(self) -> int:
201
231
  """Returns the total number of KV heads."""
@@ -523,6 +553,8 @@ multimodal_model_archs = [
523
553
  "Qwen2VLForConditionalGeneration",
524
554
  "Qwen2_5_VLForConditionalGeneration",
525
555
  "CLIPModel",
556
+ "KimiVLForConditionalGeneration",
557
+ "InternVLChatModel",
526
558
  ]
527
559
 
528
560
 
@@ -18,6 +18,7 @@ import logging
18
18
  from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
+ import xgrammar
21
22
  from xgrammar import (
22
23
  CompiledGrammar,
23
24
  GrammarCompiler,
@@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject):
58
59
  self.override_stop_tokens = override_stop_tokens
59
60
  self.finished = False
60
61
 
61
- # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
62
- # class init site to avoid re-initializing CUDA in forked subprocess.
63
- from xgrammar.kernels import apply_token_bitmask_inplace_kernels
64
-
65
- self.use_token_bitmask_triton = get_bool_env_var(
66
- "SGLANG_TOKEN_BITMASK_TRITON", "false"
67
- )
68
- self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
69
- "cuda", None
62
+ from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
63
+ apply_token_bitmask_inplace_cpu,
70
64
  )
71
- self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
65
+
66
+ self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
72
67
 
73
68
  def accept_token(self, token: int):
74
69
  assert self.matcher.accept_token(token)
@@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject):
113
108
  return vocab_mask.to(device, non_blocking=True)
114
109
 
115
110
  def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
116
- if (
117
- not self.use_token_bitmask_triton
118
- and logits.device.type == "cuda"
119
- and self.apply_vocab_mask_cuda
120
- ):
121
- return self.apply_vocab_mask_cuda(logits, vocab_mask)
122
- if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
123
- return self.apply_vocab_mask_cpu(logits, vocab_mask)
124
- apply_token_bitmask_inplace_triton(logits, vocab_mask)
111
+ if logits.device.type == "cuda":
112
+ apply_token_bitmask_inplace_triton(logits, vocab_mask)
113
+ elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
114
+ self.apply_vocab_mask_cpu(logits, vocab_mask)
115
+ else:
116
+ raise RuntimeError(f"Unsupported device: {logits.device.type}")
125
117
 
126
118
  def copy(self):
127
119
  matcher = GrammarMatcher(
@@ -17,7 +17,7 @@
17
17
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
18
18
  import dataclasses
19
19
  from enum import IntEnum, auto
20
- from typing import Dict, List, Optional, Tuple, Union
20
+ from typing import Callable, Dict, List, Optional, Tuple, Union
21
21
 
22
22
  from sglang.srt.openai_api.protocol import ChatCompletionRequest
23
23
 
@@ -48,6 +48,7 @@ class SeparatorStyle(IntEnum):
48
48
  DeepSeekVL2 = auto()
49
49
  QWEN2_VL_EMBED = auto()
50
50
  GEMMA3 = auto()
51
+ MPT = auto()
51
52
 
52
53
 
53
54
  @dataclasses.dataclass
@@ -327,6 +328,16 @@ class Conversation:
327
328
  ret += role
328
329
  return ret
329
330
 
331
+ elif self.sep_style == SeparatorStyle.MPT:
332
+ ret = system_prompt + self.sep
333
+ for role, message in self.messages:
334
+ if message:
335
+ if type(message) is tuple:
336
+ message, _, _ = message
337
+ ret += role + message + self.sep
338
+ else:
339
+ ret += role
340
+ return ret
330
341
  else:
331
342
  raise ValueError(f"Invalid style: {self.sep_style}")
332
343
 
@@ -407,6 +418,7 @@ class Conversation:
407
418
 
408
419
  # A global registry for all conversation templates
409
420
  chat_templates: Dict[str, Conversation] = {}
421
+ matching_function_registry: List[Callable] = []
410
422
 
411
423
 
412
424
  def register_conv_template(template: Conversation, override: bool = False):
@@ -419,6 +431,18 @@ def register_conv_template(template: Conversation, override: bool = False):
419
431
  chat_templates[template.name] = template
420
432
 
421
433
 
434
+ def register_conv_template_matching_function(func):
435
+ matching_function_registry.append(func)
436
+
437
+
438
+ def get_conv_template_by_model_path(model_path):
439
+ for matching_func in matching_function_registry:
440
+ conv_name = matching_func(model_path)
441
+ if conv_name is not None:
442
+ return conv_name
443
+ return None
444
+
445
+
422
446
  def chat_template_exists(template_name: str) -> bool:
423
447
  return template_name in chat_templates
424
448
 
@@ -557,8 +581,11 @@ def generate_chat_conv(
557
581
  real_content += "\n" # for video
558
582
  real_content += content.text
559
583
  elif content.type == "image_url":
560
- # NOTE: Only works for llava
561
- real_content += image_token
584
+ # NOTE: works for llava and intervl2_5
585
+ if conv.name == "internvl-2-5":
586
+ real_content = image_token + real_content
587
+ else:
588
+ real_content += image_token
562
589
  conv.append_image(content.image_url.url)
563
590
  elif content.type == "audio_url":
564
591
  real_content += audio_token
@@ -690,6 +717,19 @@ register_conv_template(
690
717
  )
691
718
  )
692
719
 
720
+ register_conv_template(
721
+ Conversation(
722
+ name="internvl-2-5",
723
+ system_template="<|im_start|>system\n{system_message}",
724
+ system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
725
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
726
+ sep_style=SeparatorStyle.MPT,
727
+ sep="<|im_end|>\n",
728
+ stop_str=["<|im_end|>", "<|action_end|>"],
729
+ image_token="<image>",
730
+ )
731
+ )
732
+
693
733
  # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
694
734
  register_conv_template(
695
735
  Conversation(
@@ -792,3 +832,111 @@ register_conv_template(
792
832
  audio_token="(<audio>./</audio>)",
793
833
  )
794
834
  )
835
+
836
+ # Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
837
+ register_conv_template(
838
+ Conversation(
839
+ name="kimi-vl",
840
+ system_message="You are a helpful assistant",
841
+ system_template="<|im_system|>system<|im_middle|>{system_message}",
842
+ roles=(
843
+ "<|im_user|>user<|im_middle|>",
844
+ "<|im_assistant|>assistant<|im_middle|>",
845
+ ),
846
+ messages=[],
847
+ sep="<|im_end|>",
848
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
849
+ stop_str="<|im_end|>",
850
+ image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
851
+ )
852
+ )
853
+
854
+
855
+ @register_conv_template_matching_function
856
+ def match_llama_3_vision(model_path: str):
857
+ if (
858
+ "llama" in model_path.lower()
859
+ and "3.2" in model_path.lower()
860
+ and "vision" in model_path.lower()
861
+ ):
862
+ return "llama_3_vision"
863
+
864
+
865
+ @register_conv_template_matching_function
866
+ def match_deepseek_janus_pro(model_path: str):
867
+ if "janus" in model_path.lower():
868
+ return "janus-pro"
869
+
870
+
871
+ @register_conv_template_matching_function
872
+ def match_vicuna(model_path: str):
873
+ if "vicuna" in model_path.lower():
874
+ return "vicuna_v1.1"
875
+ if "llava-v1.5" in model_path.lower():
876
+ return "vicuna_v1.1"
877
+ if "llava-next-video-7b" in model_path.lower():
878
+ return "vicuna_v1.1"
879
+
880
+
881
+ @register_conv_template_matching_function
882
+ def match_llama2_chat(model_path: str):
883
+ model_path = model_path.lower()
884
+ if "llama-2" in model_path and "chat" in model_path:
885
+ return "llama-2"
886
+ if (
887
+ "mistral" in model_path or "mixtral" in model_path
888
+ ) and "instruct" in model_path:
889
+ return "llama-2"
890
+ if "codellama" in model_path and "instruct" in model_path:
891
+ return "llama-2"
892
+
893
+
894
+ @register_conv_template_matching_function
895
+ def match_deepseek_vl(model_path: str):
896
+ model_path = model_path.lower()
897
+ if "deepseek" in model_path and "vl2" in model_path:
898
+ return "deepseek-vl2"
899
+
900
+
901
+ @register_conv_template_matching_function
902
+ def match_chat_ml(model_path: str):
903
+ # import pdb;pdb.set_trace()
904
+ model_path = model_path.lower()
905
+ # Now the suffix for qwen2 chat model is "instruct"
906
+ if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
907
+ return "gme-qwen2-vl"
908
+ if "qwen" in model_path and "vl" in model_path:
909
+ return "qwen2-vl"
910
+ if (
911
+ "llava-v1.6-34b" in model_path
912
+ or "llava-v1.6-yi-34b" in model_path
913
+ or "llava-next-video-34b" in model_path
914
+ or "llava-onevision-qwen2" in model_path
915
+ ):
916
+ return "chatml-llava"
917
+
918
+
919
+ @register_conv_template_matching_function
920
+ def match_gemma_it(model_path: str):
921
+ model_path = model_path.lower()
922
+ if "gemma" in model_path and "it" in model_path:
923
+ return "gemma-it"
924
+ if "gemma-3" in model_path and "1b" not in model_path:
925
+ # gemma-3-1b-it is completion model
926
+ return "gemma-it"
927
+
928
+
929
+ @register_conv_template_matching_function
930
+ def match_openbmb_minicpm(model_path: str):
931
+ model_path = model_path.lower()
932
+ if "minicpm-v" in model_path:
933
+ return "minicpmv"
934
+ elif "minicpm-o" in model_path:
935
+ return "minicpmo"
936
+
937
+
938
+ @register_conv_template_matching_function
939
+ def match_moonshot_kimivl(model_path: str):
940
+ model_path = model_path.lower()
941
+ if "kimi" in model_path and "vl" in model_path:
942
+ return "kimi-vl"
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import os
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from typing import TYPE_CHECKING, List, Optional, Tuple
@@ -97,7 +98,9 @@ class DecodePreallocQueue:
97
98
  self.tp_size = tp_size
98
99
  self.bootstrap_port = bootstrap_port
99
100
 
100
- self.num_reserved_decode_tokens = 512
101
+ self.num_reserved_decode_tokens = int(
102
+ os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
103
+ )
101
104
 
102
105
  # Queue for requests pending pre-allocation
103
106
  self.queue: List[DecodeRequest] = []
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
3
3
  """
4
4
 
5
5
  import asyncio
6
+ import dataclasses
7
+ import logging
6
8
  import random
7
9
  import urllib
8
10
  from itertools import chain
9
- from typing import List
11
+ from typing import List, Optional
10
12
 
11
13
  import aiohttp
12
14
  import orjson
@@ -14,11 +16,32 @@ import uvicorn
14
16
  from fastapi import FastAPI, HTTPException
15
17
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
16
18
 
19
+ from sglang.srt.disaggregation.utils import PDRegistryRequest
17
20
 
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("pdlb")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ @dataclasses.dataclass
18
42
  class PrefillConfig:
19
- def __init__(self, url: str, bootstrap_port: int):
20
- self.url = url
21
- self.bootstrap_port = bootstrap_port
43
+ url: str
44
+ bootstrap_port: Optional[int] = None
22
45
 
23
46
 
24
47
  class MiniLoadBalancer:
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
28
51
  self.decode_servers = decode_servers
29
52
 
30
53
  def select_pair(self):
54
+ # TODO: return some message instead of panic
55
+ assert len(self.prefill_configs) > 0, "No prefill servers available"
56
+ assert len(self.decode_servers) > 0, "No decode servers available"
57
+
31
58
  prefill_config = random.choice(self.prefill_configs)
32
59
  decode_server = random.choice(self.decode_servers)
33
60
  return prefill_config.url, prefill_config.bootstrap_port, decode_server
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
47
74
  session.post(f"{decode_server}/{endpoint}", json=modified_request),
48
75
  ]
49
76
  # Wait for both responses to complete. Prefill should end first.
50
- prefill_response, decode_response = await asyncio.gather(*tasks)
77
+ _, decode_response = await asyncio.gather(*tasks)
51
78
 
52
79
  return ORJSONResponse(
53
80
  content=await decode_response.json(),
@@ -268,6 +295,32 @@ async def get_models():
268
295
  raise HTTPException(status_code=500, detail=str(e))
269
296
 
270
297
 
298
+ @app.post("/register")
299
+ async def register(obj: PDRegistryRequest):
300
+ if obj.mode == "prefill":
301
+ load_balancer.prefill_configs.append(
302
+ PrefillConfig(obj.registry_url, obj.bootstrap_port)
303
+ )
304
+ logger.info(
305
+ f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
306
+ )
307
+ elif obj.mode == "decode":
308
+ load_balancer.decode_servers.append(obj.registry_url)
309
+ logger.info(f"Registered decode server: {obj.registry_url}")
310
+ else:
311
+ raise HTTPException(
312
+ status_code=400,
313
+ detail="Invalid mode. Must be either PREFILL or DECODE.",
314
+ )
315
+
316
+ logger.info(
317
+ f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
318
+ f"#Decode servers: {len(load_balancer.decode_servers)}"
319
+ )
320
+
321
+ return Response(status_code=200)
322
+
323
+
271
324
  def run(prefill_configs, decode_addrs, host, port):
272
325
  global load_balancer
273
326
  load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
@@ -279,15 +332,16 @@ if __name__ == "__main__":
279
332
 
280
333
  parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
281
334
  parser.add_argument(
282
- "--prefill", required=True, help="Comma-separated URLs for prefill servers"
335
+ "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
283
336
  )
284
337
  parser.add_argument(
285
- "--prefill-bootstrap-ports",
286
- help="Comma-separated bootstrap ports for prefill servers",
287
- default="8998",
338
+ "--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
288
339
  )
289
340
  parser.add_argument(
290
- "--decode", required=True, help="Comma-separated URLs for decode servers"
341
+ "--prefill-bootstrap-ports",
342
+ type=int,
343
+ nargs="+",
344
+ help="Bootstrap ports for prefill servers",
291
345
  )
292
346
  parser.add_argument(
293
347
  "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
@@ -297,22 +351,19 @@ if __name__ == "__main__":
297
351
  )
298
352
  args = parser.parse_args()
299
353
 
300
- prefill_urls = args.prefill.split(",")
301
- bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
-
303
- if len(bootstrap_ports) == 1:
304
- bootstrap_ports = bootstrap_ports * len(prefill_urls)
354
+ bootstrap_ports = args.prefill_bootstrap_ports
355
+ if bootstrap_ports is None:
356
+ bootstrap_ports = [None] * len(args.prefill)
357
+ elif len(bootstrap_ports) == 1:
358
+ bootstrap_ports = bootstrap_ports * len(args.prefill)
305
359
  else:
306
- if len(bootstrap_ports) != len(prefill_urls):
360
+ if len(bootstrap_ports) != len(args.prefill):
307
361
  raise ValueError(
308
362
  "Number of prefill URLs must match number of bootstrap ports"
309
363
  )
310
- exit(1)
311
-
312
- prefill_configs = []
313
- for url, port in zip(prefill_urls, bootstrap_ports):
314
- prefill_configs.append(PrefillConfig(url, port))
315
364
 
316
- decode_addrs = args.decode.split(",")
365
+ prefill_configs = [
366
+ PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
367
+ ]
317
368
 
318
- run(prefill_configs, decode_addrs, args.host, args.port)
369
+ run(prefill_configs, args.decode, args.host, args.port)
@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
37
37
  def group_concurrent_contiguous(
38
38
  src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
39
39
  ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
40
- src_groups = []
41
- dst_groups = []
42
- current_src = [src_indices[0]]
43
- current_dst = [dst_indices[0]]
44
-
45
- for i in range(1, len(src_indices)):
46
- src_contiguous = src_indices[i] == src_indices[i - 1] + 1
47
- dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
48
- if src_contiguous and dst_contiguous:
49
- current_src.append(src_indices[i])
50
- current_dst.append(dst_indices[i])
51
- else:
52
- src_groups.append(current_src)
53
- dst_groups.append(current_dst)
54
- current_src = [src_indices[i]]
55
- current_dst = [dst_indices[i]]
40
+ """Vectorised NumPy implementation."""
41
+ if src_indices.size == 0:
42
+ return [], []
43
+
44
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
45
+ src_groups = np.split(src_indices, brk)
46
+ dst_groups = np.split(dst_indices, brk)
56
47
 
57
- src_groups.append(current_src)
58
- dst_groups.append(current_dst)
48
+ src_groups = [g.tolist() for g in src_groups]
49
+ dst_groups = [g.tolist() for g in dst_groups]
59
50
 
60
51
  return src_groups, dst_groups
61
52