sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from typing import Deque, Dict, List, Optional, Tuple, Union
29
29
  import psutil
30
30
  import setproctitle
31
31
  import torch
32
+ import torch.distributed
32
33
  import zmq
33
34
  from torch.cuda import Stream as CudaStream
34
35
  from torch.cuda import StreamContext as CudaStreamContext
@@ -151,11 +152,13 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
151
152
  from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
152
153
  from sglang.srt.mem_cache.radix_cache import RadixCache
153
154
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
155
+ from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
154
156
  from sglang.srt.parser.reasoning_parser import ReasoningParser
155
157
  from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
156
158
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
157
159
  from sglang.srt.tracing.trace import (
158
160
  process_tracing_init,
161
+ trace_event_batch,
159
162
  trace_set_proc_propagate_context,
160
163
  trace_set_thread_info,
161
164
  trace_slice_batch,
@@ -168,7 +171,6 @@ from sglang.srt.utils import (
168
171
  broadcast_pyobj,
169
172
  configure_gc_logger,
170
173
  configure_logger,
171
- disable_request_logging,
172
174
  freeze_gc,
173
175
  get_available_gpu_memory,
174
176
  get_bool_env_var,
@@ -177,7 +179,6 @@ from sglang.srt.utils import (
177
179
  kill_itself_when_parent_died,
178
180
  numa_bind_to_node,
179
181
  point_to_point_pyobj,
180
- pyspy_dump_schedulers,
181
182
  require_mlp_sync,
182
183
  require_mlp_tp_gather,
183
184
  set_gpu_proc_affinity,
@@ -197,6 +198,7 @@ logger = logging.getLogger(__name__)
197
198
  # Test retract decode for debugging purposes
198
199
  TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
199
200
  TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
201
+ TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
200
202
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
201
203
 
202
204
 
@@ -212,6 +214,7 @@ class Scheduler(
212
214
  SchedulerMetricsMixin,
213
215
  SchedulerDisaggregationDecodeMixin,
214
216
  SchedulerDisaggregationPrefillMixin,
217
+ SchedulerMultiplexMixin,
215
218
  SchedulerRuntimeCheckerMixin,
216
219
  SchedulerPPMixin,
217
220
  ):
@@ -251,6 +254,7 @@ class Scheduler(
251
254
  self.enable_lora = server_args.enable_lora
252
255
  self.max_loras_per_batch = server_args.max_loras_per_batch
253
256
  self.enable_overlap = not server_args.disable_overlap_schedule
257
+ self.enable_pdmux = server_args.enable_pdmux
254
258
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
255
259
  self.enable_metrics = server_args.enable_metrics
256
260
  self.enable_metrics_for_all_schedulers = (
@@ -284,6 +288,10 @@ class Scheduler(
284
288
  # Init inter-process communication
285
289
  self.init_sockets(server_args, port_args)
286
290
 
291
+ # Init pdmux context
292
+ if self.enable_pdmux:
293
+ self.init_pdmux()
294
+
287
295
  # Init tokenizer
288
296
  self.init_tokenizer()
289
297
 
@@ -320,8 +328,28 @@ class Scheduler(
320
328
 
321
329
  # Launch a draft worker for speculative decoding
322
330
 
323
- self.launch_draft_worker(
324
- gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
331
+ draft_worker_kwargs = dict(
332
+ gpu_id=gpu_id,
333
+ tp_rank=tp_rank,
334
+ moe_ep_rank=moe_ep_rank,
335
+ server_args=server_args,
336
+ nccl_port=port_args.nccl_port,
337
+ target_worker=self.tp_worker,
338
+ dp_rank=dp_rank,
339
+ )
340
+
341
+ if server_args.speculative_draft_load_format is not None:
342
+ server_args.load_format = server_args.speculative_draft_load_format
343
+ logger.info(
344
+ f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
345
+ )
346
+
347
+ # Draft workers are looked up via `SpeculativeAlgorithm` registry; new
348
+ # algorithms should register their factory instead of patching this code.
349
+ if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
350
+ draft_worker_kwargs["enable_overlap"] = self.enable_overlap
351
+ self.draft_worker = self.spec_algorithm.create_draft_worker(
352
+ **draft_worker_kwargs
325
353
  )
326
354
 
327
355
  # Dispatch the model worker
@@ -356,6 +384,17 @@ class Scheduler(
356
384
  self.pp_group = get_pp_group()
357
385
  self.world_group = get_world_group()
358
386
 
387
+ # With DP attention enabled, the entry rank is attn_tp_rank==0;
388
+ # otherwise the entry rank is TP group local rank 0.
389
+ # For #11910, use the CPU communication group to broadcast VLM Python objects,
390
+ # avoiding any coupling with CUDA streams/devices.
391
+ if self.server_args.enable_dp_attention:
392
+ self.cpu_group = self.attn_tp_cpu_group
393
+ self.is_entry_rank = self.attn_tp_rank == 0
394
+ else:
395
+ self.cpu_group = self.tp_cpu_group
396
+ self.is_entry_rank = self.tp_group.rank == 0
397
+
359
398
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
360
399
  set_random_seed(self.random_seed)
361
400
 
@@ -392,6 +431,8 @@ class Scheduler(
392
431
  self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
393
432
  # The current forward batch
394
433
  self.cur_batch: Optional[ScheduleBatch] = None
434
+ # The current split prefill batch
435
+ self.split_prefill_batch: Optional[ScheduleBatch] = None
395
436
  # The last forward batch
396
437
  self.last_batch: Optional[ScheduleBatch] = None
397
438
  self.forward_ct = 0
@@ -548,57 +589,6 @@ class Scheduler(
548
589
  ]
549
590
  )
550
591
 
551
- def launch_draft_worker(
552
- self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
553
- ):
554
- if server_args.speculative_draft_load_format is not None:
555
- server_args.load_format = server_args.speculative_draft_load_format
556
- logger.info(
557
- f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
558
- )
559
-
560
- if self.spec_algorithm.is_eagle():
561
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
562
- from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
563
-
564
- WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
565
-
566
- self.draft_worker = WorkerClass(
567
- gpu_id=gpu_id,
568
- tp_rank=tp_rank,
569
- moe_ep_rank=moe_ep_rank,
570
- server_args=server_args,
571
- nccl_port=port_args.nccl_port,
572
- target_worker=self.tp_worker,
573
- dp_rank=dp_rank,
574
- )
575
- elif self.spec_algorithm.is_standalone():
576
- from sglang.srt.speculative.standalone_worker import StandaloneWorker
577
-
578
- self.draft_worker = StandaloneWorker(
579
- gpu_id=gpu_id,
580
- tp_rank=tp_rank,
581
- moe_ep_rank=moe_ep_rank,
582
- server_args=server_args,
583
- nccl_port=port_args.nccl_port,
584
- target_worker=self.tp_worker,
585
- dp_rank=dp_rank,
586
- )
587
- elif self.spec_algorithm.is_ngram():
588
- from sglang.srt.speculative.ngram_worker import NGRAMWorker
589
-
590
- self.draft_worker = NGRAMWorker(
591
- gpu_id=gpu_id,
592
- tp_rank=tp_rank,
593
- moe_ep_rank=moe_ep_rank,
594
- server_args=server_args,
595
- nccl_port=port_args.nccl_port,
596
- target_worker=self.tp_worker,
597
- dp_rank=dp_rank,
598
- )
599
- else:
600
- self.draft_worker = None
601
-
602
592
  def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
603
593
  context = zmq.Context(2)
604
594
  self.idle_sleeper = None
@@ -1162,6 +1152,70 @@ class Scheduler(
1162
1152
  self.max_req_len - len(req.origin_input_ids) - 1,
1163
1153
  )
1164
1154
 
1155
+ def _process_and_broadcast_mm_inputs(
1156
+ self,
1157
+ raw_mm_inputs: Optional[dict],
1158
+ ):
1159
+ """Materialize MultimodalInputs once on the entry rank and broadcast to others.
1160
+
1161
+ Entry rank:
1162
+ - constructs MultimodalInputs.from_dict(raw_mm_inputs) once
1163
+ - broadcasts to other ranks in self.cpu_group (if world_size > 1)
1164
+
1165
+ Non-entry ranks:
1166
+ - receive the object via broadcast (if world_size > 1)
1167
+ - otherwise (single-rank / no group) fall back to local from_dict
1168
+
1169
+ Returns:
1170
+ MultimodalInputs | None
1171
+ """
1172
+ if raw_mm_inputs is None:
1173
+ return None
1174
+
1175
+ group_world_size = 1
1176
+ try:
1177
+ if (
1178
+ torch.distributed.is_available()
1179
+ and torch.distributed.is_initialized()
1180
+ and self.cpu_group is not None
1181
+ ):
1182
+ group_world_size = torch.distributed.get_world_size(
1183
+ group=self.cpu_group
1184
+ )
1185
+ except Exception as e:
1186
+ logger.warning(
1187
+ f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
1188
+ )
1189
+
1190
+ # In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
1191
+ # process in CPU which occupies the main thread CPU cycle. This computing logic
1192
+ # merely needs to be run on TP0 and be broadcast to other TP ranks.
1193
+ # Since the Scheduler is single-threaded, any large CPU cost will impact
1194
+ # handling of other messages. For example, CPU hits 99.9% can significantly
1195
+ # increase the CUDA kernel launch time.
1196
+ if self.is_entry_rank:
1197
+ # Only the entry rank materializes once from dict.
1198
+ image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
1199
+ # Broadcast to other TP ranks (use src=0 within the group).
1200
+ if group_world_size > 1:
1201
+ obj_list = [image_inputs]
1202
+ torch.distributed.broadcast_object_list(
1203
+ obj_list, src=0, group=self.cpu_group
1204
+ )
1205
+ image_inputs = obj_list[0]
1206
+ else:
1207
+ # Non-entry ranks: receive if group size > 1; otherwise materialize locally.
1208
+ if group_world_size > 1:
1209
+ obj_list = [None]
1210
+ torch.distributed.broadcast_object_list(
1211
+ obj_list, src=0, group=self.cpu_group
1212
+ )
1213
+ image_inputs = obj_list[0]
1214
+ else:
1215
+ image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
1216
+
1217
+ return image_inputs
1218
+
1165
1219
  def handle_generate_request(
1166
1220
  self,
1167
1221
  recv_req: TokenizedGenerateReqInput,
@@ -1243,7 +1297,9 @@ class Scheduler(
1243
1297
 
1244
1298
  # Handle multimodal inputs
1245
1299
  if recv_req.mm_inputs is not None:
1246
- image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1300
+ image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)
1301
+
1302
+ # The following steps are already fast, execute locally on each rank.
1247
1303
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
1248
1304
  req.origin_input_ids = self.pad_input_ids_func(
1249
1305
  req.origin_input_ids, image_inputs
@@ -1376,7 +1432,7 @@ class Scheduler(
1376
1432
  self._prefetch_kvcache(req)
1377
1433
  self.waiting_queue.append(req)
1378
1434
  req.time_stats.wait_queue_entry_time = time.perf_counter()
1379
- trace_slice_end("process req", req.rid, auto_next_anon=True)
1435
+ trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
1380
1436
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1381
1437
  self._prefetch_kvcache(req)
1382
1438
  self.disagg_prefill_bootstrap_queue.add(
@@ -1466,13 +1522,14 @@ class Scheduler(
1466
1522
  recv_req.sampling_params,
1467
1523
  token_type_ids=recv_req.token_type_ids,
1468
1524
  priority=recv_req.priority,
1525
+ dimensions=recv_req.dimensions,
1469
1526
  http_worker_ipc=recv_req.http_worker_ipc,
1470
1527
  )
1471
1528
  req.tokenizer = self.tokenizer
1472
1529
 
1473
1530
  # Handle multimodal inputs
1474
1531
  if recv_req.image_inputs is not None:
1475
- image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1532
+ image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
1476
1533
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
1477
1534
  req.origin_input_ids = self.pad_input_ids_func(
1478
1535
  req.origin_input_ids, image_inputs
@@ -1639,6 +1696,10 @@ class Scheduler(
1639
1696
  if need_dp_attn_preparation:
1640
1697
  ret = self.prepare_mlp_sync_batch(ret)
1641
1698
 
1699
+ if ret:
1700
+ attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
1701
+ trace_event_batch("schedule", ret.reqs, attrs=attrs)
1702
+
1642
1703
  return ret
1643
1704
 
1644
1705
  def get_num_allocatable_reqs(self, running_bs):
@@ -1682,6 +1743,12 @@ class Scheduler(
1682
1743
  # Get priority queue
1683
1744
  self.policy.calc_priority(self.waiting_queue)
1684
1745
 
1746
+ if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
1747
+ # If we are testing retraction and the running batch size exceeds
1748
+ # TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
1749
+ # in the waiting queue.
1750
+ return None
1751
+
1685
1752
  # Prefill policy
1686
1753
  adder = PrefillAdder(
1687
1754
  self.page_size,
@@ -1848,14 +1915,14 @@ class Scheduler(
1848
1915
  self.num_retracted_reqs = len(retracted_reqs)
1849
1916
  self.new_token_ratio = new_token_ratio
1850
1917
  for req in reqs_to_abort:
1918
+ abort_reason: FINISH_ABORT = req.to_finish
1851
1919
  self.send_to_tokenizer.send_output(
1852
- AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
1920
+ AbortReq(abort_message=abort_reason.message, rid=req.rid), req
1853
1921
  )
1854
1922
 
1855
1923
  logger.info(
1856
1924
  "KV cache pool is full. Retract requests. "
1857
1925
  f"#retracted_reqs: {len(retracted_reqs)}, "
1858
- f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
1859
1926
  f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1860
1927
  )
1861
1928
 
@@ -1894,7 +1961,6 @@ class Scheduler(
1894
1961
 
1895
1962
  # Run forward
1896
1963
  if self.is_generation:
1897
-
1898
1964
  batch_or_worker_batch = batch
1899
1965
 
1900
1966
  if self.enable_overlap or self.spec_algorithm.is_none():
@@ -1951,6 +2017,9 @@ class Scheduler(
1951
2017
  # The future value, usually for next batch preparation
1952
2018
  # Current implementation strictly synchronizes the seq_lens
1953
2019
  batch.seq_lens = batch_result.next_draft_input.new_seq_lens
2020
+ elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
2021
+ batch_result = self.tp_worker.forward_batch_split_prefill(batch)
2022
+ future_indices_or_next_token_ids = batch_result.next_token_ids
1954
2023
  else:
1955
2024
  batch_result = self.model_worker.forward_batch_generation(
1956
2025
  batch_or_worker_batch
@@ -2012,13 +2081,10 @@ class Scheduler(
2012
2081
  ):
2013
2082
  if batch.forward_mode.is_decode():
2014
2083
  self.process_batch_result_decode(batch, result)
2015
- if self.enable_trace:
2016
- trace_slice_batch("decode loop", batch.reqs)
2084
+ trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
2017
2085
 
2018
2086
  elif batch.forward_mode.is_extend():
2019
2087
  self.process_batch_result_prefill(batch, result)
2020
- if self.enable_trace:
2021
- trace_slice_batch("prefill", batch.reqs)
2022
2088
 
2023
2089
  elif batch.forward_mode.is_idle():
2024
2090
  if self.enable_overlap:
@@ -2238,59 +2304,6 @@ class Scheduler(
2238
2304
  self._add_request_to_queue(req)
2239
2305
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2240
2306
 
2241
- def watchdog_thread(self):
2242
- """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2243
- self.watchdog_last_forward_ct = 0
2244
- self.watchdog_last_time = time.perf_counter()
2245
-
2246
- while True:
2247
- current = time.perf_counter()
2248
- if self.cur_batch is not None:
2249
- if self.watchdog_last_forward_ct == self.forward_ct:
2250
- if current > self.watchdog_last_time + self.watchdog_timeout:
2251
- break
2252
- else:
2253
- self.watchdog_last_forward_ct = self.forward_ct
2254
- self.watchdog_last_time = current
2255
- time.sleep(self.watchdog_timeout // 2)
2256
-
2257
- if not disable_request_logging():
2258
- # Print batch size and memory pool info to check whether there are de-sync issues.
2259
- if self.is_hybrid:
2260
- (
2261
- _,
2262
- _,
2263
- _,
2264
- _,
2265
- full_available_size,
2266
- full_evictable_size,
2267
- swa_available_size,
2268
- swa_evictable_size,
2269
- ) = self._get_swa_token_info()
2270
- info_msg = (
2271
- f"{full_available_size=}, "
2272
- f"{full_evictable_size=}, "
2273
- f"{swa_available_size=}, "
2274
- f"{swa_evictable_size=}, "
2275
- )
2276
- else:
2277
- _, _, available_size, evictable_size = self._get_token_info()
2278
- info_msg = f"{available_size=}, " f"{evictable_size=}, "
2279
- logger.error(
2280
- f"{self.cur_batch.batch_size()=}, "
2281
- f"{self.cur_batch.reqs=}, "
2282
- f"{info_msg}"
2283
- )
2284
-
2285
- pyspy_dump_schedulers()
2286
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2287
- print(file=sys.stderr, flush=True)
2288
- print(file=sys.stdout, flush=True)
2289
-
2290
- # Wait for some time so that the parent process can print the error.
2291
- time.sleep(5)
2292
- self.parent_process.send_signal(signal.SIGQUIT)
2293
-
2294
2307
  def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
2295
2308
  success = self.flush_cache()
2296
2309
  return FlushCacheReqOutput(success=success)
@@ -2305,13 +2318,30 @@ class Scheduler(
2305
2318
  if_success = False
2306
2319
  return ClearHiCacheReqOutput(success=if_success)
2307
2320
 
2308
- def flush_cache(self):
2309
- """Flush the memory pool and cache."""
2310
- if (
2321
+ def _is_no_request(self):
2322
+ no_request = (
2311
2323
  len(self.waiting_queue) == 0
2312
2324
  and self.running_batch.is_empty()
2325
+ and (self.last_batch is None or self.last_batch.is_empty())
2326
+ and (self.cur_batch is None or self.cur_batch.is_empty())
2327
+ and (not self.enable_overlap or len(self.result_queue) == 0)
2313
2328
  and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
2314
- ):
2329
+ )
2330
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
2331
+ no_request &= (
2332
+ len(self.disagg_prefill_bootstrap_queue.queue) == 0
2333
+ and len(self.disagg_prefill_inflight_queue) == 0
2334
+ )
2335
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2336
+ no_request &= (
2337
+ len(self.disagg_decode_prealloc_queue.queue) == 0
2338
+ and len(self.disagg_decode_transfer_queue.queue) == 0
2339
+ )
2340
+ return no_request
2341
+
2342
+ def flush_cache(self):
2343
+ """Flush the memory pool and cache."""
2344
+ if self._is_no_request():
2315
2345
  self.cur_batch = None
2316
2346
  self.last_batch = None
2317
2347
  self.tree_cache.reset()
@@ -2545,11 +2575,11 @@ class Scheduler(
2545
2575
  if not req.finished() and (
2546
2576
  recv_req.abort_all or req.rid.startswith(recv_req.rid)
2547
2577
  ):
2548
- # Abort method 3: set `to_abort=True`
2578
+ # Abort method 3: set `to_finish`
2549
2579
  # The request will still run one decode forward pass.
2550
2580
  # Then we reuse all existing code to clean up the KV cache allocation.
2551
2581
  logger.debug(f"Abort running request. {req.rid=}")
2552
- req.to_abort = True
2582
+ req.to_finish = FINISH_ABORT()
2553
2583
 
2554
2584
  def _pause_engine(self) -> Tuple[List[Req], int]:
2555
2585
  raise NotImplementedError()
@@ -2743,10 +2773,13 @@ def run_scheduler_process(
2743
2773
 
2744
2774
  # Set up tracing
2745
2775
  if server_args.enable_trace:
2746
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2747
- if server_args.disaggregation_mode == "null":
2748
- thread_label = "Scheduler"
2749
- trace_set_thread_info(thread_label, tp_rank, dp_rank)
2776
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
2777
+ thread_label = "Scheduler"
2778
+ if server_args.disaggregation_mode == "prefill":
2779
+ thread_label = "Prefill Scheduler"
2780
+ elif server_args.disaggregation_mode == "decode":
2781
+ thread_label = "Decode Scheduler"
2782
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2750
2783
 
2751
2784
  # Create a scheduler and run the event loop
2752
2785
  try:
@@ -2769,7 +2802,9 @@ def run_scheduler_process(
2769
2802
 
2770
2803
  disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2771
2804
  if disaggregation_mode == DisaggregationMode.NULL:
2772
- if server_args.pp_size > 1:
2805
+ if scheduler.enable_pdmux:
2806
+ scheduler.event_loop_pdmux()
2807
+ elif server_args.pp_size > 1:
2773
2808
  scheduler.event_loop_pp()
2774
2809
  elif scheduler.enable_overlap:
2775
2810
  scheduler.event_loop_overlap()
@@ -14,7 +14,13 @@ from sglang.srt.managers.io_struct import (
14
14
  BatchEmbeddingOutput,
15
15
  BatchTokenIDOutput,
16
16
  )
17
- from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
17
+ from sglang.srt.managers.schedule_batch import (
18
+ BaseFinishReason,
19
+ Req,
20
+ RequestStage,
21
+ ScheduleBatch,
22
+ )
23
+ from sglang.srt.tracing.trace import trace_slice
18
24
  from sglang.srt.utils.common import ceil_div
19
25
 
20
26
  if TYPE_CHECKING:
@@ -160,6 +166,14 @@ class SchedulerOutputProcessorMixin:
160
166
  )
161
167
  self.abort_request(AbortReq(rid=req.rid))
162
168
  req.grammar.finished = req.finished()
169
+
170
+ trace_slice(
171
+ RequestStage.PREFILL_FORWARD,
172
+ req.rid,
173
+ auto_next_anon=not req.finished(),
174
+ thread_finish_flag=req.finished(),
175
+ )
176
+
163
177
  else:
164
178
  # being chunked reqs' prefill is not finished
165
179
  req.is_chunked -= 1
@@ -188,6 +202,12 @@ class SchedulerOutputProcessorMixin:
188
202
  )
189
203
  logprob_pt += num_input_logprobs
190
204
 
205
+ trace_slice(
206
+ RequestStage.PREFILL_CHUNKED_FORWARD,
207
+ req.rid,
208
+ auto_next_anon=True,
209
+ )
210
+
191
211
  else: # embedding or reward model
192
212
  is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
193
213
 
@@ -203,7 +223,10 @@ class SchedulerOutputProcessorMixin:
203
223
  i
204
224
  ].item()
205
225
  else:
206
- embeddings = embeddings.tolist()
226
+ if isinstance(embeddings, torch.Tensor):
227
+ embeddings = embeddings.tolist()
228
+ else:
229
+ embeddings = [tensor.tolist() for tensor in embeddings]
207
230
 
208
231
  # Check finish conditions
209
232
  for i, req in enumerate(batch.reqs):
@@ -224,6 +247,13 @@ class SchedulerOutputProcessorMixin:
224
247
  # being chunked reqs' prefill is not finished
225
248
  req.is_chunked -= 1
226
249
 
250
+ trace_slice(
251
+ RequestStage.PREFILL_FORWARD,
252
+ req.rid,
253
+ auto_next_anon=not req.finished(),
254
+ thread_finish_flag=req.finished(),
255
+ )
256
+
227
257
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
228
258
 
229
259
  def _resolve_spec_overlap_token_ids(
@@ -727,6 +757,7 @@ class SchedulerOutputProcessorMixin:
727
757
  cached_tokens = []
728
758
  spec_verify_ct = []
729
759
  spec_accepted_tokens = []
760
+ retraction_counts = []
730
761
  output_hidden_states = None
731
762
 
732
763
  if return_logprob:
@@ -758,7 +789,7 @@ class SchedulerOutputProcessorMixin:
758
789
  continue
759
790
 
760
791
  # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
761
- if self.model_config.is_multimodal_gen and req.to_abort:
792
+ if self.model_config.is_multimodal_gen and req.to_finish:
762
793
  continue
763
794
 
764
795
  if req.finished():
@@ -828,6 +859,8 @@ class SchedulerOutputProcessorMixin:
828
859
  completion_tokens.append(len(output_ids_))
829
860
  cached_tokens.append(req.cached_tokens)
830
861
 
862
+ retraction_counts.append(req.retraction_count)
863
+
831
864
  if not self.spec_algorithm.is_none():
832
865
  spec_verify_ct.append(req.spec_verify_ct)
833
866
  spec_accepted_tokens.append(req.spec_accepted_tokens)
@@ -950,6 +983,7 @@ class SchedulerOutputProcessorMixin:
950
983
  http_worker_ipcs=http_worker_ipcs,
951
984
  placeholder_tokens_idx=None,
952
985
  placeholder_tokens_val=None,
986
+ retraction_counts=retraction_counts,
953
987
  )
954
988
  )
955
989
 
@@ -961,6 +995,7 @@ class SchedulerOutputProcessorMixin:
961
995
  embeddings = []
962
996
  prompt_tokens = []
963
997
  cached_tokens = []
998
+ retraction_counts = []
964
999
  for req in reqs:
965
1000
  if req.finished():
966
1001
  rids.append(req.rid)
@@ -969,6 +1004,7 @@ class SchedulerOutputProcessorMixin:
969
1004
  embeddings.append(req.embedding)
970
1005
  prompt_tokens.append(len(req.origin_input_ids))
971
1006
  cached_tokens.append(req.cached_tokens)
1007
+ retraction_counts.append(req.retraction_count)
972
1008
  self.send_to_detokenizer.send_output(
973
1009
  BatchEmbeddingOutput(
974
1010
  finished_reasons,
@@ -979,5 +1015,6 @@ class SchedulerOutputProcessorMixin:
979
1015
  http_worker_ipcs=http_worker_ipcs,
980
1016
  placeholder_tokens_idx=None,
981
1017
  placeholder_tokens_val=None,
1018
+ retraction_counts=retraction_counts,
982
1019
  )
983
1020
  )
@@ -4,7 +4,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
4
4
  from sglang.srt.managers.schedule_batch import ScheduleBatch
5
5
  from sglang.srt.managers.utils import GenerationBatchResult
6
6
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
7
- from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
7
+ from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj, require_mlp_sync
8
8
 
9
9
 
10
10
  class SchedulerPPMixin:
@@ -236,7 +236,12 @@ class SchedulerPPMixin:
236
236
  tmbs[mb_id] = transferred_rids
237
237
 
238
238
  self.process_prefill_chunk()
239
- mbs[mb_id] = self.get_new_batch_prefill()
239
+
240
+ batch = self.get_new_batch_prefill()
241
+ if require_mlp_sync(self.server_args):
242
+ batch = self.prepare_mlp_sync_batch(batch)
243
+ mbs[mb_id] = batch
244
+
240
245
  self.running_mbs[mb_id] = self.running_batch
241
246
 
242
247
  self.cur_batch = mbs[mb_id]