sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -418,6 +418,26 @@ if __name__ == "__main__":
418
418
  ServerArgs.add_cli_args(parser)
419
419
  BenchArgs.add_cli_args(parser)
420
420
  args = parser.parse_args()
421
+
422
+ # handling ModelScope model downloads
423
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
424
+ if os.path.exists(args.model_path):
425
+ print(f"Using local model path: {args.model_path}")
426
+ else:
427
+ try:
428
+ from modelscope import snapshot_download
429
+
430
+ print(f"Using ModelScope to download model: {args.model_path}")
431
+
432
+ # download the model and replace args.model_path
433
+ args.model_path = snapshot_download(
434
+ args.model_path,
435
+ )
436
+ print(f"Model downloaded to: {args.model_path}")
437
+ except Exception as e:
438
+ print(f"ModelScope download failed: {str(e)}")
439
+ raise e
440
+
421
441
  server_args = ServerArgs.from_cli_args(args)
422
442
  bench_args = BenchArgs.from_cli_args(args)
423
443
 
sglang/bench_one_batch.py CHANGED
@@ -138,6 +138,7 @@ class BenchArgs:
138
138
  def load_model(server_args, port_args, tp_rank):
139
139
  suppress_other_loggers()
140
140
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
141
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
141
142
 
142
143
  model_config = ModelConfig.from_server_args(server_args)
143
144
  model_runner = ModelRunner(
@@ -146,6 +147,8 @@ def load_model(server_args, port_args, tp_rank):
146
147
  gpu_id=tp_rank,
147
148
  tp_rank=tp_rank,
148
149
  tp_size=server_args.tp_size,
150
+ moe_ep_rank=moe_ep_rank,
151
+ moe_ep_size=server_args.ep_size,
149
152
  pp_rank=0,
150
153
  pp_size=1,
151
154
  nccl_port=port_args.nccl_port,
@@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig
5
5
  from sglang.srt.configs.janus_pro import MultiModalityConfig
6
6
  from sglang.srt.configs.kimi_vl import KimiVLConfig
7
7
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
8
+ from sglang.srt.configs.step3_vl import (
9
+ Step3TextConfig,
10
+ Step3VisionEncoderConfig,
11
+ Step3VLConfig,
12
+ )
8
13
 
9
14
  __all__ = [
10
15
  "ExaoneConfig",
@@ -14,4 +19,7 @@ __all__ = [
14
19
  "MultiModalityConfig",
15
20
  "KimiVLConfig",
16
21
  "MoonViTConfig",
22
+ "Step3VLConfig",
23
+ "Step3TextConfig",
24
+ "Step3VisionEncoderConfig",
17
25
  ]
@@ -112,6 +112,7 @@ class ModelConfig:
112
112
  mm_disabled_models = [
113
113
  "Gemma3ForConditionalGeneration",
114
114
  "Llama4ForConditionalGeneration",
115
+ "Step3VLForConditionalGeneration",
115
116
  ]
116
117
  if self.hf_config.architectures[0] in mm_disabled_models:
117
118
  enable_multimodal = False
@@ -335,6 +336,8 @@ class ModelConfig:
335
336
  "num_key_value_heads",
336
337
  # For ChatGLM:
337
338
  "multi_query_group_num",
339
+ # For Step3
340
+ "num_attention_groups",
338
341
  ]
339
342
  for attr in attributes:
340
343
  num_kv_heads = getattr(self.hf_text_config, attr, None)
@@ -644,6 +647,7 @@ multimodal_model_archs = [
644
647
  "InternS1ForConditionalGeneration",
645
648
  "Phi4MMForCausalLM",
646
649
  "VILAForConditionalGeneration",
650
+ "Step3VLForConditionalGeneration",
647
651
  ]
648
652
 
649
653
 
@@ -0,0 +1,172 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class Step3VisionEncoderConfig(PretrainedConfig):
7
+ model_type = "step3_vision_encoder"
8
+
9
+ def __init__(
10
+ self,
11
+ hidden_size=1792,
12
+ intermediate_size=3072,
13
+ output_hidden_size=4096,
14
+ num_hidden_layers=63,
15
+ num_attention_heads=16,
16
+ num_channels=3,
17
+ image_size=728,
18
+ patch_size=14,
19
+ hidden_act="quick_gelu",
20
+ layer_norm_eps=1e-5,
21
+ **kwargs,
22
+ ):
23
+ self.hidden_size = hidden_size
24
+ self.intermediate_size = intermediate_size
25
+ self.output_hidden_size = output_hidden_size
26
+ self.num_hidden_layers = num_hidden_layers
27
+ self.num_attention_heads = num_attention_heads
28
+ self.num_channels = num_channels
29
+ self.patch_size = patch_size
30
+ self.image_size = image_size
31
+ self.layer_norm_eps = layer_norm_eps
32
+ self.hidden_act = hidden_act
33
+ super().__init__(**kwargs)
34
+
35
+
36
+ class Step3TextConfig(PretrainedConfig):
37
+ model_type = "step3_text"
38
+ architectures = ["Step3TextForCausalLM"]
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size: int = 7168,
43
+ intermediate_size: int = 18432,
44
+ num_attention_heads: int = 64,
45
+ num_attention_groups: int = 1,
46
+ num_hidden_layers: int = 61,
47
+ max_seq_len: int = 65536,
48
+ vocab_size: int = 128815,
49
+ rms_norm_eps: float = 1e-5,
50
+ moe_intermediate_size: int = 5120,
51
+ moe_num_experts: int = 48,
52
+ moe_top_k: int = 3,
53
+ rope_theta: float = 500000,
54
+ rope_scaling: Optional[dict[str, Any]] = None,
55
+ max_position_embedding: int = 65536,
56
+ share_expert_dim: int = 5120,
57
+ share_q_dim: int = 2048,
58
+ head_dim: int = 256,
59
+ norm_expert_weight: bool = False,
60
+ moe_layers_enum: tuple[int] = (
61
+ 4,
62
+ 5,
63
+ 6,
64
+ 7,
65
+ 8,
66
+ 9,
67
+ 10,
68
+ 11,
69
+ 12,
70
+ 13,
71
+ 14,
72
+ 15,
73
+ 16,
74
+ 17,
75
+ 18,
76
+ 19,
77
+ 20,
78
+ 21,
79
+ 22,
80
+ 23,
81
+ 24,
82
+ 25,
83
+ 26,
84
+ 27,
85
+ 28,
86
+ 29,
87
+ 30,
88
+ 31,
89
+ 32,
90
+ 33,
91
+ 34,
92
+ 35,
93
+ 36,
94
+ 37,
95
+ 38,
96
+ 39,
97
+ 40,
98
+ 41,
99
+ 42,
100
+ 43,
101
+ 44,
102
+ 45,
103
+ 46,
104
+ 47,
105
+ 48,
106
+ 49,
107
+ 50,
108
+ 51,
109
+ 52,
110
+ 53,
111
+ 54,
112
+ 55,
113
+ 56,
114
+ 57,
115
+ 58,
116
+ 59,
117
+ ),
118
+ **kwargs,
119
+ ) -> None:
120
+ self.hidden_size = hidden_size
121
+ self.intermediate_size = intermediate_size
122
+ self.num_attention_heads = num_attention_heads
123
+ self.num_attention_groups = num_attention_groups
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.max_seq_len = max_seq_len
126
+ self.vocab_size = vocab_size
127
+ self.rms_norm_eps = rms_norm_eps
128
+ self.moe_intermediate_size = moe_intermediate_size
129
+ self.moe_num_experts = moe_num_experts
130
+ self.moe_top_k = moe_top_k
131
+ self.rope_theta = rope_theta
132
+ self.rope_scaling = rope_scaling
133
+ self.max_position_embedding = max_position_embedding
134
+ self.share_expert_dim = share_expert_dim
135
+ self.share_q_dim = share_q_dim
136
+ self.head_dim = head_dim
137
+ self.norm_expert_weight = norm_expert_weight
138
+ self.moe_layers_enum = moe_layers_enum
139
+
140
+ super().__init__(**kwargs)
141
+
142
+
143
+ class Step3VLConfig(PretrainedConfig):
144
+ model_type = "step3_vl"
145
+
146
+ def __init__(
147
+ self,
148
+ vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
149
+ text_config: Optional[Union[dict, Step3TextConfig]] = None,
150
+ understand_projector_stride: int = 1,
151
+ projector_bias: bool = True,
152
+ image_token_id: int = 128001,
153
+ **kwargs,
154
+ ) -> None:
155
+ if vision_config is None:
156
+ vision_config = Step3VisionEncoderConfig()
157
+ elif isinstance(vision_config, dict):
158
+ vision_config = Step3VisionEncoderConfig(**vision_config)
159
+ self.vision_config = vision_config
160
+
161
+ if text_config is None:
162
+ text_config = Step3TextConfig()
163
+ elif isinstance(text_config, dict):
164
+ text_config = Step3TextConfig(**text_config)
165
+ self.text_config = text_config
166
+
167
+ self.understand_projector_stride = understand_projector_stride
168
+ self.projector_bias = projector_bias
169
+ self.hidden_size = text_config.hidden_size
170
+ self.image_token_id = image_token_id
171
+
172
+ super().__init__(**kwargs)
@@ -994,6 +994,23 @@ register_conv_template(
994
994
  )
995
995
  )
996
996
 
997
+ register_conv_template(
998
+ Conversation(
999
+ name="step3-vl",
1000
+ system_message="<|begin▁of▁sentence|>You are a helpful assistant",
1001
+ system_template="{system_message}\n",
1002
+ roles=(
1003
+ "<|BOT|>user\n",
1004
+ "<|BOT|>assistant\n<think>\n",
1005
+ ),
1006
+ sep="<|EOT|>",
1007
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
1008
+ stop_str="<|EOT|>",
1009
+ image_token="<im_patch>",
1010
+ # add_bos=True,
1011
+ )
1012
+ )
1013
+
997
1014
 
998
1015
  @register_conv_template_matching_function
999
1016
  def match_internvl(model_path: str):
@@ -1103,3 +1120,9 @@ def match_vila(model_path: str):
1103
1120
  def match_mimo_vl(model_path: str):
1104
1121
  if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
1105
1122
  return "mimo-vl"
1123
+
1124
+
1125
+ # @register_conv_template_matching_function
1126
+ # def match_step3(model_path: str):
1127
+ # if re.search(r"step3", model_path, re.IGNORECASE):
1128
+ # return "step3-vl"
@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
694
694
  + len(self.disagg_decode_prealloc_queue.queue)
695
695
  == 0
696
696
  ):
697
- # When the server is idle, do self-check and re-init some states
698
- self.check_memory()
699
- self.new_token_ratio = self.init_new_token_ratio
700
- self.maybe_sleep_on_idle()
697
+ self.self_check_during_idle()
701
698
 
702
699
  self.last_batch = batch
703
700
 
@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
771
768
  + len(self.disagg_decode_prealloc_queue.queue)
772
769
  == 0
773
770
  ):
774
- # When the server is idle, do self-check and re-init some states
775
- self.check_memory()
776
- self.new_token_ratio = self.init_new_token_ratio
777
- self.maybe_sleep_on_idle()
771
+ self.self_check_during_idle()
778
772
 
779
773
  self.last_batch = batch
780
774
  self.last_batch_in_queue = last_batch_in_queue
@@ -1,6 +1,8 @@
1
1
  import argparse
2
2
  import dataclasses
3
3
 
4
+ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
5
+
4
6
 
5
7
  @dataclasses.dataclass
6
8
  class LBArgs:
@@ -18,7 +20,7 @@ class LBArgs:
18
20
  parser.add_argument(
19
21
  "--rust-lb",
20
22
  action="store_true",
21
- help="Use Rust load balancer",
23
+ help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
22
24
  )
23
25
  parser.add_argument(
24
26
  "--host",
@@ -115,25 +117,8 @@ def main():
115
117
  args = parser.parse_args()
116
118
  lb_args = LBArgs.from_cli_args(args)
117
119
 
118
- if lb_args.rust_lb:
119
- from sgl_pdlb._rust import LoadBalancer as RustLB
120
-
121
- RustLB(
122
- host=lb_args.host,
123
- port=lb_args.port,
124
- policy=lb_args.policy,
125
- prefill_infos=lb_args.prefill_infos,
126
- decode_infos=lb_args.decode_infos,
127
- log_interval=lb_args.log_interval,
128
- timeout=lb_args.timeout,
129
- ).start()
130
- else:
131
- from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
132
-
133
- prefill_configs = [
134
- PrefillConfig(url, port) for url, port in lb_args.prefill_infos
135
- ]
136
- run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
120
+ prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
121
+ run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
137
122
 
138
123
 
139
124
  if __name__ == "__main__":
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
38
  from sglang.srt.utils import (
39
39
  format_tcp_address,
40
+ get_bool_env_var,
40
41
  get_free_port,
41
42
  get_int_env_var,
42
43
  get_ip,
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
198
199
  self.bootstrap_timeout = get_int_env_var(
199
200
  "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
200
201
  )
202
+
203
+ self.enable_custom_mem_pool = get_bool_env_var(
204
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
205
+ )
201
206
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
202
207
  self.heartbeat_failures = {}
203
208
  self.session_pool = defaultdict(requests.Session)
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
258
263
  socket.connect(endpoint)
259
264
  return socket
260
265
 
266
+ def _transfer_data(self, mooncake_session_id, transfer_blocks):
267
+ if not transfer_blocks:
268
+ return 0
269
+
270
+ # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
271
+ if self.enable_custom_mem_pool:
272
+ # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
273
+ for src_addr, dst_addr, length in transfer_blocks:
274
+ status = self.engine.transfer_sync(
275
+ mooncake_session_id, src_addr, dst_addr, length
276
+ )
277
+ if status != 0:
278
+ return status
279
+ return 0
280
+ else:
281
+ src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
282
+ return self.engine.batch_transfer_sync(
283
+ mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
284
+ )
285
+
261
286
  def send_kvcache(
262
287
  self,
263
288
  mooncake_session_id: str,
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
283
308
 
284
309
  # Worker function for processing a single layer
285
310
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
311
+ transfer_blocks = []
286
312
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
287
313
  src_addr = src_ptr + int(prefill_index[0]) * item_len
288
314
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
289
315
  length = item_len * len(prefill_index)
316
+ transfer_blocks.append((src_addr, dst_addr, length))
290
317
 
291
- status = self.engine.transfer_sync(
292
- mooncake_session_id, src_addr, dst_addr, length
293
- )
294
- if status != 0:
295
- return status
296
- return 0
318
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
297
319
 
298
320
  futures = [
299
321
  executor.submit(
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
465
487
  dst_aux_ptrs: list[int],
466
488
  dst_aux_index: int,
467
489
  ):
468
- src_addr_list = []
469
- dst_addr_list = []
470
- length_list = []
490
+ transfer_blocks = []
471
491
  prefill_aux_ptrs = self.kv_args.aux_data_ptrs
472
492
  prefill_aux_item_lens = self.kv_args.aux_item_lens
493
+
473
494
  for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
474
495
  length = prefill_aux_item_lens[i]
475
496
  src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
476
497
  dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
477
- src_addr_list.append(src_addr)
478
- dst_addr_list.append(dst_addr)
479
- length_list.append(length)
480
- return self.engine.batch_transfer_sync(
481
- mooncake_session_id, src_addr_list, dst_addr_list, length_list
482
- )
498
+ transfer_blocks.append((src_addr, dst_addr, length))
499
+
500
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
483
501
 
484
502
  def sync_status_to_decode_endpoint(
485
503
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
287
287
  self.process_disagg_prefill_inflight_queue()
288
288
 
289
289
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
290
- self.check_memory()
291
- self.new_token_ratio = self.init_new_token_ratio
292
- self.maybe_sleep_on_idle()
290
+ self.self_check_during_idle()
293
291
 
294
292
  self.last_batch = batch
295
293
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
337
335
  self.process_disagg_prefill_inflight_queue()
338
336
 
339
337
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
340
- self.check_memory()
341
- self.new_token_ratio = self.init_new_token_ratio
342
- self.maybe_sleep_on_idle()
338
+ self.self_check_during_idle()
343
339
 
344
340
  self.last_batch = batch
345
341
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -354,6 +354,13 @@ class GroupCoordinator:
354
354
  self.cpu_group, 1 << 22, 6
355
355
  )
356
356
 
357
+ def __repr__(self):
358
+ return (
359
+ f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} "
360
+ f"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} "
361
+ f"world_size={self.world_size} rank_in_group={self.rank_in_group}"
362
+ )
363
+
357
364
  @property
358
365
  def first_rank(self):
359
366
  """Return the global rank of the first process in the group"""
@@ -1141,6 +1148,20 @@ def get_tp_group() -> GroupCoordinator:
1141
1148
  return _TP
1142
1149
 
1143
1150
 
1151
+ _MOE_EP: Optional[GroupCoordinator] = None
1152
+ _MOE_TP: Optional[GroupCoordinator] = None
1153
+
1154
+
1155
+ def get_moe_ep_group() -> GroupCoordinator:
1156
+ assert _MOE_EP is not None, "expert model parallel group is not initialized"
1157
+ return _MOE_EP
1158
+
1159
+
1160
+ def get_moe_tp_group() -> GroupCoordinator:
1161
+ assert _MOE_TP is not None, "expert model parallel group is not initialized"
1162
+ return _MOE_TP
1163
+
1164
+
1144
1165
  # kept for backward compatibility
1145
1166
  get_tensor_model_parallel_group = get_tp_group
1146
1167
 
@@ -1250,6 +1271,7 @@ def init_distributed_environment(
1250
1271
 
1251
1272
  def initialize_model_parallel(
1252
1273
  tensor_model_parallel_size: int = 1,
1274
+ expert_model_parallel_size: int = 1,
1253
1275
  pipeline_model_parallel_size: int = 1,
1254
1276
  backend: Optional[str] = None,
1255
1277
  duplicate_tp_group: bool = False,
@@ -1327,6 +1349,45 @@ def initialize_model_parallel(
1327
1349
  _TP.pynccl_comm.disabled = False
1328
1350
  _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1329
1351
 
1352
+ moe_ep_size = expert_model_parallel_size
1353
+
1354
+ moe_tp_size = tensor_model_parallel_size // moe_ep_size
1355
+ global _MOE_EP
1356
+ assert _MOE_EP is None, "expert model parallel group is already initialized"
1357
+ group_ranks = []
1358
+ for i in range(num_tensor_model_parallel_groups):
1359
+ for j in range(moe_tp_size):
1360
+ st = i * tensor_model_parallel_size + j
1361
+ en = (i + 1) * tensor_model_parallel_size + j
1362
+ ranks = list(range(st, en, moe_tp_size))
1363
+ group_ranks.append(ranks)
1364
+
1365
+ _MOE_EP = init_model_parallel_group(
1366
+ group_ranks,
1367
+ get_world_group().local_rank,
1368
+ backend,
1369
+ use_custom_allreduce=False,
1370
+ group_name="moe_ep",
1371
+ )
1372
+
1373
+ global _MOE_TP
1374
+ assert _MOE_TP is None, "expert model parallel group is already initialized"
1375
+ group_ranks = []
1376
+ for i in range(num_tensor_model_parallel_groups):
1377
+ for j in range(moe_ep_size):
1378
+ st = i * tensor_model_parallel_size + j * moe_tp_size
1379
+ en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1380
+ ranks = list(range(st, en))
1381
+ group_ranks.append(ranks)
1382
+
1383
+ _MOE_TP = init_model_parallel_group(
1384
+ group_ranks,
1385
+ get_world_group().local_rank,
1386
+ backend,
1387
+ use_custom_allreduce=False,
1388
+ group_name="moe_tp",
1389
+ )
1390
+
1330
1391
  # Build the pipeline model-parallel groups.
1331
1392
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
1332
1393
  global _PP
@@ -1347,6 +1408,7 @@ def initialize_model_parallel(
1347
1408
 
1348
1409
  def ensure_model_parallel_initialized(
1349
1410
  tensor_model_parallel_size: int,
1411
+ expert_model_parallel_size: int,
1350
1412
  pipeline_model_parallel_size: int,
1351
1413
  backend: Optional[str] = None,
1352
1414
  ) -> None:
@@ -1357,7 +1419,10 @@ def ensure_model_parallel_initialized(
1357
1419
  backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1358
1420
  if not model_parallel_is_initialized():
1359
1421
  initialize_model_parallel(
1360
- tensor_model_parallel_size, pipeline_model_parallel_size, backend
1422
+ tensor_model_parallel_size,
1423
+ expert_model_parallel_size,
1424
+ pipeline_model_parallel_size,
1425
+ backend,
1361
1426
  )
1362
1427
  return
1363
1428
 
@@ -1417,6 +1482,26 @@ def get_tensor_model_parallel_rank():
1417
1482
  return get_tp_group().rank_in_group
1418
1483
 
1419
1484
 
1485
+ def get_moe_expert_parallel_world_size():
1486
+ """Return world size for the moe expert parallel group."""
1487
+ return get_moe_ep_group().world_size
1488
+
1489
+
1490
+ def get_moe_expert_parallel_rank():
1491
+ """Return my rank for the moe expert parallel group."""
1492
+ return get_moe_ep_group().rank_in_group
1493
+
1494
+
1495
+ def get_moe_tensor_parallel_world_size():
1496
+ """Return world size for the moe tensor parallel group."""
1497
+ return get_moe_tp_group().world_size
1498
+
1499
+
1500
+ def get_moe_tensor_parallel_rank():
1501
+ """Return my rank for the moe tensor parallel group."""
1502
+ return get_moe_tp_group().rank_in_group
1503
+
1504
+
1420
1505
  def destroy_model_parallel():
1421
1506
  """Set the groups to none and destroy them."""
1422
1507
  global _TP
@@ -648,29 +648,23 @@ def _set_envs_and_config(server_args: ServerArgs):
648
648
  if _is_cuda:
649
649
  assert_pkg_version(
650
650
  "sgl-kernel",
651
- "0.2.7",
651
+ "0.2.8",
652
652
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
653
653
  )
654
654
 
655
- def sigchld_handler(signum, frame):
656
- pid, exitcode = os.waitpid(0, os.WNOHANG)
657
- if exitcode != 0:
658
- logger.warning(
659
- f"Child process unexpectedly failed with {exitcode=}. {pid=}"
655
+ if True: # Keep this check for internal code compatibility
656
+ # Register the signal handler.
657
+ # The child processes will send SIGQUIT to this process when any error happens
658
+ # This process then clean up the whole process tree
659
+ # Note: This sigquit handler is used in the launch phase, and may be replaced by
660
+ # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
661
+ def launch_phase_sigquit_handler(signum, frame):
662
+ logger.error(
663
+ "Received sigquit from a child process. It usually means the child failed."
660
664
  )
665
+ kill_process_tree(os.getpid())
661
666
 
662
- signal.signal(signal.SIGCHLD, sigchld_handler)
663
-
664
- # Register the signal handler.
665
- # The child processes will send SIGQUIT to this process when any error happens
666
- # This process then clean up the whole process tree
667
- def sigquit_handler(signum, frame):
668
- logger.error(
669
- "Received sigquit from a child process. It usually means the child failed."
670
- )
671
- kill_process_tree(os.getpid())
672
-
673
- signal.signal(signal.SIGQUIT, sigquit_handler)
667
+ signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
674
668
 
675
669
  # Set mp start method
676
670
  mp.set_start_method("spawn", force=True)
@@ -725,6 +719,7 @@ def _launch_subprocesses(
725
719
  + ((pp_rank % pp_size_per_node) * tp_size_per_node)
726
720
  + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
727
721
  )
722
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
728
723
  proc = mp.Process(
729
724
  target=run_scheduler_process,
730
725
  args=(
@@ -732,6 +727,7 @@ def _launch_subprocesses(
732
727
  port_args,
733
728
  gpu_id,
734
729
  tp_rank,
730
+ moe_ep_rank,
735
731
  pp_rank,
736
732
  None,
737
733
  writer,