sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +12 -0
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +24 -14
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  29. sglang/srt/layers/moe/ep_moe/layer.py +12 -6
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/topk.py +35 -12
  32. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  33. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  34. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  35. sglang/srt/layers/quantization/mxfp4.py +9 -4
  36. sglang/srt/layers/quantization/utils.py +13 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  38. sglang/srt/layers/rotary_embedding.py +28 -1
  39. sglang/srt/layers/sampler.py +29 -5
  40. sglang/srt/managers/cache_controller.py +62 -96
  41. sglang/srt/managers/detokenizer_manager.py +43 -2
  42. sglang/srt/managers/io_struct.py +27 -0
  43. sglang/srt/managers/mm_utils.py +5 -1
  44. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  45. sglang/srt/managers/scheduler.py +36 -2
  46. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  47. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  48. sglang/srt/managers/tokenizer_manager.py +86 -39
  49. sglang/srt/mem_cache/chunk_cache.py +1 -1
  50. sglang/srt/mem_cache/hicache_storage.py +20 -3
  51. sglang/srt/mem_cache/hiradix_cache.py +75 -68
  52. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  53. sglang/srt/mem_cache/memory_pool.py +4 -0
  54. sglang/srt/mem_cache/memory_pool_host.py +2 -4
  55. sglang/srt/mem_cache/radix_cache.py +5 -4
  56. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  57. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
  58. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  59. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  60. sglang/srt/model_executor/model_runner.py +5 -4
  61. sglang/srt/model_loader/loader.py +15 -24
  62. sglang/srt/model_loader/utils.py +12 -0
  63. sglang/srt/models/deepseek_v2.py +26 -10
  64. sglang/srt/models/gpt_oss.py +0 -14
  65. sglang/srt/models/llama_eagle3.py +4 -0
  66. sglang/srt/models/longcat_flash.py +1015 -0
  67. sglang/srt/models/longcat_flash_nextn.py +691 -0
  68. sglang/srt/models/qwen2.py +26 -3
  69. sglang/srt/models/qwen2_5_vl.py +65 -41
  70. sglang/srt/models/qwen2_moe.py +22 -2
  71. sglang/srt/models/transformers.py +1 -1
  72. sglang/srt/multimodal/processors/base_processor.py +4 -2
  73. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  74. sglang/srt/server_args.py +112 -55
  75. sglang/srt/speculative/eagle_worker.py +28 -8
  76. sglang/srt/utils.py +14 -0
  77. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  78. sglang/version.py +1 -1
  79. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
  80. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
  81. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  82. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
61
61
  from sglang.srt.distributed.parallel_state import destroy_distributed_environment
62
62
  from sglang.srt.entrypoints.engine import _set_envs_and_config
63
63
  from sglang.srt.hf_transformers_utils import get_tokenizer
64
+ from sglang.srt.layers.moe import initialize_moe_config
64
65
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
65
66
  from sglang.srt.managers.scheduler import Scheduler
66
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -509,6 +510,8 @@ def latency_test(
509
510
  bench_args,
510
511
  tp_rank,
511
512
  ):
513
+ initialize_moe_config(server_args)
514
+
512
515
  # Set CPU affinity
513
516
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
514
517
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
5
5
  from sglang.srt.configs.janus_pro import MultiModalityConfig
6
6
  from sglang.srt.configs.kimi_vl import KimiVLConfig
7
7
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
8
+ from sglang.srt.configs.longcat_flash import LongcatFlashConfig
8
9
  from sglang.srt.configs.step3_vl import (
9
10
  Step3TextConfig,
10
11
  Step3VisionEncoderConfig,
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "ChatGLMConfig",
17
18
  "DbrxConfig",
18
19
  "DeepseekVL2Config",
20
+ "LongcatFlashConfig",
19
21
  "MultiModalityConfig",
20
22
  "KimiVLConfig",
21
23
  "MoonViTConfig",
@@ -0,0 +1,104 @@
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
+
8
+
9
+ class LongcatFlashConfig(PretrainedConfig):
10
+ model_type = "longcat_flash"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=131072,
16
+ hidden_size=6144,
17
+ intermediate_size=None,
18
+ ffn_hidden_size=12288,
19
+ expert_ffn_hidden_size=2048,
20
+ num_layers=28,
21
+ num_hidden_layers=None,
22
+ num_attention_heads=64,
23
+ ep_size=1,
24
+ kv_lora_rank=512,
25
+ q_lora_rank=1536,
26
+ qk_rope_head_dim=128,
27
+ qk_nope_head_dim=128,
28
+ v_head_dim=128,
29
+ n_routed_experts=512,
30
+ moe_topk=12,
31
+ norm_topk_prob=False,
32
+ max_position_embeddings=131072,
33
+ rms_norm_eps=1e-05,
34
+ use_cache=True,
35
+ pad_token_id=None,
36
+ bos_token_id=1,
37
+ eos_token_id=2,
38
+ pretraining_tp=1,
39
+ tie_word_embeddings=False,
40
+ rope_theta=10000000.0,
41
+ rope_scaling=None,
42
+ attention_bias=False,
43
+ attention_dropout=0.0,
44
+ mla_scale_q_lora=True,
45
+ mla_scale_kv_lora=True,
46
+ torch_dtype="bfloat16",
47
+ params_dtype="bfloat16",
48
+ rounter_params_dtype="float32",
49
+ router_bias=False,
50
+ topk_method=None,
51
+ routed_scaling_factor=6.0,
52
+ zero_expert_num=256,
53
+ zero_expert_type="identity",
54
+ nextn_use_scmoe=False,
55
+ num_nextn_predict_layers=1,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(
59
+ pad_token_id=pad_token_id,
60
+ bos_token_id=bos_token_id,
61
+ eos_token_id=eos_token_id,
62
+ tie_word_embeddings=tie_word_embeddings,
63
+ torch_dtype=torch_dtype,
64
+ params_dtype=params_dtype,
65
+ rounter_params_dtype=rounter_params_dtype,
66
+ topk_method=topk_method,
67
+ router_bias=router_bias,
68
+ nextn_use_scmoe=nextn_use_scmoe,
69
+ num_nextn_predict_layers=num_nextn_predict_layers,
70
+ **kwargs,
71
+ )
72
+ self.vocab_size = vocab_size
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.hidden_size = hidden_size
75
+ self.num_hidden_layers = (
76
+ num_hidden_layers if num_hidden_layers is not None else num_layers
77
+ )
78
+ self.intermediate_size = (
79
+ intermediate_size if intermediate_size is not None else ffn_hidden_size
80
+ )
81
+ self.moe_intermediate_size = expert_ffn_hidden_size
82
+ self.num_attention_heads = num_attention_heads
83
+ self.ep_size = ep_size
84
+ self.kv_lora_rank = kv_lora_rank
85
+ self.q_lora_rank = q_lora_rank
86
+ self.qk_rope_head_dim = qk_rope_head_dim
87
+ self.v_head_dim = v_head_dim
88
+ self.qk_nope_head_dim = qk_nope_head_dim
89
+ self.n_routed_experts = n_routed_experts
90
+ self.moe_topk = moe_topk
91
+ self.norm_topk_prob = norm_topk_prob
92
+ self.rms_norm_eps = rms_norm_eps
93
+ self.pretraining_tp = pretraining_tp
94
+ self.use_cache = use_cache
95
+ self.rope_theta = rope_theta
96
+ self.rope_scaling = rope_scaling
97
+ self.attention_bias = attention_bias
98
+ self.attention_dropout = attention_dropout
99
+ self.mla_scale_q_lora = mla_scale_q_lora
100
+ self.mla_scale_kv_lora = mla_scale_kv_lora
101
+ self.zero_expert_num = zero_expert_num
102
+ self.zero_expert_type = zero_expert_type
103
+ self.routed_scaling_factor = routed_scaling_factor
104
+ self.hidden_act = "silu"
@@ -132,6 +132,13 @@ class ModelConfig:
132
132
  if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
133
133
  self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
134
134
 
135
+ if (
136
+ is_draft_model
137
+ and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
138
+ ):
139
+ self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
140
+ self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
141
+
135
142
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
136
143
  self.hf_config.architectures[0] = "MiMoMTP"
137
144
  if (
@@ -199,6 +206,8 @@ class ModelConfig:
199
206
  "DeepseekV2ForCausalLM" in self.hf_config.architectures
200
207
  or "DeepseekV3ForCausalLM" in self.hf_config.architectures
201
208
  or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
209
+ or "LongcatFlashForCausalLM" in self.hf_config.architectures
210
+ or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
202
211
  ):
203
212
  self.head_dim = 256
204
213
  self.attention_arch = AttentionArch.MLA
@@ -270,6 +279,9 @@ class ModelConfig:
270
279
  self.num_key_value_heads = self.num_attention_heads
271
280
  self.hidden_size = self.hf_text_config.hidden_size
272
281
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
282
+ self.num_attention_layers = self.num_hidden_layers
283
+ if "LongcatFlashForCausalLM" in self.hf_config.architectures:
284
+ self.num_attention_layers = self.num_hidden_layers * 2
273
285
  self.num_nextn_predict_layers = getattr(
274
286
  self.hf_text_config, "num_nextn_predict_layers", None
275
287
  )
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
20
20
  KV = "KV"
21
21
 
22
22
 
23
- def create_remote_connector(url, device="cpu") -> BaseConnector:
23
+ def create_remote_connector(url, **kwargs) -> BaseConnector:
24
24
  connector_type = parse_connector_type(url)
25
25
  if connector_type == "redis":
26
26
  return RedisConnector(url)
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
20
20
  <connector_type://<host>:<port>/<model_name>/files/<filename>
21
21
  """
22
22
 
23
- def __init__(self, url: str, device: torch.device = "cpu"):
23
+ def __init__(self, url: str):
24
24
  self.url = url
25
- self.device = device
26
25
  self.closed = False
27
26
  self.local_dir = tempfile.mkdtemp()
28
27
  for sig in (signal.SIGINT, signal.SIGTERM):
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class RedisConnector(BaseKVConnector):
17
17
 
18
- def __init__(self, url: str, device: torch.device = "cpu"):
18
+ def __init__(self, url: str):
19
19
  import redis
20
20
 
21
- super().__init__(url, device)
21
+ super().__init__(url)
22
22
  parsed_url = urlparse(url)
23
23
  self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
24
24
  self.model_name = parsed_url.path.lstrip("/")
@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
15
15
 
16
16
  if serde_type == "safe":
17
17
  s = SafeSerializer()
18
- d = SafeDeserializer(torch.uint8)
18
+ d = SafeDeserializer()
19
19
  else:
20
20
  raise ValueError(f"Unknown serde type: {serde_type}")
21
21
 
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
19
19
 
20
20
  class SafeDeserializer(Deserializer):
21
21
 
22
- def __init__(self, dtype):
23
- super().__init__(dtype)
22
+ def __init__(self):
23
+ # TODO: dtype options
24
+ super().__init__(torch.float32)
24
25
 
25
26
  def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
26
- return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
27
+ return load(bytes(b))["tensor_bytes"]
27
28
 
28
29
  def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
29
30
  return self.from_bytes_normal(b)
@@ -1,6 +1,12 @@
1
+ import concurrent.futures
1
2
  import logging
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
2
7
 
3
8
  from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
9
+ from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
4
10
  from sglang.srt.disaggregation.mooncake.conn import (
5
11
  MooncakeKVBootstrapServer,
6
12
  MooncakeKVManager,
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
29
35
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
30
36
  )
31
37
 
38
+ def send_kvcache(
39
+ self,
40
+ mooncake_session_id: str,
41
+ prefill_kv_indices: npt.NDArray[np.int32],
42
+ dst_kv_ptrs: list[int],
43
+ dst_kv_indices: npt.NDArray[np.int32],
44
+ executor: concurrent.futures.ThreadPoolExecutor,
45
+ ):
46
+ # Group by indices
47
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
48
+ prefill_kv_indices, dst_kv_indices
49
+ )
50
+
51
+ num_layers = len(self.kv_args.kv_data_ptrs)
52
+ layers_params = [
53
+ (
54
+ self.kv_args.kv_data_ptrs[layer_id],
55
+ dst_kv_ptrs[layer_id],
56
+ self.kv_args.kv_item_lens[layer_id],
57
+ )
58
+ for layer_id in range(num_layers)
59
+ ]
60
+
61
+ def set_transfer_blocks(
62
+ src_ptr: int, dst_ptr: int, item_len: int
63
+ ) -> List[Tuple[int, int, int]]:
64
+ transfer_blocks = []
65
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
66
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
67
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
68
+ length = item_len * len(prefill_index)
69
+ transfer_blocks.append((src_addr, dst_addr, length))
70
+ return transfer_blocks
71
+
72
+ # Worker function for processing a single layer
73
+ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
74
+ transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
75
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
76
+
77
+ # Worker function for processing all layers in a batch
78
+ def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
79
+ transfer_blocks = []
80
+ for src_ptr, dst_ptr, item_len in layers_params:
81
+ transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
82
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
83
+
84
+ if self.enable_custom_mem_pool:
85
+ futures = [
86
+ executor.submit(
87
+ process_layer,
88
+ src_ptr,
89
+ dst_ptr,
90
+ item_len,
91
+ )
92
+ for (src_ptr, dst_ptr, item_len) in layers_params
93
+ ]
94
+ for future in concurrent.futures.as_completed(futures):
95
+ status = future.result()
96
+ if status != 0:
97
+ for f in futures:
98
+ f.cancel()
99
+ return status
100
+ else:
101
+ # Combining all layers' params in one batch transfer is more efficient
102
+ # compared to using multiple threads
103
+ return process_layers(layers_params)
104
+
105
+ return 0
106
+
32
107
 
33
108
  class AscendKVSender(MooncakeKVSender):
34
109
  pass
@@ -6,7 +6,6 @@ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
6
6
 
7
7
  @dataclasses.dataclass
8
8
  class LBArgs:
9
- rust_lb: bool = False
10
9
  host: str = "0.0.0.0"
11
10
  port: int = 8000
12
11
  policy: str = "random"
@@ -17,11 +16,6 @@ class LBArgs:
17
16
 
18
17
  @staticmethod
19
18
  def add_cli_args(parser: argparse.ArgumentParser):
20
- parser.add_argument(
21
- "--rust-lb",
22
- action="store_true",
23
- help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
24
- )
25
19
  parser.add_argument(
26
20
  "--host",
27
21
  type=str,
@@ -92,7 +86,6 @@ class LBArgs:
92
86
  ]
93
87
 
94
88
  return cls(
95
- rust_lb=args.rust_lb,
96
89
  host=args.host,
97
90
  port=args.port,
98
91
  policy=args.policy,
@@ -102,12 +95,6 @@ class LBArgs:
102
95
  timeout=args.timeout,
103
96
  )
104
97
 
105
- def __post_init__(self):
106
- if not self.rust_lb:
107
- assert (
108
- self.policy == "random"
109
- ), "Only random policy is supported for Python load balancer"
110
-
111
98
 
112
99
  def main():
113
100
  parser = argparse.ArgumentParser(
@@ -7,6 +7,7 @@ import dataclasses
7
7
  import logging
8
8
  import random
9
9
  import urllib
10
+ from http import HTTPStatus
10
11
  from itertools import chain
11
12
  from typing import List, Optional
12
13
 
@@ -262,14 +263,38 @@ async def get_server_info():
262
263
 
263
264
  @app.get("/get_model_info")
264
265
  async def get_model_info():
265
- # Dummy model information
266
- model_info = {
267
- "model_path": "/path/to/dummy/model",
268
- "tokenizer_path": "/path/to/dummy/tokenizer",
269
- "is_generation": True,
270
- "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
271
- }
272
- return ORJSONResponse(content=model_info)
266
+ global load_balancer
267
+
268
+ if not load_balancer or not load_balancer.prefill_servers:
269
+ raise HTTPException(
270
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
271
+ detail="There is no server registered",
272
+ )
273
+
274
+ target_server_url = load_balancer.prefill_servers[0]
275
+ endpoint_url = f"{target_server_url}/get_model_info"
276
+
277
+ async with aiohttp.ClientSession() as session:
278
+ try:
279
+ async with session.get(endpoint_url) as response:
280
+ if response.status != 200:
281
+ error_text = await response.text()
282
+ raise HTTPException(
283
+ status_code=HTTPStatus.BAD_GATEWAY,
284
+ detail=(
285
+ f"Failed to get model info from {target_server_url}"
286
+ f"Status: {response.status}, Response: {error_text}"
287
+ ),
288
+ )
289
+
290
+ model_info_json = await response.json()
291
+ return ORJSONResponse(content=model_info_json)
292
+
293
+ except aiohttp.ClientError as e:
294
+ raise HTTPException(
295
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
296
+ detail=f"Failed to get model info from backend",
297
+ )
273
298
 
274
299
 
275
300
  @app.post("/generate")
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
567
567
  # Move the chunked request out of the batch so that we can merge
568
568
  # only finished requests to running_batch.
569
569
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
570
- self.tree_cache.cache_unfinished_req(self.chunked_req)
570
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
571
571
  if self.enable_overlap:
572
572
  # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
573
573
  self.chunked_req.tmp_end_idx = min(
@@ -52,6 +52,8 @@ from sglang.srt.utils import (
52
52
 
53
53
  _is_npu = is_npu()
54
54
 
55
+ IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
56
+
55
57
 
56
58
  @dataclass
57
59
  class GraphCaptureContext:
@@ -223,10 +225,12 @@ class GroupCoordinator:
223
225
  use_message_queue_broadcaster: bool = False,
224
226
  group_name: Optional[str] = None,
225
227
  ):
228
+ # Set group info
226
229
  group_name = group_name or "anonymous"
227
230
  self.unique_name = _get_unique_name(group_name)
228
231
  _register_group(self)
229
232
 
233
+ # Set rank info
230
234
  self.rank = torch.distributed.get_rank()
231
235
  self.local_rank = local_rank
232
236
  self.device_group = None
@@ -250,14 +254,16 @@ class GroupCoordinator:
250
254
  assert self.cpu_group is not None
251
255
  assert self.device_group is not None
252
256
 
257
+ device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
253
258
  if is_cuda_alike():
254
- self.device = torch.device(f"cuda:{local_rank}")
259
+ self.device = torch.device(f"cuda:{device_id}")
255
260
  elif _is_npu:
256
- self.device = torch.device(f"npu:{local_rank}")
261
+ self.device = torch.device(f"npu:{device_id}")
257
262
  else:
258
263
  self.device = torch.device("cpu")
259
264
  self.device_module = torch.get_device_module(self.device)
260
265
 
266
+ # Import communicators
261
267
  self.use_pynccl = use_pynccl
262
268
  self.use_pymscclpp = use_pymscclpp
263
269
  self.use_custom_allreduce = use_custom_allreduce
@@ -270,6 +276,9 @@ class GroupCoordinator:
270
276
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
271
277
  CustomAllreduce,
272
278
  )
279
+ from sglang.srt.distributed.device_communicators.pymscclpp import (
280
+ PyMscclppCommunicator,
281
+ )
273
282
  from sglang.srt.distributed.device_communicators.pynccl import (
274
283
  PyNcclCommunicator,
275
284
  )
@@ -287,10 +296,6 @@ class GroupCoordinator:
287
296
  device=self.device,
288
297
  )
289
298
 
290
- from sglang.srt.distributed.device_communicators.pymscclpp import (
291
- PyMscclppCommunicator,
292
- )
293
-
294
299
  self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
295
300
  if use_pymscclpp and self.world_size > 1:
296
301
  self.pymscclpp_comm = PyMscclppCommunicator(
@@ -325,30 +330,30 @@ class GroupCoordinator:
325
330
  except Exception as e:
326
331
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
327
332
 
333
+ # Create communicator for other hardware backends
328
334
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
329
335
  HpuCommunicator,
330
336
  )
337
+ from sglang.srt.distributed.device_communicators.npu_communicator import (
338
+ NpuCommunicator,
339
+ )
340
+ from sglang.srt.distributed.device_communicators.xpu_communicator import (
341
+ XpuCommunicator,
342
+ )
331
343
 
332
344
  self.hpu_communicator: Optional[HpuCommunicator] = None
333
345
  if use_hpu_communicator and self.world_size > 1:
334
346
  self.hpu_communicator = HpuCommunicator(group=self.device_group)
335
347
 
336
- from sglang.srt.distributed.device_communicators.xpu_communicator import (
337
- XpuCommunicator,
338
- )
339
-
340
348
  self.xpu_communicator: Optional[XpuCommunicator] = None
341
349
  if use_xpu_communicator and self.world_size > 1:
342
350
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
343
351
 
344
- from sglang.srt.distributed.device_communicators.npu_communicator import (
345
- NpuCommunicator,
346
- )
347
-
348
352
  self.npu_communicator: Optional[NpuCommunicator] = None
349
353
  if use_npu_communicator and self.world_size > 1:
350
354
  self.npu_communicator = NpuCommunicator(group=self.device_group)
351
355
 
356
+ # Create message queue
352
357
  from sglang.srt.distributed.device_communicators.shm_broadcast import (
353
358
  MessageQueue,
354
359
  )
@@ -848,6 +853,11 @@ class GroupCoordinator:
848
853
  )
849
854
  return obj_list
850
855
 
856
+ def all_gather_object(self, obj: Any) -> List[Any]:
857
+ objs = [None] * self.world_size
858
+ torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
859
+ return objs
860
+
851
861
  def send_object(self, obj: Any, dst: int) -> None:
852
862
  """Send the input object list to the destination rank."""
853
863
  """NOTE: `dst` is the local rank of the destination rank."""
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
60
60
  UpdateWeightsFromDistributedReqInput,
61
61
  UpdateWeightsFromTensorReqInput,
62
62
  )
63
+ from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
63
64
  from sglang.srt.managers.scheduler import run_scheduler_process
64
65
  from sglang.srt.managers.template_manager import TemplateManager
65
66
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -672,7 +673,7 @@ def _set_envs_and_config(server_args: ServerArgs):
672
673
  if server_args.attention_backend == "flashinfer":
673
674
  assert_pkg_version(
674
675
  "flashinfer_python",
675
- "0.2.14.post1",
676
+ "0.3.0",
676
677
  "Please uninstall the old version and "
677
678
  "reinstall the latest version by following the instructions "
678
679
  "at https://docs.flashinfer.ai/installation.html.",
@@ -680,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
680
681
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
681
682
  assert_pkg_version(
682
683
  "sgl-kernel",
683
- "0.3.7",
684
+ "0.3.7.post1",
684
685
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
685
686
  )
686
687
 
@@ -814,18 +815,24 @@ def _launch_subprocesses(
814
815
  ),
815
816
  )
816
817
  detoken_proc.start()
818
+ if server_args.tokenizer_worker_num > 1:
819
+ # Launch multi-tokenizer router
820
+ tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
817
821
 
818
- # Launch tokenizer process
819
- tokenizer_manager = TokenizerManager(server_args, port_args)
822
+ # Initialize templates
823
+ template_manager = None
824
+ else:
825
+ # Launch tokenizer process
826
+ tokenizer_manager = TokenizerManager(server_args, port_args)
820
827
 
821
- # Initialize templates
822
- template_manager = TemplateManager()
823
- template_manager.initialize_templates(
824
- tokenizer_manager=tokenizer_manager,
825
- model_path=server_args.model_path,
826
- chat_template=server_args.chat_template,
827
- completion_template=server_args.completion_template,
828
- )
828
+ # Initialize templates
829
+ template_manager = TemplateManager()
830
+ template_manager.initialize_templates(
831
+ tokenizer_manager=tokenizer_manager,
832
+ model_path=server_args.model_path,
833
+ chat_template=server_args.chat_template,
834
+ completion_template=server_args.completion_template,
835
+ )
829
836
 
830
837
  # Wait for the model to finish loading
831
838
  scheduler_infos = []