sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from typing import Deque, Dict, List, Optional, Tuple, Union
29
29
  import psutil
30
30
  import setproctitle
31
31
  import torch
32
+ import torch.distributed
32
33
  import zmq
33
34
  from torch.cuda import Stream as CudaStream
34
35
  from torch.cuda import StreamContext as CudaStreamContext
@@ -151,11 +152,13 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
151
152
  from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
152
153
  from sglang.srt.mem_cache.radix_cache import RadixCache
153
154
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
155
+ from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
154
156
  from sglang.srt.parser.reasoning_parser import ReasoningParser
155
157
  from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
156
158
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
157
159
  from sglang.srt.tracing.trace import (
158
160
  process_tracing_init,
161
+ trace_event_batch,
159
162
  trace_set_proc_propagate_context,
160
163
  trace_set_thread_info,
161
164
  trace_slice_batch,
@@ -168,7 +171,6 @@ from sglang.srt.utils import (
168
171
  broadcast_pyobj,
169
172
  configure_gc_logger,
170
173
  configure_logger,
171
- disable_request_logging,
172
174
  freeze_gc,
173
175
  get_available_gpu_memory,
174
176
  get_bool_env_var,
@@ -177,7 +179,6 @@ from sglang.srt.utils import (
177
179
  kill_itself_when_parent_died,
178
180
  numa_bind_to_node,
179
181
  point_to_point_pyobj,
180
- pyspy_dump_schedulers,
181
182
  require_mlp_sync,
182
183
  require_mlp_tp_gather,
183
184
  set_gpu_proc_affinity,
@@ -197,6 +198,7 @@ logger = logging.getLogger(__name__)
197
198
  # Test retract decode for debugging purposes
198
199
  TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
199
200
  TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
201
+ TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
200
202
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
201
203
 
202
204
 
@@ -212,6 +214,7 @@ class Scheduler(
212
214
  SchedulerMetricsMixin,
213
215
  SchedulerDisaggregationDecodeMixin,
214
216
  SchedulerDisaggregationPrefillMixin,
217
+ SchedulerMultiplexMixin,
215
218
  SchedulerRuntimeCheckerMixin,
216
219
  SchedulerPPMixin,
217
220
  ):
@@ -251,6 +254,7 @@ class Scheduler(
251
254
  self.enable_lora = server_args.enable_lora
252
255
  self.max_loras_per_batch = server_args.max_loras_per_batch
253
256
  self.enable_overlap = not server_args.disable_overlap_schedule
257
+ self.enable_pdmux = server_args.enable_pdmux
254
258
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
255
259
  self.enable_metrics = server_args.enable_metrics
256
260
  self.enable_metrics_for_all_schedulers = (
@@ -284,6 +288,10 @@ class Scheduler(
284
288
  # Init inter-process communication
285
289
  self.init_sockets(server_args, port_args)
286
290
 
291
+ # Init pdmux context
292
+ if self.enable_pdmux:
293
+ self.init_pdmux()
294
+
287
295
  # Init tokenizer
288
296
  self.init_tokenizer()
289
297
 
@@ -320,8 +328,28 @@ class Scheduler(
320
328
 
321
329
  # Launch a draft worker for speculative decoding
322
330
 
323
- self.launch_draft_worker(
324
- gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
331
+ draft_worker_kwargs = dict(
332
+ gpu_id=gpu_id,
333
+ tp_rank=tp_rank,
334
+ moe_ep_rank=moe_ep_rank,
335
+ server_args=server_args,
336
+ nccl_port=port_args.nccl_port,
337
+ target_worker=self.tp_worker,
338
+ dp_rank=dp_rank,
339
+ )
340
+
341
+ if server_args.speculative_draft_load_format is not None:
342
+ server_args.load_format = server_args.speculative_draft_load_format
343
+ logger.info(
344
+ f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
345
+ )
346
+
347
+ # Draft workers are looked up via `SpeculativeAlgorithm` registry; new
348
+ # algorithms should register their factory instead of patching this code.
349
+ if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
350
+ draft_worker_kwargs["enable_overlap"] = self.enable_overlap
351
+ self.draft_worker = self.spec_algorithm.create_draft_worker(
352
+ **draft_worker_kwargs
325
353
  )
326
354
 
327
355
  # Dispatch the model worker
@@ -356,6 +384,17 @@ class Scheduler(
356
384
  self.pp_group = get_pp_group()
357
385
  self.world_group = get_world_group()
358
386
 
387
+ # With DP attention enabled, the entry rank is attn_tp_rank==0;
388
+ # otherwise the entry rank is TP group local rank 0.
389
+ # For #11910, use the CPU communication group to broadcast VLM Python objects,
390
+ # avoiding any coupling with CUDA streams/devices.
391
+ if self.server_args.enable_dp_attention:
392
+ self.cpu_group = self.attn_tp_cpu_group
393
+ self.is_entry_rank = self.attn_tp_rank == 0
394
+ else:
395
+ self.cpu_group = self.tp_cpu_group
396
+ self.is_entry_rank = self.tp_group.rank == 0
397
+
359
398
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
360
399
  set_random_seed(self.random_seed)
361
400
 
@@ -392,6 +431,8 @@ class Scheduler(
392
431
  self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
393
432
  # The current forward batch
394
433
  self.cur_batch: Optional[ScheduleBatch] = None
434
+ # The current split prefill batch
435
+ self.split_prefill_batch: Optional[ScheduleBatch] = None
395
436
  # The last forward batch
396
437
  self.last_batch: Optional[ScheduleBatch] = None
397
438
  self.forward_ct = 0
@@ -494,7 +535,7 @@ class Scheduler(
494
535
  )
495
536
  self.init_disaggregation()
496
537
 
497
- if get_bool_env_var("SGLANG_GC_LOG"):
538
+ if envs.SGLANG_LOG_GC.get():
498
539
  configure_gc_logger()
499
540
 
500
541
  # Init prefill kv split size when deterministic inference is enabled with various attention backends
@@ -548,57 +589,6 @@ class Scheduler(
548
589
  ]
549
590
  )
550
591
 
551
- def launch_draft_worker(
552
- self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
553
- ):
554
- if server_args.speculative_draft_load_format is not None:
555
- server_args.load_format = server_args.speculative_draft_load_format
556
- logger.info(
557
- f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
558
- )
559
-
560
- if self.spec_algorithm.is_eagle():
561
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
562
- from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
563
-
564
- WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
565
-
566
- self.draft_worker = WorkerClass(
567
- gpu_id=gpu_id,
568
- tp_rank=tp_rank,
569
- moe_ep_rank=moe_ep_rank,
570
- server_args=server_args,
571
- nccl_port=port_args.nccl_port,
572
- target_worker=self.tp_worker,
573
- dp_rank=dp_rank,
574
- )
575
- elif self.spec_algorithm.is_standalone():
576
- from sglang.srt.speculative.standalone_worker import StandaloneWorker
577
-
578
- self.draft_worker = StandaloneWorker(
579
- gpu_id=gpu_id,
580
- tp_rank=tp_rank,
581
- moe_ep_rank=moe_ep_rank,
582
- server_args=server_args,
583
- nccl_port=port_args.nccl_port,
584
- target_worker=self.tp_worker,
585
- dp_rank=dp_rank,
586
- )
587
- elif self.spec_algorithm.is_ngram():
588
- from sglang.srt.speculative.ngram_worker import NGRAMWorker
589
-
590
- self.draft_worker = NGRAMWorker(
591
- gpu_id=gpu_id,
592
- tp_rank=tp_rank,
593
- moe_ep_rank=moe_ep_rank,
594
- server_args=server_args,
595
- nccl_port=port_args.nccl_port,
596
- target_worker=self.tp_worker,
597
- dp_rank=dp_rank,
598
- )
599
- else:
600
- self.draft_worker = None
601
-
602
592
  def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
603
593
  context = zmq.Context(2)
604
594
  self.idle_sleeper = None
@@ -1162,6 +1152,70 @@ class Scheduler(
1162
1152
  self.max_req_len - len(req.origin_input_ids) - 1,
1163
1153
  )
1164
1154
 
1155
+ def _process_and_broadcast_mm_inputs(
1156
+ self,
1157
+ raw_mm_inputs: Optional[dict],
1158
+ ):
1159
+ """Materialize MultimodalInputs once on the entry rank and broadcast to others.
1160
+
1161
+ Entry rank:
1162
+ - constructs MultimodalInputs.from_dict(raw_mm_inputs) once
1163
+ - broadcasts to other ranks in self.cpu_group (if world_size > 1)
1164
+
1165
+ Non-entry ranks:
1166
+ - receive the object via broadcast (if world_size > 1)
1167
+ - otherwise (single-rank / no group) fall back to local from_dict
1168
+
1169
+ Returns:
1170
+ MultimodalInputs | None
1171
+ """
1172
+ if raw_mm_inputs is None:
1173
+ return None
1174
+
1175
+ group_world_size = 1
1176
+ try:
1177
+ if (
1178
+ torch.distributed.is_available()
1179
+ and torch.distributed.is_initialized()
1180
+ and self.cpu_group is not None
1181
+ ):
1182
+ group_world_size = torch.distributed.get_world_size(
1183
+ group=self.cpu_group
1184
+ )
1185
+ except Exception as e:
1186
+ logger.warning(
1187
+ f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
1188
+ )
1189
+
1190
+ # In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
1191
+ # process in CPU which occupies the main thread CPU cycle. This computing logic
1192
+ # merely needs to be run on TP0 and be broadcast to other TP ranks.
1193
+ # Since the Scheduler is single-threaded, any large CPU cost will impact
1194
+ # handling of other messages. For example, CPU hits 99.9% can significantly
1195
+ # increase the CUDA kernel launch time.
1196
+ if self.is_entry_rank:
1197
+ # Only the entry rank materializes once from dict.
1198
+ image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
1199
+ # Broadcast to other TP ranks (use src=0 within the group).
1200
+ if group_world_size > 1:
1201
+ obj_list = [image_inputs]
1202
+ torch.distributed.broadcast_object_list(
1203
+ obj_list, src=0, group=self.cpu_group
1204
+ )
1205
+ image_inputs = obj_list[0]
1206
+ else:
1207
+ # Non-entry ranks: receive if group size > 1; otherwise materialize locally.
1208
+ if group_world_size > 1:
1209
+ obj_list = [None]
1210
+ torch.distributed.broadcast_object_list(
1211
+ obj_list, src=0, group=self.cpu_group
1212
+ )
1213
+ image_inputs = obj_list[0]
1214
+ else:
1215
+ image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
1216
+
1217
+ return image_inputs
1218
+
1165
1219
  def handle_generate_request(
1166
1220
  self,
1167
1221
  recv_req: TokenizedGenerateReqInput,
@@ -1243,7 +1297,9 @@ class Scheduler(
1243
1297
 
1244
1298
  # Handle multimodal inputs
1245
1299
  if recv_req.mm_inputs is not None:
1246
- image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1300
+ image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)
1301
+
1302
+ # The following steps are already fast, execute locally on each rank.
1247
1303
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
1248
1304
  req.origin_input_ids = self.pad_input_ids_func(
1249
1305
  req.origin_input_ids, image_inputs
@@ -1376,7 +1432,7 @@ class Scheduler(
1376
1432
  self._prefetch_kvcache(req)
1377
1433
  self.waiting_queue.append(req)
1378
1434
  req.time_stats.wait_queue_entry_time = time.perf_counter()
1379
- trace_slice_end("process req", req.rid, auto_next_anon=True)
1435
+ trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
1380
1436
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1381
1437
  self._prefetch_kvcache(req)
1382
1438
  self.disagg_prefill_bootstrap_queue.add(
@@ -1466,13 +1522,14 @@ class Scheduler(
1466
1522
  recv_req.sampling_params,
1467
1523
  token_type_ids=recv_req.token_type_ids,
1468
1524
  priority=recv_req.priority,
1525
+ dimensions=recv_req.dimensions,
1469
1526
  http_worker_ipc=recv_req.http_worker_ipc,
1470
1527
  )
1471
1528
  req.tokenizer = self.tokenizer
1472
1529
 
1473
1530
  # Handle multimodal inputs
1474
1531
  if recv_req.image_inputs is not None:
1475
- image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1532
+ image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
1476
1533
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
1477
1534
  req.origin_input_ids = self.pad_input_ids_func(
1478
1535
  req.origin_input_ids, image_inputs
@@ -1639,6 +1696,10 @@ class Scheduler(
1639
1696
  if need_dp_attn_preparation:
1640
1697
  ret = self.prepare_mlp_sync_batch(ret)
1641
1698
 
1699
+ if ret:
1700
+ attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
1701
+ trace_event_batch("schedule", ret.reqs, attrs=attrs)
1702
+
1642
1703
  return ret
1643
1704
 
1644
1705
  def get_num_allocatable_reqs(self, running_bs):
@@ -1682,6 +1743,12 @@ class Scheduler(
1682
1743
  # Get priority queue
1683
1744
  self.policy.calc_priority(self.waiting_queue)
1684
1745
 
1746
+ if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
1747
+ # If we are testing retraction and the running batch size exceeds
1748
+ # TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
1749
+ # in the waiting queue.
1750
+ return None
1751
+
1685
1752
  # Prefill policy
1686
1753
  adder = PrefillAdder(
1687
1754
  self.page_size,
@@ -1848,14 +1915,14 @@ class Scheduler(
1848
1915
  self.num_retracted_reqs = len(retracted_reqs)
1849
1916
  self.new_token_ratio = new_token_ratio
1850
1917
  for req in reqs_to_abort:
1918
+ abort_reason: FINISH_ABORT = req.to_finish
1851
1919
  self.send_to_tokenizer.send_output(
1852
- AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
1920
+ AbortReq(abort_message=abort_reason.message, rid=req.rid), req
1853
1921
  )
1854
1922
 
1855
1923
  logger.info(
1856
1924
  "KV cache pool is full. Retract requests. "
1857
1925
  f"#retracted_reqs: {len(retracted_reqs)}, "
1858
- f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
1859
1926
  f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1860
1927
  )
1861
1928
 
@@ -1894,7 +1961,6 @@ class Scheduler(
1894
1961
 
1895
1962
  # Run forward
1896
1963
  if self.is_generation:
1897
-
1898
1964
  batch_or_worker_batch = batch
1899
1965
 
1900
1966
  if self.enable_overlap or self.spec_algorithm.is_none():
@@ -1951,6 +2017,9 @@ class Scheduler(
1951
2017
  # The future value, usually for next batch preparation
1952
2018
  # Current implementation strictly synchronizes the seq_lens
1953
2019
  batch.seq_lens = batch_result.next_draft_input.new_seq_lens
2020
+ elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
2021
+ batch_result = self.tp_worker.forward_batch_split_prefill(batch)
2022
+ future_indices_or_next_token_ids = batch_result.next_token_ids
1954
2023
  else:
1955
2024
  batch_result = self.model_worker.forward_batch_generation(
1956
2025
  batch_or_worker_batch
@@ -2012,13 +2081,10 @@ class Scheduler(
2012
2081
  ):
2013
2082
  if batch.forward_mode.is_decode():
2014
2083
  self.process_batch_result_decode(batch, result)
2015
- if self.enable_trace:
2016
- trace_slice_batch("decode loop", batch.reqs)
2084
+ trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
2017
2085
 
2018
2086
  elif batch.forward_mode.is_extend():
2019
2087
  self.process_batch_result_prefill(batch, result)
2020
- if self.enable_trace:
2021
- trace_slice_batch("prefill", batch.reqs)
2022
2088
 
2023
2089
  elif batch.forward_mode.is_idle():
2024
2090
  if self.enable_overlap:
@@ -2073,15 +2139,18 @@ class Scheduler(
2073
2139
  num_tokens_for_logprob = num_tokens
2074
2140
  else:
2075
2141
  num_tokens = local_batch.extend_num_tokens
2076
- num_tokens_for_logprob = sum(
2077
- [
2142
+ if local_batch.return_logprob:
2143
+ num_tokens_for_logprob = sum(
2078
2144
  # We should have at least 1 token for sample in every case.
2079
2145
  max(extend_len - logprob_start_len, 1)
2080
2146
  for logprob_start_len, extend_len in zip(
2081
- local_batch.extend_logprob_start_lens, local_batch.extend_lens
2147
+ local_batch.extend_logprob_start_lens,
2148
+ local_batch.extend_lens,
2082
2149
  )
2083
- ]
2084
- )
2150
+ )
2151
+ else:
2152
+ # When return_logprob = False, only need last token per request
2153
+ num_tokens_for_logprob = local_batch.batch_size()
2085
2154
 
2086
2155
  if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
2087
2156
  can_cuda_graph = 1
@@ -2235,59 +2304,6 @@ class Scheduler(
2235
2304
  self._add_request_to_queue(req)
2236
2305
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2237
2306
 
2238
- def watchdog_thread(self):
2239
- """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2240
- self.watchdog_last_forward_ct = 0
2241
- self.watchdog_last_time = time.perf_counter()
2242
-
2243
- while True:
2244
- current = time.perf_counter()
2245
- if self.cur_batch is not None:
2246
- if self.watchdog_last_forward_ct == self.forward_ct:
2247
- if current > self.watchdog_last_time + self.watchdog_timeout:
2248
- break
2249
- else:
2250
- self.watchdog_last_forward_ct = self.forward_ct
2251
- self.watchdog_last_time = current
2252
- time.sleep(self.watchdog_timeout // 2)
2253
-
2254
- if not disable_request_logging():
2255
- # Print batch size and memory pool info to check whether there are de-sync issues.
2256
- if self.is_hybrid:
2257
- (
2258
- _,
2259
- _,
2260
- _,
2261
- _,
2262
- full_available_size,
2263
- full_evictable_size,
2264
- swa_available_size,
2265
- swa_evictable_size,
2266
- ) = self._get_swa_token_info()
2267
- info_msg = (
2268
- f"{full_available_size=}, "
2269
- f"{full_evictable_size=}, "
2270
- f"{swa_available_size=}, "
2271
- f"{swa_evictable_size=}, "
2272
- )
2273
- else:
2274
- _, _, available_size, evictable_size = self._get_token_info()
2275
- info_msg = f"{available_size=}, " f"{evictable_size=}, "
2276
- logger.error(
2277
- f"{self.cur_batch.batch_size()=}, "
2278
- f"{self.cur_batch.reqs=}, "
2279
- f"{info_msg}"
2280
- )
2281
-
2282
- pyspy_dump_schedulers()
2283
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2284
- print(file=sys.stderr, flush=True)
2285
- print(file=sys.stdout, flush=True)
2286
-
2287
- # Wait for some time so that the parent process can print the error.
2288
- time.sleep(5)
2289
- self.parent_process.send_signal(signal.SIGQUIT)
2290
-
2291
2307
  def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
2292
2308
  success = self.flush_cache()
2293
2309
  return FlushCacheReqOutput(success=success)
@@ -2302,13 +2318,30 @@ class Scheduler(
2302
2318
  if_success = False
2303
2319
  return ClearHiCacheReqOutput(success=if_success)
2304
2320
 
2305
- def flush_cache(self):
2306
- """Flush the memory pool and cache."""
2307
- if (
2321
+ def _is_no_request(self):
2322
+ no_request = (
2308
2323
  len(self.waiting_queue) == 0
2309
2324
  and self.running_batch.is_empty()
2325
+ and (self.last_batch is None or self.last_batch.is_empty())
2326
+ and (self.cur_batch is None or self.cur_batch.is_empty())
2327
+ and (not self.enable_overlap or len(self.result_queue) == 0)
2310
2328
  and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
2311
- ):
2329
+ )
2330
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
2331
+ no_request &= (
2332
+ len(self.disagg_prefill_bootstrap_queue.queue) == 0
2333
+ and len(self.disagg_prefill_inflight_queue) == 0
2334
+ )
2335
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2336
+ no_request &= (
2337
+ len(self.disagg_decode_prealloc_queue.queue) == 0
2338
+ and len(self.disagg_decode_transfer_queue.queue) == 0
2339
+ )
2340
+ return no_request
2341
+
2342
+ def flush_cache(self):
2343
+ """Flush the memory pool and cache."""
2344
+ if self._is_no_request():
2312
2345
  self.cur_batch = None
2313
2346
  self.last_batch = None
2314
2347
  self.tree_cache.reset()
@@ -2322,10 +2355,10 @@ class Scheduler(
2322
2355
 
2323
2356
  self.num_generated_tokens = 0
2324
2357
  self.forward_ct_decode = 0
2325
- self.spec_num_total_accepted_tokens = 0
2326
- self.spec_num_total_forward_ct = 0
2327
- self.cum_spec_accept_length = 0
2328
- self.cum_spec_accept_count = 0
2358
+ self.spec_num_accepted_tokens = 0
2359
+ self.spec_num_forward_ct = 0
2360
+ self.spec_total_num_accepted_tokens = 0
2361
+ self.spec_total_num_forward_ct = 0
2329
2362
  torch.cuda.empty_cache()
2330
2363
  logger.info("Cache flushed successfully!")
2331
2364
  if_success = True
@@ -2398,13 +2431,16 @@ class Scheduler(
2398
2431
  self.tp_worker.model_runner.graph_mem_usage, 2
2399
2432
  )
2400
2433
 
2401
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2434
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2402
2435
  ret["avg_spec_accept_length"] = (
2403
- self.cum_spec_accept_length / self.cum_spec_accept_count
2436
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2404
2437
  )
2405
2438
  if RECORD_STEP_TIME:
2406
2439
  ret["step_time_dict"] = self.step_time_dict
2407
2440
 
2441
+ # This field is not serializable.
2442
+ ret.pop("model_config", None)
2443
+
2408
2444
  return GetInternalStateReqOutput(internal_state=ret)
2409
2445
 
2410
2446
  def set_internal_state(self, recv_req: SetInternalStateReq):
@@ -2431,12 +2467,12 @@ class Scheduler(
2431
2467
  if_success = False
2432
2468
  break
2433
2469
  if if_success:
2434
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2470
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2435
2471
  avg_spec_accept_length = (
2436
- self.cum_spec_accept_length / self.cum_spec_accept_count
2472
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2437
2473
  )
2438
2474
  logger.info(f"{avg_spec_accept_length=}")
2439
- self.cum_spec_accept_length = self.cum_spec_accept_count = 0
2475
+ self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
2440
2476
  for k, v in server_args_dict.items():
2441
2477
  setattr(get_global_server_args(), k, v)
2442
2478
  logger.info(f"Global server args updated! {get_global_server_args()=}")
@@ -2539,11 +2575,11 @@ class Scheduler(
2539
2575
  if not req.finished() and (
2540
2576
  recv_req.abort_all or req.rid.startswith(recv_req.rid)
2541
2577
  ):
2542
- # Abort method 3: set `to_abort=True`
2578
+ # Abort method 3: set `to_finish`
2543
2579
  # The request will still run one decode forward pass.
2544
2580
  # Then we reuse all existing code to clean up the KV cache allocation.
2545
2581
  logger.debug(f"Abort running request. {req.rid=}")
2546
- req.to_abort = True
2582
+ req.to_finish = FINISH_ABORT()
2547
2583
 
2548
2584
  def _pause_engine(self) -> Tuple[List[Req], int]:
2549
2585
  raise NotImplementedError()
@@ -2737,10 +2773,13 @@ def run_scheduler_process(
2737
2773
 
2738
2774
  # Set up tracing
2739
2775
  if server_args.enable_trace:
2740
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2741
- if server_args.disaggregation_mode == "null":
2742
- thread_label = "Scheduler"
2743
- trace_set_thread_info(thread_label, tp_rank, dp_rank)
2776
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
2777
+ thread_label = "Scheduler"
2778
+ if server_args.disaggregation_mode == "prefill":
2779
+ thread_label = "Prefill Scheduler"
2780
+ elif server_args.disaggregation_mode == "decode":
2781
+ thread_label = "Decode Scheduler"
2782
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2744
2783
 
2745
2784
  # Create a scheduler and run the event loop
2746
2785
  try:
@@ -2763,7 +2802,9 @@ def run_scheduler_process(
2763
2802
 
2764
2803
  disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2765
2804
  if disaggregation_mode == DisaggregationMode.NULL:
2766
- if server_args.pp_size > 1:
2805
+ if scheduler.enable_pdmux:
2806
+ scheduler.event_loop_pdmux()
2807
+ elif server_args.pp_size > 1:
2767
2808
  scheduler.event_loop_pp()
2768
2809
  elif scheduler.enable_overlap:
2769
2810
  scheduler.event_loop_overlap()
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional
7
7
 
8
8
  from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
9
9
  from sglang.srt.disaggregation.utils import DisaggregationMode
10
+ from sglang.srt.environ import envs
10
11
  from sglang.srt.managers.schedule_policy import PrefillAdder
11
12
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
12
13
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -18,6 +19,7 @@ if TYPE_CHECKING:
18
19
  logger = logging.getLogger(__name__)
19
20
 
20
21
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
22
+ LOG_FORWARD_ITERS = envs.SGLANG_LOG_FORWARD_ITERS.get()
21
23
 
22
24
 
23
25
  class KvMetrics:
@@ -39,10 +41,13 @@ class SchedulerMetricsMixin:
39
41
  self.last_gen_throughput: float = 0.0
40
42
  self.last_input_throughput: float = 0.0
41
43
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
42
- self.spec_num_total_accepted_tokens = 0
43
- self.spec_num_total_forward_ct = 0
44
- self.cum_spec_accept_length = 0
45
- self.cum_spec_accept_count = 0
44
+
45
+ # The number of accepted tokens and forward ct for the recent `decode_log_interval` batches (for logging)
46
+ self.spec_num_accepted_tokens = 0
47
+ self.spec_num_forward_ct = 0
48
+ # The total number of accepted tokens and forward ct for the whole server lifetime
49
+ self.spec_total_num_accepted_tokens = 0
50
+ self.spec_total_num_forward_ct = 0
46
51
  self.kv_transfer_speed_gb_s: float = 0.0
47
52
  self.kv_transfer_latency_ms: float = 0.0
48
53
 
@@ -67,8 +72,8 @@ class SchedulerMetricsMixin:
67
72
  )
68
73
 
69
74
  def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
70
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
71
- self.spec_num_total_forward_ct += bs
75
+ self.spec_num_accepted_tokens += num_accepted_tokens + bs
76
+ self.spec_num_forward_ct += bs
72
77
  self.num_generated_tokens += num_accepted_tokens
73
78
 
74
79
  def log_prefill_stats(
@@ -122,8 +127,10 @@ class SchedulerMetricsMixin:
122
127
  num_used, token_usage, _, _ = self._get_token_info()
123
128
  token_usage_msg = f"token usage: {token_usage:.2f}, "
124
129
 
130
+ iter_msg = f" [{self.forward_ct + 1}]" if LOG_FORWARD_ITERS else ""
131
+
125
132
  f = (
126
- f"Prefill batch [{self.forward_ct + 1}], "
133
+ f"Prefill batch{iter_msg}, "
127
134
  f"#new-seq: {len(can_run_list)}, "
128
135
  f"#new-token: {adder.log_input_tokens}, "
129
136
  f"#cached-token: {adder.log_hit_tokens}, "
@@ -246,27 +253,28 @@ class SchedulerMetricsMixin:
246
253
  gap_latency / self.server_args.decode_log_interval
247
254
  )
248
255
 
249
- msg = f"Decode batch [{self.forward_ct}], #running-req: {num_running_reqs}, {token_usage_msg}"
256
+ iter_msg = f" [{self.forward_ct}]" if LOG_FORWARD_ITERS else ""
257
+ msg = f"Decode batch{iter_msg}, #running-req: {num_running_reqs}, {token_usage_msg}"
250
258
 
251
259
  if self.spec_algorithm.is_none():
252
260
  spec_accept_length = 0
253
261
  spec_accept_rate = 0
254
262
  else:
255
263
  spec_accept_length = (
256
- self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
264
+ self.spec_num_accepted_tokens / self.spec_num_forward_ct
257
265
  )
258
266
  # Calculate acceptance rate: accepted tokens / total draft tokens
259
- total_draft_tokens = self.spec_num_total_forward_ct * (
267
+ total_draft_tokens = self.spec_num_forward_ct * (
260
268
  (self.server_args.speculative_num_steps or 0) + 1
261
269
  )
262
270
  spec_accept_rate = (
263
- self.spec_num_total_accepted_tokens / total_draft_tokens
271
+ self.spec_num_accepted_tokens / total_draft_tokens
264
272
  if total_draft_tokens > 0
265
273
  else 0
266
274
  )
267
- self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
268
- self.cum_spec_accept_count += self.spec_num_total_forward_ct
269
- self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
275
+ self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens
276
+ self.spec_total_num_forward_ct += self.spec_num_forward_ct
277
+ self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0
270
278
  msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
271
279
  cache_hit_rate = 0.0
272
280