sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
61
61
  from sglang.srt.distributed.parallel_state import destroy_distributed_environment
62
62
  from sglang.srt.entrypoints.engine import _set_envs_and_config
63
63
  from sglang.srt.hf_transformers_utils import get_tokenizer
64
+ from sglang.srt.layers.moe import initialize_moe_config
64
65
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
65
66
  from sglang.srt.managers.scheduler import Scheduler
66
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -509,6 +510,8 @@ def latency_test(
509
510
  bench_args,
510
511
  tp_rank,
511
512
  ):
513
+ initialize_moe_config(server_args)
514
+
512
515
  # Set CPU affinity
513
516
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
514
517
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
5
5
  from sglang.srt.configs.janus_pro import MultiModalityConfig
6
6
  from sglang.srt.configs.kimi_vl import KimiVLConfig
7
7
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
8
+ from sglang.srt.configs.longcat_flash import LongcatFlashConfig
8
9
  from sglang.srt.configs.step3_vl import (
9
10
  Step3TextConfig,
10
11
  Step3VisionEncoderConfig,
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "ChatGLMConfig",
17
18
  "DbrxConfig",
18
19
  "DeepseekVL2Config",
20
+ "LongcatFlashConfig",
19
21
  "MultiModalityConfig",
20
22
  "KimiVLConfig",
21
23
  "MoonViTConfig",
@@ -0,0 +1,104 @@
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
+
8
+
9
+ class LongcatFlashConfig(PretrainedConfig):
10
+ model_type = "longcat_flash"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=131072,
16
+ hidden_size=6144,
17
+ intermediate_size=None,
18
+ ffn_hidden_size=12288,
19
+ expert_ffn_hidden_size=2048,
20
+ num_layers=28,
21
+ num_hidden_layers=None,
22
+ num_attention_heads=64,
23
+ ep_size=1,
24
+ kv_lora_rank=512,
25
+ q_lora_rank=1536,
26
+ qk_rope_head_dim=128,
27
+ qk_nope_head_dim=128,
28
+ v_head_dim=128,
29
+ n_routed_experts=512,
30
+ moe_topk=12,
31
+ norm_topk_prob=False,
32
+ max_position_embeddings=131072,
33
+ rms_norm_eps=1e-05,
34
+ use_cache=True,
35
+ pad_token_id=None,
36
+ bos_token_id=1,
37
+ eos_token_id=2,
38
+ pretraining_tp=1,
39
+ tie_word_embeddings=False,
40
+ rope_theta=10000000.0,
41
+ rope_scaling=None,
42
+ attention_bias=False,
43
+ attention_dropout=0.0,
44
+ mla_scale_q_lora=True,
45
+ mla_scale_kv_lora=True,
46
+ torch_dtype="bfloat16",
47
+ params_dtype="bfloat16",
48
+ rounter_params_dtype="float32",
49
+ router_bias=False,
50
+ topk_method=None,
51
+ routed_scaling_factor=6.0,
52
+ zero_expert_num=256,
53
+ zero_expert_type="identity",
54
+ nextn_use_scmoe=False,
55
+ num_nextn_predict_layers=1,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(
59
+ pad_token_id=pad_token_id,
60
+ bos_token_id=bos_token_id,
61
+ eos_token_id=eos_token_id,
62
+ tie_word_embeddings=tie_word_embeddings,
63
+ torch_dtype=torch_dtype,
64
+ params_dtype=params_dtype,
65
+ rounter_params_dtype=rounter_params_dtype,
66
+ topk_method=topk_method,
67
+ router_bias=router_bias,
68
+ nextn_use_scmoe=nextn_use_scmoe,
69
+ num_nextn_predict_layers=num_nextn_predict_layers,
70
+ **kwargs,
71
+ )
72
+ self.vocab_size = vocab_size
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.hidden_size = hidden_size
75
+ self.num_hidden_layers = (
76
+ num_hidden_layers if num_hidden_layers is not None else num_layers
77
+ )
78
+ self.intermediate_size = (
79
+ intermediate_size if intermediate_size is not None else ffn_hidden_size
80
+ )
81
+ self.moe_intermediate_size = expert_ffn_hidden_size
82
+ self.num_attention_heads = num_attention_heads
83
+ self.ep_size = ep_size
84
+ self.kv_lora_rank = kv_lora_rank
85
+ self.q_lora_rank = q_lora_rank
86
+ self.qk_rope_head_dim = qk_rope_head_dim
87
+ self.v_head_dim = v_head_dim
88
+ self.qk_nope_head_dim = qk_nope_head_dim
89
+ self.n_routed_experts = n_routed_experts
90
+ self.moe_topk = moe_topk
91
+ self.norm_topk_prob = norm_topk_prob
92
+ self.rms_norm_eps = rms_norm_eps
93
+ self.pretraining_tp = pretraining_tp
94
+ self.use_cache = use_cache
95
+ self.rope_theta = rope_theta
96
+ self.rope_scaling = rope_scaling
97
+ self.attention_bias = attention_bias
98
+ self.attention_dropout = attention_dropout
99
+ self.mla_scale_q_lora = mla_scale_q_lora
100
+ self.mla_scale_kv_lora = mla_scale_kv_lora
101
+ self.zero_expert_num = zero_expert_num
102
+ self.zero_expert_type = zero_expert_type
103
+ self.routed_scaling_factor = routed_scaling_factor
104
+ self.hidden_act = "silu"
@@ -132,6 +132,13 @@ class ModelConfig:
132
132
  if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
133
133
  self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
134
134
 
135
+ if (
136
+ is_draft_model
137
+ and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
138
+ ):
139
+ self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
140
+ self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
141
+
135
142
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
136
143
  self.hf_config.architectures[0] = "MiMoMTP"
137
144
  if (
@@ -199,6 +206,8 @@ class ModelConfig:
199
206
  "DeepseekV2ForCausalLM" in self.hf_config.architectures
200
207
  or "DeepseekV3ForCausalLM" in self.hf_config.architectures
201
208
  or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
209
+ or "LongcatFlashForCausalLM" in self.hf_config.architectures
210
+ or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
202
211
  ):
203
212
  self.head_dim = 256
204
213
  self.attention_arch = AttentionArch.MLA
@@ -270,6 +279,9 @@ class ModelConfig:
270
279
  self.num_key_value_heads = self.num_attention_heads
271
280
  self.hidden_size = self.hf_text_config.hidden_size
272
281
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
282
+ self.num_attention_layers = self.num_hidden_layers
283
+ if "LongcatFlashForCausalLM" in self.hf_config.architectures:
284
+ self.num_attention_layers = self.num_hidden_layers * 2
273
285
  self.num_nextn_predict_layers = getattr(
274
286
  self.hf_text_config, "num_nextn_predict_layers", None
275
287
  )
@@ -393,9 +405,10 @@ class ModelConfig:
393
405
  # compressed-tensors uses a "compression_config" key
394
406
  quant_cfg = getattr(self.hf_config, "compression_config", None)
395
407
  if quant_cfg is None:
396
- # check if is modelopt model -- modelopt doesn't have corresponding field
408
+ # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
397
409
  # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
398
410
  # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
411
+ # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
399
412
  is_local = os.path.exists(self.model_path)
400
413
  modelopt_quant_config = {"quant_method": "modelopt"}
401
414
  if not is_local:
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
20
20
  KV = "KV"
21
21
 
22
22
 
23
- def create_remote_connector(url, device="cpu") -> BaseConnector:
23
+ def create_remote_connector(url, **kwargs) -> BaseConnector:
24
24
  connector_type = parse_connector_type(url)
25
25
  if connector_type == "redis":
26
26
  return RedisConnector(url)
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
20
20
  <connector_type://<host>:<port>/<model_name>/files/<filename>
21
21
  """
22
22
 
23
- def __init__(self, url: str, device: torch.device = "cpu"):
23
+ def __init__(self, url: str):
24
24
  self.url = url
25
- self.device = device
26
25
  self.closed = False
27
26
  self.local_dir = tempfile.mkdtemp()
28
27
  for sig in (signal.SIGINT, signal.SIGTERM):
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class RedisConnector(BaseKVConnector):
17
17
 
18
- def __init__(self, url: str, device: torch.device = "cpu"):
18
+ def __init__(self, url: str):
19
19
  import redis
20
20
 
21
- super().__init__(url, device)
21
+ super().__init__(url)
22
22
  parsed_url = urlparse(url)
23
23
  self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
24
24
  self.model_name = parsed_url.path.lstrip("/")
@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
15
15
 
16
16
  if serde_type == "safe":
17
17
  s = SafeSerializer()
18
- d = SafeDeserializer(torch.uint8)
18
+ d = SafeDeserializer()
19
19
  else:
20
20
  raise ValueError(f"Unknown serde type: {serde_type}")
21
21
 
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
19
19
 
20
20
  class SafeDeserializer(Deserializer):
21
21
 
22
- def __init__(self, dtype):
23
- super().__init__(dtype)
22
+ def __init__(self):
23
+ # TODO: dtype options
24
+ super().__init__(torch.float32)
24
25
 
25
26
  def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
26
- return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
27
+ return load(bytes(b))["tensor_bytes"]
27
28
 
28
29
  def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
29
30
  return self.from_bytes_normal(b)
@@ -1,6 +1,12 @@
1
+ import concurrent.futures
1
2
  import logging
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
2
7
 
3
8
  from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
9
+ from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
4
10
  from sglang.srt.disaggregation.mooncake.conn import (
5
11
  MooncakeKVBootstrapServer,
6
12
  MooncakeKVManager,
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
29
35
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
30
36
  )
31
37
 
38
+ def send_kvcache(
39
+ self,
40
+ mooncake_session_id: str,
41
+ prefill_kv_indices: npt.NDArray[np.int32],
42
+ dst_kv_ptrs: list[int],
43
+ dst_kv_indices: npt.NDArray[np.int32],
44
+ executor: concurrent.futures.ThreadPoolExecutor,
45
+ ):
46
+ # Group by indices
47
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
48
+ prefill_kv_indices, dst_kv_indices
49
+ )
50
+
51
+ num_layers = len(self.kv_args.kv_data_ptrs)
52
+ layers_params = [
53
+ (
54
+ self.kv_args.kv_data_ptrs[layer_id],
55
+ dst_kv_ptrs[layer_id],
56
+ self.kv_args.kv_item_lens[layer_id],
57
+ )
58
+ for layer_id in range(num_layers)
59
+ ]
60
+
61
+ def set_transfer_blocks(
62
+ src_ptr: int, dst_ptr: int, item_len: int
63
+ ) -> List[Tuple[int, int, int]]:
64
+ transfer_blocks = []
65
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
66
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
67
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
68
+ length = item_len * len(prefill_index)
69
+ transfer_blocks.append((src_addr, dst_addr, length))
70
+ return transfer_blocks
71
+
72
+ # Worker function for processing a single layer
73
+ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
74
+ transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
75
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
76
+
77
+ # Worker function for processing all layers in a batch
78
+ def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
79
+ transfer_blocks = []
80
+ for src_ptr, dst_ptr, item_len in layers_params:
81
+ transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
82
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
83
+
84
+ if self.enable_custom_mem_pool:
85
+ futures = [
86
+ executor.submit(
87
+ process_layer,
88
+ src_ptr,
89
+ dst_ptr,
90
+ item_len,
91
+ )
92
+ for (src_ptr, dst_ptr, item_len) in layers_params
93
+ ]
94
+ for future in concurrent.futures.as_completed(futures):
95
+ status = future.result()
96
+ if status != 0:
97
+ for f in futures:
98
+ f.cancel()
99
+ return status
100
+ else:
101
+ # Combining all layers' params in one batch transfer is more efficient
102
+ # compared to using multiple threads
103
+ return process_layers(layers_params)
104
+
105
+ return 0
106
+
32
107
 
33
108
  class AscendKVSender(MooncakeKVSender):
34
109
  pass
@@ -6,7 +6,6 @@ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
6
6
 
7
7
  @dataclasses.dataclass
8
8
  class LBArgs:
9
- rust_lb: bool = False
10
9
  host: str = "0.0.0.0"
11
10
  port: int = 8000
12
11
  policy: str = "random"
@@ -17,11 +16,6 @@ class LBArgs:
17
16
 
18
17
  @staticmethod
19
18
  def add_cli_args(parser: argparse.ArgumentParser):
20
- parser.add_argument(
21
- "--rust-lb",
22
- action="store_true",
23
- help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
24
- )
25
19
  parser.add_argument(
26
20
  "--host",
27
21
  type=str,
@@ -92,7 +86,6 @@ class LBArgs:
92
86
  ]
93
87
 
94
88
  return cls(
95
- rust_lb=args.rust_lb,
96
89
  host=args.host,
97
90
  port=args.port,
98
91
  policy=args.policy,
@@ -102,12 +95,6 @@ class LBArgs:
102
95
  timeout=args.timeout,
103
96
  )
104
97
 
105
- def __post_init__(self):
106
- if not self.rust_lb:
107
- assert (
108
- self.policy == "random"
109
- ), "Only random policy is supported for Python load balancer"
110
-
111
98
 
112
99
  def main():
113
100
  parser = argparse.ArgumentParser(
@@ -7,6 +7,7 @@ import dataclasses
7
7
  import logging
8
8
  import random
9
9
  import urllib
10
+ from http import HTTPStatus
10
11
  from itertools import chain
11
12
  from typing import List, Optional
12
13
 
@@ -262,14 +263,38 @@ async def get_server_info():
262
263
 
263
264
  @app.get("/get_model_info")
264
265
  async def get_model_info():
265
- # Dummy model information
266
- model_info = {
267
- "model_path": "/path/to/dummy/model",
268
- "tokenizer_path": "/path/to/dummy/tokenizer",
269
- "is_generation": True,
270
- "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
271
- }
272
- return ORJSONResponse(content=model_info)
266
+ global load_balancer
267
+
268
+ if not load_balancer or not load_balancer.prefill_servers:
269
+ raise HTTPException(
270
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
271
+ detail="There is no server registered",
272
+ )
273
+
274
+ target_server_url = load_balancer.prefill_servers[0]
275
+ endpoint_url = f"{target_server_url}/get_model_info"
276
+
277
+ async with aiohttp.ClientSession() as session:
278
+ try:
279
+ async with session.get(endpoint_url) as response:
280
+ if response.status != 200:
281
+ error_text = await response.text()
282
+ raise HTTPException(
283
+ status_code=HTTPStatus.BAD_GATEWAY,
284
+ detail=(
285
+ f"Failed to get model info from {target_server_url}"
286
+ f"Status: {response.status}, Response: {error_text}"
287
+ ),
288
+ )
289
+
290
+ model_info_json = await response.json()
291
+ return ORJSONResponse(content=model_info_json)
292
+
293
+ except aiohttp.ClientError as e:
294
+ raise HTTPException(
295
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
296
+ detail=f"Failed to get model info from backend",
297
+ )
273
298
 
274
299
 
275
300
  @app.post("/generate")
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
567
567
  # Move the chunked request out of the batch so that we can merge
568
568
  # only finished requests to running_batch.
569
569
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
570
- self.tree_cache.cache_unfinished_req(self.chunked_req)
570
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
571
571
  if self.enable_overlap:
572
572
  # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
573
573
  self.chunked_req.tmp_end_idx = min(
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
45
  get_int_env_var,
46
+ is_cpu,
46
47
  is_cuda_alike,
47
48
  is_hip,
48
49
  is_npu,
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
51
52
  )
52
53
 
53
54
  _is_npu = is_npu()
55
+ _is_cpu = is_cpu()
56
+
57
+ IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
54
58
 
55
59
 
56
60
  @dataclass
@@ -223,10 +227,12 @@ class GroupCoordinator:
223
227
  use_message_queue_broadcaster: bool = False,
224
228
  group_name: Optional[str] = None,
225
229
  ):
230
+ # Set group info
226
231
  group_name = group_name or "anonymous"
227
232
  self.unique_name = _get_unique_name(group_name)
228
233
  _register_group(self)
229
234
 
235
+ # Set rank info
230
236
  self.rank = torch.distributed.get_rank()
231
237
  self.local_rank = local_rank
232
238
  self.device_group = None
@@ -250,14 +256,16 @@ class GroupCoordinator:
250
256
  assert self.cpu_group is not None
251
257
  assert self.device_group is not None
252
258
 
259
+ device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
253
260
  if is_cuda_alike():
254
- self.device = torch.device(f"cuda:{local_rank}")
261
+ self.device = torch.device(f"cuda:{device_id}")
255
262
  elif _is_npu:
256
- self.device = torch.device(f"npu:{local_rank}")
263
+ self.device = torch.device(f"npu:{device_id}")
257
264
  else:
258
265
  self.device = torch.device("cpu")
259
266
  self.device_module = torch.get_device_module(self.device)
260
267
 
268
+ # Import communicators
261
269
  self.use_pynccl = use_pynccl
262
270
  self.use_pymscclpp = use_pymscclpp
263
271
  self.use_custom_allreduce = use_custom_allreduce
@@ -270,6 +278,9 @@ class GroupCoordinator:
270
278
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
271
279
  CustomAllreduce,
272
280
  )
281
+ from sglang.srt.distributed.device_communicators.pymscclpp import (
282
+ PyMscclppCommunicator,
283
+ )
273
284
  from sglang.srt.distributed.device_communicators.pynccl import (
274
285
  PyNcclCommunicator,
275
286
  )
@@ -287,10 +298,6 @@ class GroupCoordinator:
287
298
  device=self.device,
288
299
  )
289
300
 
290
- from sglang.srt.distributed.device_communicators.pymscclpp import (
291
- PyMscclppCommunicator,
292
- )
293
-
294
301
  self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
295
302
  if use_pymscclpp and self.world_size > 1:
296
303
  self.pymscclpp_comm = PyMscclppCommunicator(
@@ -325,30 +332,30 @@ class GroupCoordinator:
325
332
  except Exception as e:
326
333
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
327
334
 
335
+ # Create communicator for other hardware backends
328
336
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
329
337
  HpuCommunicator,
330
338
  )
339
+ from sglang.srt.distributed.device_communicators.npu_communicator import (
340
+ NpuCommunicator,
341
+ )
342
+ from sglang.srt.distributed.device_communicators.xpu_communicator import (
343
+ XpuCommunicator,
344
+ )
331
345
 
332
346
  self.hpu_communicator: Optional[HpuCommunicator] = None
333
347
  if use_hpu_communicator and self.world_size > 1:
334
348
  self.hpu_communicator = HpuCommunicator(group=self.device_group)
335
349
 
336
- from sglang.srt.distributed.device_communicators.xpu_communicator import (
337
- XpuCommunicator,
338
- )
339
-
340
350
  self.xpu_communicator: Optional[XpuCommunicator] = None
341
351
  if use_xpu_communicator and self.world_size > 1:
342
352
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
343
353
 
344
- from sglang.srt.distributed.device_communicators.npu_communicator import (
345
- NpuCommunicator,
346
- )
347
-
348
354
  self.npu_communicator: Optional[NpuCommunicator] = None
349
355
  if use_npu_communicator and self.world_size > 1:
350
356
  self.npu_communicator = NpuCommunicator(group=self.device_group)
351
357
 
358
+ # Create message queue
352
359
  from sglang.srt.distributed.device_communicators.shm_broadcast import (
353
360
  MessageQueue,
354
361
  )
@@ -848,6 +855,11 @@ class GroupCoordinator:
848
855
  )
849
856
  return obj_list
850
857
 
858
+ def all_gather_object(self, obj: Any) -> List[Any]:
859
+ objs = [None] * self.world_size
860
+ torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
861
+ return objs
862
+
851
863
  def send_object(self, obj: Any, dst: int) -> None:
852
864
  """Send the input object list to the destination rank."""
853
865
  """NOTE: `dst` is the local rank of the destination rank."""
@@ -1633,7 +1645,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1633
1645
 
1634
1646
  ray.shutdown()
1635
1647
  gc.collect()
1636
- if not current_platform.is_cpu():
1648
+ if not _is_cpu:
1637
1649
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1638
1650
  torch.cuda.empty_cache()
1639
1651
  if hasattr(torch._C, "_host_emptyCache"):
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
60
60
  UpdateWeightsFromDistributedReqInput,
61
61
  UpdateWeightsFromTensorReqInput,
62
62
  )
63
+ from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
63
64
  from sglang.srt.managers.scheduler import run_scheduler_process
64
65
  from sglang.srt.managers.template_manager import TemplateManager
65
66
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -672,7 +673,7 @@ def _set_envs_and_config(server_args: ServerArgs):
672
673
  if server_args.attention_backend == "flashinfer":
673
674
  assert_pkg_version(
674
675
  "flashinfer_python",
675
- "0.2.14.post1",
676
+ "0.3.0",
676
677
  "Please uninstall the old version and "
677
678
  "reinstall the latest version by following the instructions "
678
679
  "at https://docs.flashinfer.ai/installation.html.",
@@ -680,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
680
681
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
681
682
  assert_pkg_version(
682
683
  "sgl-kernel",
683
- "0.3.7",
684
+ "0.3.8",
684
685
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
685
686
  )
686
687
 
@@ -814,18 +815,24 @@ def _launch_subprocesses(
814
815
  ),
815
816
  )
816
817
  detoken_proc.start()
818
+ if server_args.tokenizer_worker_num > 1:
819
+ # Launch multi-tokenizer router
820
+ tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
817
821
 
818
- # Launch tokenizer process
819
- tokenizer_manager = TokenizerManager(server_args, port_args)
822
+ # Initialize templates
823
+ template_manager = None
824
+ else:
825
+ # Launch tokenizer process
826
+ tokenizer_manager = TokenizerManager(server_args, port_args)
820
827
 
821
- # Initialize templates
822
- template_manager = TemplateManager()
823
- template_manager.initialize_templates(
824
- tokenizer_manager=tokenizer_manager,
825
- model_path=server_args.model_path,
826
- chat_template=server_args.chat_template,
827
- completion_template=server_args.completion_template,
828
- )
828
+ # Initialize templates
829
+ template_manager = TemplateManager()
830
+ template_manager.initialize_templates(
831
+ tokenizer_manager=tokenizer_manager,
832
+ model_path=server_args.model_path,
833
+ chat_template=server_args.chat_template,
834
+ completion_template=server_args.completion_template,
835
+ )
829
836
 
830
837
  # Wait for the model to finish loading
831
838
  scheduler_infos = []