sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/managers/scheduler.py
    CHANGED
    
    | 
         @@ -29,6 +29,7 @@ from typing import Deque, Dict, List, Optional, Tuple, Union 
     | 
|
| 
       29 
29 
     | 
    
         
             
            import psutil
         
     | 
| 
       30 
30 
     | 
    
         
             
            import setproctitle
         
     | 
| 
       31 
31 
     | 
    
         
             
            import torch
         
     | 
| 
      
 32 
     | 
    
         
            +
            import torch.distributed
         
     | 
| 
       32 
33 
     | 
    
         
             
            import zmq
         
     | 
| 
       33 
34 
     | 
    
         
             
            from torch.cuda import Stream as CudaStream
         
     | 
| 
       34 
35 
     | 
    
         
             
            from torch.cuda import StreamContext as CudaStreamContext
         
     | 
| 
         @@ -151,11 +152,13 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache 
     | 
|
| 
       151 
152 
     | 
    
         
             
            from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
         
     | 
| 
       152 
153 
     | 
    
         
             
            from sglang.srt.mem_cache.radix_cache import RadixCache
         
     | 
| 
       153 
154 
     | 
    
         
             
            from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
         
     | 
| 
      
 155 
     | 
    
         
            +
            from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
         
     | 
| 
       154 
156 
     | 
    
         
             
            from sglang.srt.parser.reasoning_parser import ReasoningParser
         
     | 
| 
       155 
157 
     | 
    
         
             
            from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
         
     | 
| 
       156 
158 
     | 
    
         
             
            from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
         
     | 
| 
       157 
159 
     | 
    
         
             
            from sglang.srt.tracing.trace import (
         
     | 
| 
       158 
160 
     | 
    
         
             
                process_tracing_init,
         
     | 
| 
      
 161 
     | 
    
         
            +
                trace_event_batch,
         
     | 
| 
       159 
162 
     | 
    
         
             
                trace_set_proc_propagate_context,
         
     | 
| 
       160 
163 
     | 
    
         
             
                trace_set_thread_info,
         
     | 
| 
       161 
164 
     | 
    
         
             
                trace_slice_batch,
         
     | 
| 
         @@ -168,7 +171,6 @@ from sglang.srt.utils import ( 
     | 
|
| 
       168 
171 
     | 
    
         
             
                broadcast_pyobj,
         
     | 
| 
       169 
172 
     | 
    
         
             
                configure_gc_logger,
         
     | 
| 
       170 
173 
     | 
    
         
             
                configure_logger,
         
     | 
| 
       171 
     | 
    
         
            -
                disable_request_logging,
         
     | 
| 
       172 
174 
     | 
    
         
             
                freeze_gc,
         
     | 
| 
       173 
175 
     | 
    
         
             
                get_available_gpu_memory,
         
     | 
| 
       174 
176 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
         @@ -177,7 +179,6 @@ from sglang.srt.utils import ( 
     | 
|
| 
       177 
179 
     | 
    
         
             
                kill_itself_when_parent_died,
         
     | 
| 
       178 
180 
     | 
    
         
             
                numa_bind_to_node,
         
     | 
| 
       179 
181 
     | 
    
         
             
                point_to_point_pyobj,
         
     | 
| 
       180 
     | 
    
         
            -
                pyspy_dump_schedulers,
         
     | 
| 
       181 
182 
     | 
    
         
             
                require_mlp_sync,
         
     | 
| 
       182 
183 
     | 
    
         
             
                require_mlp_tp_gather,
         
     | 
| 
       183 
184 
     | 
    
         
             
                set_gpu_proc_affinity,
         
     | 
| 
         @@ -197,6 +198,7 @@ logger = logging.getLogger(__name__) 
     | 
|
| 
       197 
198 
     | 
    
         
             
            # Test retract decode for debugging purposes
         
     | 
| 
       198 
199 
     | 
    
         
             
            TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
         
     | 
| 
       199 
200 
     | 
    
         
             
            TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
         
     | 
| 
      
 201 
     | 
    
         
            +
            TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
         
     | 
| 
       200 
202 
     | 
    
         
             
            GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
         
     | 
| 
       201 
203 
     | 
    
         | 
| 
       202 
204 
     | 
    
         | 
| 
         @@ -212,6 +214,7 @@ class Scheduler( 
     | 
|
| 
       212 
214 
     | 
    
         
             
                SchedulerMetricsMixin,
         
     | 
| 
       213 
215 
     | 
    
         
             
                SchedulerDisaggregationDecodeMixin,
         
     | 
| 
       214 
216 
     | 
    
         
             
                SchedulerDisaggregationPrefillMixin,
         
     | 
| 
      
 217 
     | 
    
         
            +
                SchedulerMultiplexMixin,
         
     | 
| 
       215 
218 
     | 
    
         
             
                SchedulerRuntimeCheckerMixin,
         
     | 
| 
       216 
219 
     | 
    
         
             
                SchedulerPPMixin,
         
     | 
| 
       217 
220 
     | 
    
         
             
            ):
         
     | 
| 
         @@ -251,6 +254,7 @@ class Scheduler( 
     | 
|
| 
       251 
254 
     | 
    
         
             
                    self.enable_lora = server_args.enable_lora
         
     | 
| 
       252 
255 
     | 
    
         
             
                    self.max_loras_per_batch = server_args.max_loras_per_batch
         
     | 
| 
       253 
256 
     | 
    
         
             
                    self.enable_overlap = not server_args.disable_overlap_schedule
         
     | 
| 
      
 257 
     | 
    
         
            +
                    self.enable_pdmux = server_args.enable_pdmux
         
     | 
| 
       254 
258 
     | 
    
         
             
                    self.skip_tokenizer_init = server_args.skip_tokenizer_init
         
     | 
| 
       255 
259 
     | 
    
         
             
                    self.enable_metrics = server_args.enable_metrics
         
     | 
| 
       256 
260 
     | 
    
         
             
                    self.enable_metrics_for_all_schedulers = (
         
     | 
| 
         @@ -284,6 +288,10 @@ class Scheduler( 
     | 
|
| 
       284 
288 
     | 
    
         
             
                    # Init inter-process communication
         
     | 
| 
       285 
289 
     | 
    
         
             
                    self.init_sockets(server_args, port_args)
         
     | 
| 
       286 
290 
     | 
    
         | 
| 
      
 291 
     | 
    
         
            +
                    # Init pdmux context
         
     | 
| 
      
 292 
     | 
    
         
            +
                    if self.enable_pdmux:
         
     | 
| 
      
 293 
     | 
    
         
            +
                        self.init_pdmux()
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
       287 
295 
     | 
    
         
             
                    # Init tokenizer
         
     | 
| 
       288 
296 
     | 
    
         
             
                    self.init_tokenizer()
         
     | 
| 
       289 
297 
     | 
    
         | 
| 
         @@ -320,8 +328,28 @@ class Scheduler( 
     | 
|
| 
       320 
328 
     | 
    
         | 
| 
       321 
329 
     | 
    
         
             
                    # Launch a draft worker for speculative decoding
         
     | 
| 
       322 
330 
     | 
    
         | 
| 
       323 
     | 
    
         
            -
                     
     | 
| 
       324 
     | 
    
         
            -
                        gpu_id, 
     | 
| 
      
 331 
     | 
    
         
            +
                    draft_worker_kwargs = dict(
         
     | 
| 
      
 332 
     | 
    
         
            +
                        gpu_id=gpu_id,
         
     | 
| 
      
 333 
     | 
    
         
            +
                        tp_rank=tp_rank,
         
     | 
| 
      
 334 
     | 
    
         
            +
                        moe_ep_rank=moe_ep_rank,
         
     | 
| 
      
 335 
     | 
    
         
            +
                        server_args=server_args,
         
     | 
| 
      
 336 
     | 
    
         
            +
                        nccl_port=port_args.nccl_port,
         
     | 
| 
      
 337 
     | 
    
         
            +
                        target_worker=self.tp_worker,
         
     | 
| 
      
 338 
     | 
    
         
            +
                        dp_rank=dp_rank,
         
     | 
| 
      
 339 
     | 
    
         
            +
                    )
         
     | 
| 
      
 340 
     | 
    
         
            +
             
     | 
| 
      
 341 
     | 
    
         
            +
                    if server_args.speculative_draft_load_format is not None:
         
     | 
| 
      
 342 
     | 
    
         
            +
                        server_args.load_format = server_args.speculative_draft_load_format
         
     | 
| 
      
 343 
     | 
    
         
            +
                        logger.info(
         
     | 
| 
      
 344 
     | 
    
         
            +
                            f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
         
     | 
| 
      
 345 
     | 
    
         
            +
                        )
         
     | 
| 
      
 346 
     | 
    
         
            +
             
     | 
| 
      
 347 
     | 
    
         
            +
                    # Draft workers are looked up via `SpeculativeAlgorithm` registry; new
         
     | 
| 
      
 348 
     | 
    
         
            +
                    # algorithms should register their factory instead of patching this code.
         
     | 
| 
      
 349 
     | 
    
         
            +
                    if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
         
     | 
| 
      
 350 
     | 
    
         
            +
                        draft_worker_kwargs["enable_overlap"] = self.enable_overlap
         
     | 
| 
      
 351 
     | 
    
         
            +
                    self.draft_worker = self.spec_algorithm.create_draft_worker(
         
     | 
| 
      
 352 
     | 
    
         
            +
                        **draft_worker_kwargs
         
     | 
| 
       325 
353 
     | 
    
         
             
                    )
         
     | 
| 
       326 
354 
     | 
    
         | 
| 
       327 
355 
     | 
    
         
             
                    # Dispatch the model worker
         
     | 
| 
         @@ -356,6 +384,17 @@ class Scheduler( 
     | 
|
| 
       356 
384 
     | 
    
         
             
                    self.pp_group = get_pp_group()
         
     | 
| 
       357 
385 
     | 
    
         
             
                    self.world_group = get_world_group()
         
     | 
| 
       358 
386 
     | 
    
         | 
| 
      
 387 
     | 
    
         
            +
                    # With DP attention enabled, the entry rank is attn_tp_rank==0;
         
     | 
| 
      
 388 
     | 
    
         
            +
                    # otherwise the entry rank is TP group local rank 0.
         
     | 
| 
      
 389 
     | 
    
         
            +
                    # For #11910, use the CPU communication group to broadcast VLM Python objects,
         
     | 
| 
      
 390 
     | 
    
         
            +
                    # avoiding any coupling with CUDA streams/devices.
         
     | 
| 
      
 391 
     | 
    
         
            +
                    if self.server_args.enable_dp_attention:
         
     | 
| 
      
 392 
     | 
    
         
            +
                        self.cpu_group = self.attn_tp_cpu_group
         
     | 
| 
      
 393 
     | 
    
         
            +
                        self.is_entry_rank = self.attn_tp_rank == 0
         
     | 
| 
      
 394 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 395 
     | 
    
         
            +
                        self.cpu_group = self.tp_cpu_group
         
     | 
| 
      
 396 
     | 
    
         
            +
                        self.is_entry_rank = self.tp_group.rank == 0
         
     | 
| 
      
 397 
     | 
    
         
            +
             
     | 
| 
       359 
398 
     | 
    
         
             
                    self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
         
     | 
| 
       360 
399 
     | 
    
         
             
                    set_random_seed(self.random_seed)
         
     | 
| 
       361 
400 
     | 
    
         | 
| 
         @@ -392,6 +431,8 @@ class Scheduler( 
     | 
|
| 
       392 
431 
     | 
    
         
             
                    self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
         
     | 
| 
       393 
432 
     | 
    
         
             
                    # The current forward batch
         
     | 
| 
       394 
433 
     | 
    
         
             
                    self.cur_batch: Optional[ScheduleBatch] = None
         
     | 
| 
      
 434 
     | 
    
         
            +
                    # The current split prefill batch
         
     | 
| 
      
 435 
     | 
    
         
            +
                    self.split_prefill_batch: Optional[ScheduleBatch] = None
         
     | 
| 
       395 
436 
     | 
    
         
             
                    # The last forward batch
         
     | 
| 
       396 
437 
     | 
    
         
             
                    self.last_batch: Optional[ScheduleBatch] = None
         
     | 
| 
       397 
438 
     | 
    
         
             
                    self.forward_ct = 0
         
     | 
| 
         @@ -548,57 +589,6 @@ class Scheduler( 
     | 
|
| 
       548 
589 
     | 
    
         
             
                        ]
         
     | 
| 
       549 
590 
     | 
    
         
             
                    )
         
     | 
| 
       550 
591 
     | 
    
         | 
| 
       551 
     | 
    
         
            -
                def launch_draft_worker(
         
     | 
| 
       552 
     | 
    
         
            -
                    self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
         
     | 
| 
       553 
     | 
    
         
            -
                ):
         
     | 
| 
       554 
     | 
    
         
            -
                    if server_args.speculative_draft_load_format is not None:
         
     | 
| 
       555 
     | 
    
         
            -
                        server_args.load_format = server_args.speculative_draft_load_format
         
     | 
| 
       556 
     | 
    
         
            -
                        logger.info(
         
     | 
| 
       557 
     | 
    
         
            -
                            f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
         
     | 
| 
       558 
     | 
    
         
            -
                        )
         
     | 
| 
       559 
     | 
    
         
            -
             
     | 
| 
       560 
     | 
    
         
            -
                    if self.spec_algorithm.is_eagle():
         
     | 
| 
       561 
     | 
    
         
            -
                        from sglang.srt.speculative.eagle_worker import EAGLEWorker
         
     | 
| 
       562 
     | 
    
         
            -
                        from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
         
     | 
| 
       563 
     | 
    
         
            -
             
     | 
| 
       564 
     | 
    
         
            -
                        WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
         
     | 
| 
       565 
     | 
    
         
            -
             
     | 
| 
       566 
     | 
    
         
            -
                        self.draft_worker = WorkerClass(
         
     | 
| 
       567 
     | 
    
         
            -
                            gpu_id=gpu_id,
         
     | 
| 
       568 
     | 
    
         
            -
                            tp_rank=tp_rank,
         
     | 
| 
       569 
     | 
    
         
            -
                            moe_ep_rank=moe_ep_rank,
         
     | 
| 
       570 
     | 
    
         
            -
                            server_args=server_args,
         
     | 
| 
       571 
     | 
    
         
            -
                            nccl_port=port_args.nccl_port,
         
     | 
| 
       572 
     | 
    
         
            -
                            target_worker=self.tp_worker,
         
     | 
| 
       573 
     | 
    
         
            -
                            dp_rank=dp_rank,
         
     | 
| 
       574 
     | 
    
         
            -
                        )
         
     | 
| 
       575 
     | 
    
         
            -
                    elif self.spec_algorithm.is_standalone():
         
     | 
| 
       576 
     | 
    
         
            -
                        from sglang.srt.speculative.standalone_worker import StandaloneWorker
         
     | 
| 
       577 
     | 
    
         
            -
             
     | 
| 
       578 
     | 
    
         
            -
                        self.draft_worker = StandaloneWorker(
         
     | 
| 
       579 
     | 
    
         
            -
                            gpu_id=gpu_id,
         
     | 
| 
       580 
     | 
    
         
            -
                            tp_rank=tp_rank,
         
     | 
| 
       581 
     | 
    
         
            -
                            moe_ep_rank=moe_ep_rank,
         
     | 
| 
       582 
     | 
    
         
            -
                            server_args=server_args,
         
     | 
| 
       583 
     | 
    
         
            -
                            nccl_port=port_args.nccl_port,
         
     | 
| 
       584 
     | 
    
         
            -
                            target_worker=self.tp_worker,
         
     | 
| 
       585 
     | 
    
         
            -
                            dp_rank=dp_rank,
         
     | 
| 
       586 
     | 
    
         
            -
                        )
         
     | 
| 
       587 
     | 
    
         
            -
                    elif self.spec_algorithm.is_ngram():
         
     | 
| 
       588 
     | 
    
         
            -
                        from sglang.srt.speculative.ngram_worker import NGRAMWorker
         
     | 
| 
       589 
     | 
    
         
            -
             
     | 
| 
       590 
     | 
    
         
            -
                        self.draft_worker = NGRAMWorker(
         
     | 
| 
       591 
     | 
    
         
            -
                            gpu_id=gpu_id,
         
     | 
| 
       592 
     | 
    
         
            -
                            tp_rank=tp_rank,
         
     | 
| 
       593 
     | 
    
         
            -
                            moe_ep_rank=moe_ep_rank,
         
     | 
| 
       594 
     | 
    
         
            -
                            server_args=server_args,
         
     | 
| 
       595 
     | 
    
         
            -
                            nccl_port=port_args.nccl_port,
         
     | 
| 
       596 
     | 
    
         
            -
                            target_worker=self.tp_worker,
         
     | 
| 
       597 
     | 
    
         
            -
                            dp_rank=dp_rank,
         
     | 
| 
       598 
     | 
    
         
            -
                        )
         
     | 
| 
       599 
     | 
    
         
            -
                    else:
         
     | 
| 
       600 
     | 
    
         
            -
                        self.draft_worker = None
         
     | 
| 
       601 
     | 
    
         
            -
             
     | 
| 
       602 
592 
     | 
    
         
             
                def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
         
     | 
| 
       603 
593 
     | 
    
         
             
                    context = zmq.Context(2)
         
     | 
| 
       604 
594 
     | 
    
         
             
                    self.idle_sleeper = None
         
     | 
| 
         @@ -1162,6 +1152,70 @@ class Scheduler( 
     | 
|
| 
       1162 
1152 
     | 
    
         
             
                        self.max_req_len - len(req.origin_input_ids) - 1,
         
     | 
| 
       1163 
1153 
     | 
    
         
             
                    )
         
     | 
| 
       1164 
1154 
     | 
    
         | 
| 
      
 1155 
     | 
    
         
            +
                def _process_and_broadcast_mm_inputs(
         
     | 
| 
      
 1156 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1157 
     | 
    
         
            +
                    raw_mm_inputs: Optional[dict],
         
     | 
| 
      
 1158 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1159 
     | 
    
         
            +
                    """Materialize MultimodalInputs once on the entry rank and broadcast to others.
         
     | 
| 
      
 1160 
     | 
    
         
            +
             
     | 
| 
      
 1161 
     | 
    
         
            +
                    Entry rank:
         
     | 
| 
      
 1162 
     | 
    
         
            +
                    - constructs MultimodalInputs.from_dict(raw_mm_inputs) once
         
     | 
| 
      
 1163 
     | 
    
         
            +
                    - broadcasts to other ranks in self.cpu_group (if world_size > 1)
         
     | 
| 
      
 1164 
     | 
    
         
            +
             
     | 
| 
      
 1165 
     | 
    
         
            +
                    Non-entry ranks:
         
     | 
| 
      
 1166 
     | 
    
         
            +
                    - receive the object via broadcast (if world_size > 1)
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    - otherwise (single-rank / no group) fall back to local from_dict
         
     | 
| 
      
 1168 
     | 
    
         
            +
             
     | 
| 
      
 1169 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1170 
     | 
    
         
            +
                        MultimodalInputs | None
         
     | 
| 
      
 1171 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1172 
     | 
    
         
            +
                    if raw_mm_inputs is None:
         
     | 
| 
      
 1173 
     | 
    
         
            +
                        return None
         
     | 
| 
      
 1174 
     | 
    
         
            +
             
     | 
| 
      
 1175 
     | 
    
         
            +
                    group_world_size = 1
         
     | 
| 
      
 1176 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 1177 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 1178 
     | 
    
         
            +
                            torch.distributed.is_available()
         
     | 
| 
      
 1179 
     | 
    
         
            +
                            and torch.distributed.is_initialized()
         
     | 
| 
      
 1180 
     | 
    
         
            +
                            and self.cpu_group is not None
         
     | 
| 
      
 1181 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 1182 
     | 
    
         
            +
                            group_world_size = torch.distributed.get_world_size(
         
     | 
| 
      
 1183 
     | 
    
         
            +
                                group=self.cpu_group
         
     | 
| 
      
 1184 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1185 
     | 
    
         
            +
                    except Exception as e:
         
     | 
| 
      
 1186 
     | 
    
         
            +
                        logger.warning(
         
     | 
| 
      
 1187 
     | 
    
         
            +
                            f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
         
     | 
| 
      
 1188 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1189 
     | 
    
         
            +
             
     | 
| 
      
 1190 
     | 
    
         
            +
                    # In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
         
     | 
| 
      
 1191 
     | 
    
         
            +
                    # process in CPU which occupies the main thread CPU cycle. This computing logic
         
     | 
| 
      
 1192 
     | 
    
         
            +
                    # merely needs to be run on TP0 and be broadcast to other TP ranks.
         
     | 
| 
      
 1193 
     | 
    
         
            +
                    # Since the Scheduler is single-threaded, any large CPU cost will impact
         
     | 
| 
      
 1194 
     | 
    
         
            +
                    # handling of other messages. For example, CPU hits 99.9% can significantly
         
     | 
| 
      
 1195 
     | 
    
         
            +
                    # increase the CUDA kernel launch time.
         
     | 
| 
      
 1196 
     | 
    
         
            +
                    if self.is_entry_rank:
         
     | 
| 
      
 1197 
     | 
    
         
            +
                        # Only the entry rank materializes once from dict.
         
     | 
| 
      
 1198 
     | 
    
         
            +
                        image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
         
     | 
| 
      
 1199 
     | 
    
         
            +
                        # Broadcast to other TP ranks (use src=0 within the group).
         
     | 
| 
      
 1200 
     | 
    
         
            +
                        if group_world_size > 1:
         
     | 
| 
      
 1201 
     | 
    
         
            +
                            obj_list = [image_inputs]
         
     | 
| 
      
 1202 
     | 
    
         
            +
                            torch.distributed.broadcast_object_list(
         
     | 
| 
      
 1203 
     | 
    
         
            +
                                obj_list, src=0, group=self.cpu_group
         
     | 
| 
      
 1204 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1205 
     | 
    
         
            +
                            image_inputs = obj_list[0]
         
     | 
| 
      
 1206 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1207 
     | 
    
         
            +
                        # Non-entry ranks: receive if group size > 1; otherwise materialize locally.
         
     | 
| 
      
 1208 
     | 
    
         
            +
                        if group_world_size > 1:
         
     | 
| 
      
 1209 
     | 
    
         
            +
                            obj_list = [None]
         
     | 
| 
      
 1210 
     | 
    
         
            +
                            torch.distributed.broadcast_object_list(
         
     | 
| 
      
 1211 
     | 
    
         
            +
                                obj_list, src=0, group=self.cpu_group
         
     | 
| 
      
 1212 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1213 
     | 
    
         
            +
                            image_inputs = obj_list[0]
         
     | 
| 
      
 1214 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1215 
     | 
    
         
            +
                            image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
         
     | 
| 
      
 1216 
     | 
    
         
            +
             
     | 
| 
      
 1217 
     | 
    
         
            +
                    return image_inputs
         
     | 
| 
      
 1218 
     | 
    
         
            +
             
     | 
| 
       1165 
1219 
     | 
    
         
             
                def handle_generate_request(
         
     | 
| 
       1166 
1220 
     | 
    
         
             
                    self,
         
     | 
| 
       1167 
1221 
     | 
    
         
             
                    recv_req: TokenizedGenerateReqInput,
         
     | 
| 
         @@ -1243,7 +1297,9 @@ class Scheduler( 
     | 
|
| 
       1243 
1297 
     | 
    
         | 
| 
       1244 
1298 
     | 
    
         
             
                    # Handle multimodal inputs
         
     | 
| 
       1245 
1299 
     | 
    
         
             
                    if recv_req.mm_inputs is not None:
         
     | 
| 
       1246 
     | 
    
         
            -
                        image_inputs =  
     | 
| 
      
 1300 
     | 
    
         
            +
                        image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)
         
     | 
| 
      
 1301 
     | 
    
         
            +
             
     | 
| 
      
 1302 
     | 
    
         
            +
                        # The following steps are already fast, execute locally on each rank.
         
     | 
| 
       1247 
1303 
     | 
    
         
             
                        # Expand a single image token into multiple dummy tokens for receiving image embeddings
         
     | 
| 
       1248 
1304 
     | 
    
         
             
                        req.origin_input_ids = self.pad_input_ids_func(
         
     | 
| 
       1249 
1305 
     | 
    
         
             
                            req.origin_input_ids, image_inputs
         
     | 
| 
         @@ -1376,7 +1432,7 @@ class Scheduler( 
     | 
|
| 
       1376 
1432 
     | 
    
         
             
                        self._prefetch_kvcache(req)
         
     | 
| 
       1377 
1433 
     | 
    
         
             
                        self.waiting_queue.append(req)
         
     | 
| 
       1378 
1434 
     | 
    
         
             
                        req.time_stats.wait_queue_entry_time = time.perf_counter()
         
     | 
| 
       1379 
     | 
    
         
            -
                        trace_slice_end( 
     | 
| 
      
 1435 
     | 
    
         
            +
                        trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
         
     | 
| 
       1380 
1436 
     | 
    
         
             
                    elif self.disaggregation_mode == DisaggregationMode.PREFILL:
         
     | 
| 
       1381 
1437 
     | 
    
         
             
                        self._prefetch_kvcache(req)
         
     | 
| 
       1382 
1438 
     | 
    
         
             
                        self.disagg_prefill_bootstrap_queue.add(
         
     | 
| 
         @@ -1466,13 +1522,14 @@ class Scheduler( 
     | 
|
| 
       1466 
1522 
     | 
    
         
             
                        recv_req.sampling_params,
         
     | 
| 
       1467 
1523 
     | 
    
         
             
                        token_type_ids=recv_req.token_type_ids,
         
     | 
| 
       1468 
1524 
     | 
    
         
             
                        priority=recv_req.priority,
         
     | 
| 
      
 1525 
     | 
    
         
            +
                        dimensions=recv_req.dimensions,
         
     | 
| 
       1469 
1526 
     | 
    
         
             
                        http_worker_ipc=recv_req.http_worker_ipc,
         
     | 
| 
       1470 
1527 
     | 
    
         
             
                    )
         
     | 
| 
       1471 
1528 
     | 
    
         
             
                    req.tokenizer = self.tokenizer
         
     | 
| 
       1472 
1529 
     | 
    
         | 
| 
       1473 
1530 
     | 
    
         
             
                    # Handle multimodal inputs
         
     | 
| 
       1474 
1531 
     | 
    
         
             
                    if recv_req.image_inputs is not None:
         
     | 
| 
       1475 
     | 
    
         
            -
                        image_inputs =  
     | 
| 
      
 1532 
     | 
    
         
            +
                        image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
         
     | 
| 
       1476 
1533 
     | 
    
         
             
                        # Expand a single image token into multiple dummy tokens for receiving image embeddings
         
     | 
| 
       1477 
1534 
     | 
    
         
             
                        req.origin_input_ids = self.pad_input_ids_func(
         
     | 
| 
       1478 
1535 
     | 
    
         
             
                            req.origin_input_ids, image_inputs
         
     | 
| 
         @@ -1639,6 +1696,10 @@ class Scheduler( 
     | 
|
| 
       1639 
1696 
     | 
    
         
             
                    if need_dp_attn_preparation:
         
     | 
| 
       1640 
1697 
     | 
    
         
             
                        ret = self.prepare_mlp_sync_batch(ret)
         
     | 
| 
       1641 
1698 
     | 
    
         | 
| 
      
 1699 
     | 
    
         
            +
                    if ret:
         
     | 
| 
      
 1700 
     | 
    
         
            +
                        attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
         
     | 
| 
      
 1701 
     | 
    
         
            +
                        trace_event_batch("schedule", ret.reqs, attrs=attrs)
         
     | 
| 
      
 1702 
     | 
    
         
            +
             
     | 
| 
       1642 
1703 
     | 
    
         
             
                    return ret
         
     | 
| 
       1643 
1704 
     | 
    
         | 
| 
       1644 
1705 
     | 
    
         
             
                def get_num_allocatable_reqs(self, running_bs):
         
     | 
| 
         @@ -1682,6 +1743,12 @@ class Scheduler( 
     | 
|
| 
       1682 
1743 
     | 
    
         
             
                    # Get priority queue
         
     | 
| 
       1683 
1744 
     | 
    
         
             
                    self.policy.calc_priority(self.waiting_queue)
         
     | 
| 
       1684 
1745 
     | 
    
         | 
| 
      
 1746 
     | 
    
         
            +
                    if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
         
     | 
| 
      
 1747 
     | 
    
         
            +
                        # If we are testing retraction and the running batch size exceeds
         
     | 
| 
      
 1748 
     | 
    
         
            +
                        # TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
         
     | 
| 
      
 1749 
     | 
    
         
            +
                        # in the waiting queue.
         
     | 
| 
      
 1750 
     | 
    
         
            +
                        return None
         
     | 
| 
      
 1751 
     | 
    
         
            +
             
     | 
| 
       1685 
1752 
     | 
    
         
             
                    # Prefill policy
         
     | 
| 
       1686 
1753 
     | 
    
         
             
                    adder = PrefillAdder(
         
     | 
| 
       1687 
1754 
     | 
    
         
             
                        self.page_size,
         
     | 
| 
         @@ -1848,14 +1915,14 @@ class Scheduler( 
     | 
|
| 
       1848 
1915 
     | 
    
         
             
                        self.num_retracted_reqs = len(retracted_reqs)
         
     | 
| 
       1849 
1916 
     | 
    
         
             
                        self.new_token_ratio = new_token_ratio
         
     | 
| 
       1850 
1917 
     | 
    
         
             
                        for req in reqs_to_abort:
         
     | 
| 
      
 1918 
     | 
    
         
            +
                            abort_reason: FINISH_ABORT = req.to_finish
         
     | 
| 
       1851 
1919 
     | 
    
         
             
                            self.send_to_tokenizer.send_output(
         
     | 
| 
       1852 
     | 
    
         
            -
                                AbortReq(abort_reason 
     | 
| 
      
 1920 
     | 
    
         
            +
                                AbortReq(abort_message=abort_reason.message, rid=req.rid), req
         
     | 
| 
       1853 
1921 
     | 
    
         
             
                            )
         
     | 
| 
       1854 
1922 
     | 
    
         | 
| 
       1855 
1923 
     | 
    
         
             
                        logger.info(
         
     | 
| 
       1856 
1924 
     | 
    
         
             
                            "KV cache pool is full. Retract requests. "
         
     | 
| 
       1857 
1925 
     | 
    
         
             
                            f"#retracted_reqs: {len(retracted_reqs)}, "
         
     | 
| 
       1858 
     | 
    
         
            -
                            f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
         
     | 
| 
       1859 
1926 
     | 
    
         
             
                            f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
         
     | 
| 
       1860 
1927 
     | 
    
         
             
                        )
         
     | 
| 
       1861 
1928 
     | 
    
         | 
| 
         @@ -1894,7 +1961,6 @@ class Scheduler( 
     | 
|
| 
       1894 
1961 
     | 
    
         | 
| 
       1895 
1962 
     | 
    
         
             
                    # Run forward
         
     | 
| 
       1896 
1963 
     | 
    
         
             
                    if self.is_generation:
         
     | 
| 
       1897 
     | 
    
         
            -
             
     | 
| 
       1898 
1964 
     | 
    
         
             
                        batch_or_worker_batch = batch
         
     | 
| 
       1899 
1965 
     | 
    
         | 
| 
       1900 
1966 
     | 
    
         
             
                        if self.enable_overlap or self.spec_algorithm.is_none():
         
     | 
| 
         @@ -1951,6 +2017,9 @@ class Scheduler( 
     | 
|
| 
       1951 
2017 
     | 
    
         
             
                                # The future value, usually for next batch preparation
         
     | 
| 
       1952 
2018 
     | 
    
         
             
                                # Current implementation strictly synchronizes the seq_lens
         
     | 
| 
       1953 
2019 
     | 
    
         
             
                                batch.seq_lens = batch_result.next_draft_input.new_seq_lens
         
     | 
| 
      
 2020 
     | 
    
         
            +
                        elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
         
     | 
| 
      
 2021 
     | 
    
         
            +
                            batch_result = self.tp_worker.forward_batch_split_prefill(batch)
         
     | 
| 
      
 2022 
     | 
    
         
            +
                            future_indices_or_next_token_ids = batch_result.next_token_ids
         
     | 
| 
       1954 
2023 
     | 
    
         
             
                        else:
         
     | 
| 
       1955 
2024 
     | 
    
         
             
                            batch_result = self.model_worker.forward_batch_generation(
         
     | 
| 
       1956 
2025 
     | 
    
         
             
                                batch_or_worker_batch
         
     | 
| 
         @@ -2012,13 +2081,10 @@ class Scheduler( 
     | 
|
| 
       2012 
2081 
     | 
    
         
             
                ):
         
     | 
| 
       2013 
2082 
     | 
    
         
             
                    if batch.forward_mode.is_decode():
         
     | 
| 
       2014 
2083 
     | 
    
         
             
                        self.process_batch_result_decode(batch, result)
         
     | 
| 
       2015 
     | 
    
         
            -
                         
     | 
| 
       2016 
     | 
    
         
            -
                            trace_slice_batch("decode loop", batch.reqs)
         
     | 
| 
      
 2084 
     | 
    
         
            +
                        trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
         
     | 
| 
       2017 
2085 
     | 
    
         | 
| 
       2018 
2086 
     | 
    
         
             
                    elif batch.forward_mode.is_extend():
         
     | 
| 
       2019 
2087 
     | 
    
         
             
                        self.process_batch_result_prefill(batch, result)
         
     | 
| 
       2020 
     | 
    
         
            -
                        if self.enable_trace:
         
     | 
| 
       2021 
     | 
    
         
            -
                            trace_slice_batch("prefill", batch.reqs)
         
     | 
| 
       2022 
2088 
     | 
    
         | 
| 
       2023 
2089 
     | 
    
         
             
                    elif batch.forward_mode.is_idle():
         
     | 
| 
       2024 
2090 
     | 
    
         
             
                        if self.enable_overlap:
         
     | 
| 
         @@ -2238,59 +2304,6 @@ class Scheduler( 
     | 
|
| 
       2238 
2304 
     | 
    
         
             
                        self._add_request_to_queue(req)
         
     | 
| 
       2239 
2305 
     | 
    
         
             
                    self.grammar_queue = self.grammar_queue[num_ready_reqs:]
         
     | 
| 
       2240 
2306 
     | 
    
         | 
| 
       2241 
     | 
    
         
            -
                def watchdog_thread(self):
         
     | 
| 
       2242 
     | 
    
         
            -
                    """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
         
     | 
| 
       2243 
     | 
    
         
            -
                    self.watchdog_last_forward_ct = 0
         
     | 
| 
       2244 
     | 
    
         
            -
                    self.watchdog_last_time = time.perf_counter()
         
     | 
| 
       2245 
     | 
    
         
            -
             
     | 
| 
       2246 
     | 
    
         
            -
                    while True:
         
     | 
| 
       2247 
     | 
    
         
            -
                        current = time.perf_counter()
         
     | 
| 
       2248 
     | 
    
         
            -
                        if self.cur_batch is not None:
         
     | 
| 
       2249 
     | 
    
         
            -
                            if self.watchdog_last_forward_ct == self.forward_ct:
         
     | 
| 
       2250 
     | 
    
         
            -
                                if current > self.watchdog_last_time + self.watchdog_timeout:
         
     | 
| 
       2251 
     | 
    
         
            -
                                    break
         
     | 
| 
       2252 
     | 
    
         
            -
                            else:
         
     | 
| 
       2253 
     | 
    
         
            -
                                self.watchdog_last_forward_ct = self.forward_ct
         
     | 
| 
       2254 
     | 
    
         
            -
                                self.watchdog_last_time = current
         
     | 
| 
       2255 
     | 
    
         
            -
                        time.sleep(self.watchdog_timeout // 2)
         
     | 
| 
       2256 
     | 
    
         
            -
             
     | 
| 
       2257 
     | 
    
         
            -
                    if not disable_request_logging():
         
     | 
| 
       2258 
     | 
    
         
            -
                        # Print batch size and memory pool info to check whether there are de-sync issues.
         
     | 
| 
       2259 
     | 
    
         
            -
                        if self.is_hybrid:
         
     | 
| 
       2260 
     | 
    
         
            -
                            (
         
     | 
| 
       2261 
     | 
    
         
            -
                                _,
         
     | 
| 
       2262 
     | 
    
         
            -
                                _,
         
     | 
| 
       2263 
     | 
    
         
            -
                                _,
         
     | 
| 
       2264 
     | 
    
         
            -
                                _,
         
     | 
| 
       2265 
     | 
    
         
            -
                                full_available_size,
         
     | 
| 
       2266 
     | 
    
         
            -
                                full_evictable_size,
         
     | 
| 
       2267 
     | 
    
         
            -
                                swa_available_size,
         
     | 
| 
       2268 
     | 
    
         
            -
                                swa_evictable_size,
         
     | 
| 
       2269 
     | 
    
         
            -
                            ) = self._get_swa_token_info()
         
     | 
| 
       2270 
     | 
    
         
            -
                            info_msg = (
         
     | 
| 
       2271 
     | 
    
         
            -
                                f"{full_available_size=}, "
         
     | 
| 
       2272 
     | 
    
         
            -
                                f"{full_evictable_size=}, "
         
     | 
| 
       2273 
     | 
    
         
            -
                                f"{swa_available_size=}, "
         
     | 
| 
       2274 
     | 
    
         
            -
                                f"{swa_evictable_size=}, "
         
     | 
| 
       2275 
     | 
    
         
            -
                            )
         
     | 
| 
       2276 
     | 
    
         
            -
                        else:
         
     | 
| 
       2277 
     | 
    
         
            -
                            _, _, available_size, evictable_size = self._get_token_info()
         
     | 
| 
       2278 
     | 
    
         
            -
                            info_msg = f"{available_size=}, " f"{evictable_size=}, "
         
     | 
| 
       2279 
     | 
    
         
            -
                        logger.error(
         
     | 
| 
       2280 
     | 
    
         
            -
                            f"{self.cur_batch.batch_size()=}, "
         
     | 
| 
       2281 
     | 
    
         
            -
                            f"{self.cur_batch.reqs=}, "
         
     | 
| 
       2282 
     | 
    
         
            -
                            f"{info_msg}"
         
     | 
| 
       2283 
     | 
    
         
            -
                        )
         
     | 
| 
       2284 
     | 
    
         
            -
             
     | 
| 
       2285 
     | 
    
         
            -
                    pyspy_dump_schedulers()
         
     | 
| 
       2286 
     | 
    
         
            -
                    logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
         
     | 
| 
       2287 
     | 
    
         
            -
                    print(file=sys.stderr, flush=True)
         
     | 
| 
       2288 
     | 
    
         
            -
                    print(file=sys.stdout, flush=True)
         
     | 
| 
       2289 
     | 
    
         
            -
             
     | 
| 
       2290 
     | 
    
         
            -
                    # Wait for some time so that the parent process can print the error.
         
     | 
| 
       2291 
     | 
    
         
            -
                    time.sleep(5)
         
     | 
| 
       2292 
     | 
    
         
            -
                    self.parent_process.send_signal(signal.SIGQUIT)
         
     | 
| 
       2293 
     | 
    
         
            -
             
     | 
| 
       2294 
2307 
     | 
    
         
             
                def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
         
     | 
| 
       2295 
2308 
     | 
    
         
             
                    success = self.flush_cache()
         
     | 
| 
       2296 
2309 
     | 
    
         
             
                    return FlushCacheReqOutput(success=success)
         
     | 
| 
         @@ -2305,13 +2318,30 @@ class Scheduler( 
     | 
|
| 
       2305 
2318 
     | 
    
         
             
                        if_success = False
         
     | 
| 
       2306 
2319 
     | 
    
         
             
                    return ClearHiCacheReqOutput(success=if_success)
         
     | 
| 
       2307 
2320 
     | 
    
         | 
| 
       2308 
     | 
    
         
            -
                def  
     | 
| 
       2309 
     | 
    
         
            -
                     
     | 
| 
       2310 
     | 
    
         
            -
                    if (
         
     | 
| 
      
 2321 
     | 
    
         
            +
                def _is_no_request(self):
         
     | 
| 
      
 2322 
     | 
    
         
            +
                    no_request = (
         
     | 
| 
       2311 
2323 
     | 
    
         
             
                        len(self.waiting_queue) == 0
         
     | 
| 
       2312 
2324 
     | 
    
         
             
                        and self.running_batch.is_empty()
         
     | 
| 
      
 2325 
     | 
    
         
            +
                        and (self.last_batch is None or self.last_batch.is_empty())
         
     | 
| 
      
 2326 
     | 
    
         
            +
                        and (self.cur_batch is None or self.cur_batch.is_empty())
         
     | 
| 
      
 2327 
     | 
    
         
            +
                        and (not self.enable_overlap or len(self.result_queue) == 0)
         
     | 
| 
       2313 
2328 
     | 
    
         
             
                        and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
         
     | 
| 
       2314 
     | 
    
         
            -
                    ) 
     | 
| 
      
 2329 
     | 
    
         
            +
                    )
         
     | 
| 
      
 2330 
     | 
    
         
            +
                    if self.disaggregation_mode == DisaggregationMode.PREFILL:
         
     | 
| 
      
 2331 
     | 
    
         
            +
                        no_request &= (
         
     | 
| 
      
 2332 
     | 
    
         
            +
                            len(self.disagg_prefill_bootstrap_queue.queue) == 0
         
     | 
| 
      
 2333 
     | 
    
         
            +
                            and len(self.disagg_prefill_inflight_queue) == 0
         
     | 
| 
      
 2334 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2335 
     | 
    
         
            +
                    if self.disaggregation_mode == DisaggregationMode.DECODE:
         
     | 
| 
      
 2336 
     | 
    
         
            +
                        no_request &= (
         
     | 
| 
      
 2337 
     | 
    
         
            +
                            len(self.disagg_decode_prealloc_queue.queue) == 0
         
     | 
| 
      
 2338 
     | 
    
         
            +
                            and len(self.disagg_decode_transfer_queue.queue) == 0
         
     | 
| 
      
 2339 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2340 
     | 
    
         
            +
                    return no_request
         
     | 
| 
      
 2341 
     | 
    
         
            +
             
     | 
| 
      
 2342 
     | 
    
         
            +
                def flush_cache(self):
         
     | 
| 
      
 2343 
     | 
    
         
            +
                    """Flush the memory pool and cache."""
         
     | 
| 
      
 2344 
     | 
    
         
            +
                    if self._is_no_request():
         
     | 
| 
       2315 
2345 
     | 
    
         
             
                        self.cur_batch = None
         
     | 
| 
       2316 
2346 
     | 
    
         
             
                        self.last_batch = None
         
     | 
| 
       2317 
2347 
     | 
    
         
             
                        self.tree_cache.reset()
         
     | 
| 
         @@ -2545,11 +2575,11 @@ class Scheduler( 
     | 
|
| 
       2545 
2575 
     | 
    
         
             
                        if not req.finished() and (
         
     | 
| 
       2546 
2576 
     | 
    
         
             
                            recv_req.abort_all or req.rid.startswith(recv_req.rid)
         
     | 
| 
       2547 
2577 
     | 
    
         
             
                        ):
         
     | 
| 
       2548 
     | 
    
         
            -
                            # Abort method 3: set ` 
     | 
| 
      
 2578 
     | 
    
         
            +
                            # Abort method 3: set `to_finish`
         
     | 
| 
       2549 
2579 
     | 
    
         
             
                            # The request will still run one decode forward pass.
         
     | 
| 
       2550 
2580 
     | 
    
         
             
                            # Then we reuse all existing code to clean up the KV cache allocation.
         
     | 
| 
       2551 
2581 
     | 
    
         
             
                            logger.debug(f"Abort running request. {req.rid=}")
         
     | 
| 
       2552 
     | 
    
         
            -
                            req. 
     | 
| 
      
 2582 
     | 
    
         
            +
                            req.to_finish = FINISH_ABORT()
         
     | 
| 
       2553 
2583 
     | 
    
         | 
| 
       2554 
2584 
     | 
    
         
             
                def _pause_engine(self) -> Tuple[List[Req], int]:
         
     | 
| 
       2555 
2585 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
         @@ -2743,10 +2773,13 @@ def run_scheduler_process( 
     | 
|
| 
       2743 
2773 
     | 
    
         | 
| 
       2744 
2774 
     | 
    
         
             
                # Set up tracing
         
     | 
| 
       2745 
2775 
     | 
    
         
             
                if server_args.enable_trace:
         
     | 
| 
       2746 
     | 
    
         
            -
                    process_tracing_init(server_args. 
     | 
| 
       2747 
     | 
    
         
            -
                     
     | 
| 
       2748 
     | 
    
         
            -
             
     | 
| 
       2749 
     | 
    
         
            -
                         
     | 
| 
      
 2776 
     | 
    
         
            +
                    process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
         
     | 
| 
      
 2777 
     | 
    
         
            +
                    thread_label = "Scheduler"
         
     | 
| 
      
 2778 
     | 
    
         
            +
                    if server_args.disaggregation_mode == "prefill":
         
     | 
| 
      
 2779 
     | 
    
         
            +
                        thread_label = "Prefill Scheduler"
         
     | 
| 
      
 2780 
     | 
    
         
            +
                    elif server_args.disaggregation_mode == "decode":
         
     | 
| 
      
 2781 
     | 
    
         
            +
                        thread_label = "Decode Scheduler"
         
     | 
| 
      
 2782 
     | 
    
         
            +
                    trace_set_thread_info(thread_label, tp_rank, dp_rank)
         
     | 
| 
       2750 
2783 
     | 
    
         | 
| 
       2751 
2784 
     | 
    
         
             
                # Create a scheduler and run the event loop
         
     | 
| 
       2752 
2785 
     | 
    
         
             
                try:
         
     | 
| 
         @@ -2769,7 +2802,9 @@ def run_scheduler_process( 
     | 
|
| 
       2769 
2802 
     | 
    
         | 
| 
       2770 
2803 
     | 
    
         
             
                    disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
         
     | 
| 
       2771 
2804 
     | 
    
         
             
                    if disaggregation_mode == DisaggregationMode.NULL:
         
     | 
| 
       2772 
     | 
    
         
            -
                        if  
     | 
| 
      
 2805 
     | 
    
         
            +
                        if scheduler.enable_pdmux:
         
     | 
| 
      
 2806 
     | 
    
         
            +
                            scheduler.event_loop_pdmux()
         
     | 
| 
      
 2807 
     | 
    
         
            +
                        elif server_args.pp_size > 1:
         
     | 
| 
       2773 
2808 
     | 
    
         
             
                            scheduler.event_loop_pp()
         
     | 
| 
       2774 
2809 
     | 
    
         
             
                        elif scheduler.enable_overlap:
         
     | 
| 
       2775 
2810 
     | 
    
         
             
                            scheduler.event_loop_overlap()
         
     | 
| 
         @@ -14,7 +14,13 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       14 
14 
     | 
    
         
             
                BatchEmbeddingOutput,
         
     | 
| 
       15 
15 
     | 
    
         
             
                BatchTokenIDOutput,
         
     | 
| 
       16 
16 
     | 
    
         
             
            )
         
     | 
| 
       17 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import  
     | 
| 
      
 17 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import (
         
     | 
| 
      
 18 
     | 
    
         
            +
                BaseFinishReason,
         
     | 
| 
      
 19 
     | 
    
         
            +
                Req,
         
     | 
| 
      
 20 
     | 
    
         
            +
                RequestStage,
         
     | 
| 
      
 21 
     | 
    
         
            +
                ScheduleBatch,
         
     | 
| 
      
 22 
     | 
    
         
            +
            )
         
     | 
| 
      
 23 
     | 
    
         
            +
            from sglang.srt.tracing.trace import trace_slice
         
     | 
| 
       18 
24 
     | 
    
         
             
            from sglang.srt.utils.common import ceil_div
         
     | 
| 
       19 
25 
     | 
    
         | 
| 
       20 
26 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
         @@ -160,6 +166,14 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       160 
166 
     | 
    
         
             
                                        )
         
     | 
| 
       161 
167 
     | 
    
         
             
                                        self.abort_request(AbortReq(rid=req.rid))
         
     | 
| 
       162 
168 
     | 
    
         
             
                                    req.grammar.finished = req.finished()
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                                trace_slice(
         
     | 
| 
      
 171 
     | 
    
         
            +
                                    RequestStage.PREFILL_FORWARD,
         
     | 
| 
      
 172 
     | 
    
         
            +
                                    req.rid,
         
     | 
| 
      
 173 
     | 
    
         
            +
                                    auto_next_anon=not req.finished(),
         
     | 
| 
      
 174 
     | 
    
         
            +
                                    thread_finish_flag=req.finished(),
         
     | 
| 
      
 175 
     | 
    
         
            +
                                )
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
       163 
177 
     | 
    
         
             
                            else:
         
     | 
| 
       164 
178 
     | 
    
         
             
                                # being chunked reqs' prefill is not finished
         
     | 
| 
       165 
179 
     | 
    
         
             
                                req.is_chunked -= 1
         
     | 
| 
         @@ -188,6 +202,12 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       188 
202 
     | 
    
         
             
                                            )
         
     | 
| 
       189 
203 
     | 
    
         
             
                                        logprob_pt += num_input_logprobs
         
     | 
| 
       190 
204 
     | 
    
         | 
| 
      
 205 
     | 
    
         
            +
                                trace_slice(
         
     | 
| 
      
 206 
     | 
    
         
            +
                                    RequestStage.PREFILL_CHUNKED_FORWARD,
         
     | 
| 
      
 207 
     | 
    
         
            +
                                    req.rid,
         
     | 
| 
      
 208 
     | 
    
         
            +
                                    auto_next_anon=True,
         
     | 
| 
      
 209 
     | 
    
         
            +
                                )
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
       191 
211 
     | 
    
         
             
                    else:  # embedding or reward model
         
     | 
| 
       192 
212 
     | 
    
         
             
                        is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
         
     | 
| 
       193 
213 
     | 
    
         | 
| 
         @@ -203,7 +223,10 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       203 
223 
     | 
    
         
             
                                    i
         
     | 
| 
       204 
224 
     | 
    
         
             
                                ].item()
         
     | 
| 
       205 
225 
     | 
    
         
             
                        else:
         
     | 
| 
       206 
     | 
    
         
            -
                            embeddings  
     | 
| 
      
 226 
     | 
    
         
            +
                            if isinstance(embeddings, torch.Tensor):
         
     | 
| 
      
 227 
     | 
    
         
            +
                                embeddings = embeddings.tolist()
         
     | 
| 
      
 228 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 229 
     | 
    
         
            +
                                embeddings = [tensor.tolist() for tensor in embeddings]
         
     | 
| 
       207 
230 
     | 
    
         | 
| 
       208 
231 
     | 
    
         
             
                        # Check finish conditions
         
     | 
| 
       209 
232 
     | 
    
         
             
                        for i, req in enumerate(batch.reqs):
         
     | 
| 
         @@ -224,6 +247,13 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       224 
247 
     | 
    
         
             
                                # being chunked reqs' prefill is not finished
         
     | 
| 
       225 
248 
     | 
    
         
             
                                req.is_chunked -= 1
         
     | 
| 
       226 
249 
     | 
    
         | 
| 
      
 250 
     | 
    
         
            +
                            trace_slice(
         
     | 
| 
      
 251 
     | 
    
         
            +
                                RequestStage.PREFILL_FORWARD,
         
     | 
| 
      
 252 
     | 
    
         
            +
                                req.rid,
         
     | 
| 
      
 253 
     | 
    
         
            +
                                auto_next_anon=not req.finished(),
         
     | 
| 
      
 254 
     | 
    
         
            +
                                thread_finish_flag=req.finished(),
         
     | 
| 
      
 255 
     | 
    
         
            +
                            )
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
       227 
257 
     | 
    
         
             
                    self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
         
     | 
| 
       228 
258 
     | 
    
         | 
| 
       229 
259 
     | 
    
         
             
                def _resolve_spec_overlap_token_ids(
         
     | 
| 
         @@ -727,6 +757,7 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       727 
757 
     | 
    
         
             
                    cached_tokens = []
         
     | 
| 
       728 
758 
     | 
    
         
             
                    spec_verify_ct = []
         
     | 
| 
       729 
759 
     | 
    
         
             
                    spec_accepted_tokens = []
         
     | 
| 
      
 760 
     | 
    
         
            +
                    retraction_counts = []
         
     | 
| 
       730 
761 
     | 
    
         
             
                    output_hidden_states = None
         
     | 
| 
       731 
762 
     | 
    
         | 
| 
       732 
763 
     | 
    
         
             
                    if return_logprob:
         
     | 
| 
         @@ -758,7 +789,7 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       758 
789 
     | 
    
         
             
                            continue
         
     | 
| 
       759 
790 
     | 
    
         | 
| 
       760 
791 
     | 
    
         
             
                        # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
         
     | 
| 
       761 
     | 
    
         
            -
                        if self.model_config.is_multimodal_gen and req. 
     | 
| 
      
 792 
     | 
    
         
            +
                        if self.model_config.is_multimodal_gen and req.to_finish:
         
     | 
| 
       762 
793 
     | 
    
         
             
                            continue
         
     | 
| 
       763 
794 
     | 
    
         | 
| 
       764 
795 
     | 
    
         
             
                        if req.finished():
         
     | 
| 
         @@ -828,6 +859,8 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       828 
859 
     | 
    
         
             
                            completion_tokens.append(len(output_ids_))
         
     | 
| 
       829 
860 
     | 
    
         
             
                            cached_tokens.append(req.cached_tokens)
         
     | 
| 
       830 
861 
     | 
    
         | 
| 
      
 862 
     | 
    
         
            +
                            retraction_counts.append(req.retraction_count)
         
     | 
| 
      
 863 
     | 
    
         
            +
             
     | 
| 
       831 
864 
     | 
    
         
             
                            if not self.spec_algorithm.is_none():
         
     | 
| 
       832 
865 
     | 
    
         
             
                                spec_verify_ct.append(req.spec_verify_ct)
         
     | 
| 
       833 
866 
     | 
    
         
             
                                spec_accepted_tokens.append(req.spec_accepted_tokens)
         
     | 
| 
         @@ -950,6 +983,7 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       950 
983 
     | 
    
         
             
                                http_worker_ipcs=http_worker_ipcs,
         
     | 
| 
       951 
984 
     | 
    
         
             
                                placeholder_tokens_idx=None,
         
     | 
| 
       952 
985 
     | 
    
         
             
                                placeholder_tokens_val=None,
         
     | 
| 
      
 986 
     | 
    
         
            +
                                retraction_counts=retraction_counts,
         
     | 
| 
       953 
987 
     | 
    
         
             
                            )
         
     | 
| 
       954 
988 
     | 
    
         
             
                        )
         
     | 
| 
       955 
989 
     | 
    
         | 
| 
         @@ -961,6 +995,7 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       961 
995 
     | 
    
         
             
                    embeddings = []
         
     | 
| 
       962 
996 
     | 
    
         
             
                    prompt_tokens = []
         
     | 
| 
       963 
997 
     | 
    
         
             
                    cached_tokens = []
         
     | 
| 
      
 998 
     | 
    
         
            +
                    retraction_counts = []
         
     | 
| 
       964 
999 
     | 
    
         
             
                    for req in reqs:
         
     | 
| 
       965 
1000 
     | 
    
         
             
                        if req.finished():
         
     | 
| 
       966 
1001 
     | 
    
         
             
                            rids.append(req.rid)
         
     | 
| 
         @@ -969,6 +1004,7 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       969 
1004 
     | 
    
         
             
                            embeddings.append(req.embedding)
         
     | 
| 
       970 
1005 
     | 
    
         
             
                            prompt_tokens.append(len(req.origin_input_ids))
         
     | 
| 
       971 
1006 
     | 
    
         
             
                            cached_tokens.append(req.cached_tokens)
         
     | 
| 
      
 1007 
     | 
    
         
            +
                            retraction_counts.append(req.retraction_count)
         
     | 
| 
       972 
1008 
     | 
    
         
             
                    self.send_to_detokenizer.send_output(
         
     | 
| 
       973 
1009 
     | 
    
         
             
                        BatchEmbeddingOutput(
         
     | 
| 
       974 
1010 
     | 
    
         
             
                            finished_reasons,
         
     | 
| 
         @@ -979,5 +1015,6 @@ class SchedulerOutputProcessorMixin: 
     | 
|
| 
       979 
1015 
     | 
    
         
             
                            http_worker_ipcs=http_worker_ipcs,
         
     | 
| 
       980 
1016 
     | 
    
         
             
                            placeholder_tokens_idx=None,
         
     | 
| 
       981 
1017 
     | 
    
         
             
                            placeholder_tokens_val=None,
         
     | 
| 
      
 1018 
     | 
    
         
            +
                            retraction_counts=retraction_counts,
         
     | 
| 
       982 
1019 
     | 
    
         
             
                        )
         
     | 
| 
       983 
1020 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -4,7 +4,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput 
     | 
|
| 
       4 
4 
     | 
    
         
             
            from sglang.srt.managers.schedule_batch import ScheduleBatch
         
     | 
| 
       5 
5 
     | 
    
         
             
            from sglang.srt.managers.utils import GenerationBatchResult
         
     | 
| 
       6 
6 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
         
     | 
| 
       7 
     | 
    
         
            -
            from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
         
     | 
| 
      
 7 
     | 
    
         
            +
            from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj, require_mlp_sync
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            class SchedulerPPMixin:
         
     | 
| 
         @@ -236,7 +236,12 @@ class SchedulerPPMixin: 
     | 
|
| 
       236 
236 
     | 
    
         
             
                            tmbs[mb_id] = transferred_rids
         
     | 
| 
       237 
237 
     | 
    
         | 
| 
       238 
238 
     | 
    
         
             
                            self.process_prefill_chunk()
         
     | 
| 
       239 
     | 
    
         
            -
             
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
                            batch = self.get_new_batch_prefill()
         
     | 
| 
      
 241 
     | 
    
         
            +
                            if require_mlp_sync(self.server_args):
         
     | 
| 
      
 242 
     | 
    
         
            +
                                batch = self.prepare_mlp_sync_batch(batch)
         
     | 
| 
      
 243 
     | 
    
         
            +
                            mbs[mb_id] = batch
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
       240 
245 
     | 
    
         
             
                            self.running_mbs[mb_id] = self.running_batch
         
     | 
| 
       241 
246 
     | 
    
         | 
| 
       242 
247 
     | 
    
         
             
                            self.cur_batch = mbs[mb_id]
         
     |