sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +73 -14
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +30 -7
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -12
 - sglang/srt/entrypoints/engine.py +31 -20
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +94 -94
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +23 -2
 - sglang/srt/eplb/expert_distribution.py +64 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +19 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +24 -1
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +35 -6
 - sglang/srt/layers/logits_processor.py +9 -20
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +78 -289
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +35 -10
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +13 -84
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +130 -46
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +29 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +185 -144
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +165 -78
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +253 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +55 -14
 - sglang/srt/model_executor/model_runner.py +77 -170
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +296 -78
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +29 -197
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +23 -2
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +35 -5
 - sglang/srt/models/qwen3_moe.py +18 -12
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +459 -199
 - sglang/srt/single_batch_overlap.py +2 -4
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +142 -74
 - sglang/srt/utils/hf_transformers_utils.py +38 -12
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
 - sglang/srt/models/vila.py +0 -306
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -205,6 +205,14 @@ class ModelConfig: 
     | 
|
| 
       205 
205 
     | 
    
         
             
                        self.hf_config, "image_token_id", None
         
     | 
| 
       206 
206 
     | 
    
         
             
                    ) or getattr(self.hf_config, "image_token_index", None)
         
     | 
| 
       207 
207 
     | 
    
         | 
| 
      
 208 
     | 
    
         
            +
                    # matryoshka embeddings
         
     | 
| 
      
 209 
     | 
    
         
            +
                    self.matryoshka_dimensions = getattr(
         
     | 
| 
      
 210 
     | 
    
         
            +
                        self.hf_config, "matryoshka_dimensions", None
         
     | 
| 
      
 211 
     | 
    
         
            +
                    )
         
     | 
| 
      
 212 
     | 
    
         
            +
                    self.is_matryoshka = self.matryoshka_dimensions or getattr(
         
     | 
| 
      
 213 
     | 
    
         
            +
                        self.hf_config, "is_matryoshka", False
         
     | 
| 
      
 214 
     | 
    
         
            +
                    )
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
       208 
216 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       209 
217 
     | 
    
         
             
                def from_server_args(
         
     | 
| 
       210 
218 
     | 
    
         
             
                    server_args: ServerArgs,
         
     | 
| 
         @@ -358,6 +366,13 @@ class ModelConfig: 
     | 
|
| 
       358 
366 
     | 
    
         
             
                        self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
         
     | 
| 
       359 
367 
     | 
    
         
             
                        self.v_head_dim = self.hf_text_config.v_head_dim
         
     | 
| 
       360 
368 
     | 
    
         
             
                        self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
         
     | 
| 
      
 369 
     | 
    
         
            +
                    elif "KimiLinearForCausalLM" in self.hf_config.architectures:
         
     | 
| 
      
 370 
     | 
    
         
            +
                        self.head_dim = 72
         
     | 
| 
      
 371 
     | 
    
         
            +
                        self.attention_arch = AttentionArch.MLA
         
     | 
| 
      
 372 
     | 
    
         
            +
                        self.kv_lora_rank = self.hf_config.kv_lora_rank
         
     | 
| 
      
 373 
     | 
    
         
            +
                        self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
         
     | 
| 
      
 374 
     | 
    
         
            +
                        self.v_head_dim = self.hf_config.v_head_dim
         
     | 
| 
      
 375 
     | 
    
         
            +
                        self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
         
     | 
| 
       361 
376 
     | 
    
         
             
                    else:
         
     | 
| 
       362 
377 
     | 
    
         
             
                        if (
         
     | 
| 
       363 
378 
     | 
    
         
             
                            "MistralModel" in self.hf_config.architectures
         
     | 
| 
         @@ -535,7 +550,7 @@ class ModelConfig: 
     | 
|
| 
       535 
550 
     | 
    
         
             
                            quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
         
     | 
| 
       536 
551 
     | 
    
         
             
                    return quant_cfg
         
     | 
| 
       537 
552 
     | 
    
         | 
| 
       538 
     | 
    
         
            -
                def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
         
     | 
| 
      
 553 
     | 
    
         
            +
                def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
         
     | 
| 
       539 
554 
     | 
    
         
             
                    """Parse ModelOpt quantization config and return the appropriate quant_method."""
         
     | 
| 
       540 
555 
     | 
    
         
             
                    json_quant_configs = quant_config_dict["quantization"]
         
     | 
| 
       541 
556 
     | 
    
         
             
                    quant_algo = json_quant_configs.get("quant_algo", None)
         
     | 
| 
         @@ -547,8 +562,7 @@ class ModelConfig: 
     | 
|
| 
       547 
562 
     | 
    
         
             
                    elif quant_algo and "FP8" in quant_algo:
         
     | 
| 
       548 
563 
     | 
    
         
             
                        return {"quant_method": "modelopt_fp8"}
         
     | 
| 
       549 
564 
     | 
    
         
             
                    else:
         
     | 
| 
       550 
     | 
    
         
            -
                         
     | 
| 
       551 
     | 
    
         
            -
                        return {"quant_method": "modelopt_fp8"}
         
     | 
| 
      
 565 
     | 
    
         
            +
                        return None
         
     | 
| 
       552 
566 
     | 
    
         | 
| 
       553 
567 
     | 
    
         
             
                def _is_already_quantized(self) -> bool:
         
     | 
| 
       554 
568 
     | 
    
         
             
                    """Check if the model is already quantized based on config files."""
         
     | 
| 
         @@ -583,14 +597,20 @@ class ModelConfig: 
     | 
|
| 
       583 
597 
     | 
    
         
             
                        return
         
     | 
| 
       584 
598 
     | 
    
         | 
| 
       585 
599 
     | 
    
         
             
                    # Check if ModelOpt quantization is specified
         
     | 
| 
       586 
     | 
    
         
            -
                     
     | 
| 
      
 600 
     | 
    
         
            +
                    _MODELOPT_QUANTIZATION_METHODS = [
         
     | 
| 
       587 
601 
     | 
    
         
             
                        "modelopt",
         
     | 
| 
       588 
602 
     | 
    
         
             
                        "modelopt_fp8",
         
     | 
| 
       589 
603 
     | 
    
         
             
                        "modelopt_fp4",
         
     | 
| 
       590 
604 
     | 
    
         
             
                    ]
         
     | 
| 
      
 605 
     | 
    
         
            +
                    modelopt_quantization_specified = (
         
     | 
| 
      
 606 
     | 
    
         
            +
                        self.quantization in _MODELOPT_QUANTIZATION_METHODS
         
     | 
| 
      
 607 
     | 
    
         
            +
                    )
         
     | 
| 
       591 
608 
     | 
    
         | 
| 
       592 
609 
     | 
    
         
             
                    if not modelopt_quantization_specified:
         
     | 
| 
       593 
     | 
    
         
            -
                        raise ValueError( 
     | 
| 
      
 610 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 611 
     | 
    
         
            +
                            "quantize_and_serve requires ModelOpt quantization (set with --quantization "
         
     | 
| 
      
 612 
     | 
    
         
            +
                            f"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})"
         
     | 
| 
      
 613 
     | 
    
         
            +
                        )
         
     | 
| 
       594 
614 
     | 
    
         | 
| 
       595 
615 
     | 
    
         
             
                    # quantize_and_serve is disabled due to compatibility issues
         
     | 
| 
       596 
616 
     | 
    
         
             
                    raise NotImplementedError(
         
     | 
| 
         @@ -614,6 +634,7 @@ class ModelConfig: 
     | 
|
| 
       614 
634 
     | 
    
         
             
                        "petit_nvfp4",
         
     | 
| 
       615 
635 
     | 
    
         
             
                        "quark",
         
     | 
| 
       616 
636 
     | 
    
         
             
                        "mxfp4",
         
     | 
| 
      
 637 
     | 
    
         
            +
                        "auto-round",
         
     | 
| 
       617 
638 
     | 
    
         
             
                    ]
         
     | 
| 
       618 
639 
     | 
    
         
             
                    optimized_quantization_methods = [
         
     | 
| 
       619 
640 
     | 
    
         
             
                        "fp8",
         
     | 
| 
         @@ -635,6 +656,7 @@ class ModelConfig: 
     | 
|
| 
       635 
656 
     | 
    
         
             
                        "petit_nvfp4",
         
     | 
| 
       636 
657 
     | 
    
         
             
                    ]
         
     | 
| 
       637 
658 
     | 
    
         
             
                    compatible_quantization_methods = {
         
     | 
| 
      
 659 
     | 
    
         
            +
                        "modelopt_fp8": ["modelopt"],
         
     | 
| 
       638 
660 
     | 
    
         
             
                        "modelopt_fp4": ["modelopt"],
         
     | 
| 
       639 
661 
     | 
    
         
             
                        "petit_nvfp4": ["modelopt"],
         
     | 
| 
       640 
662 
     | 
    
         
             
                        "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
         
     | 
| 
         @@ -806,7 +828,7 @@ def _get_and_verify_dtype( 
     | 
|
| 
       806 
828 
     | 
    
         
             
            ) -> torch.dtype:
         
     | 
| 
       807 
829 
     | 
    
         
             
                # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
         
     | 
| 
       808 
830 
     | 
    
         
             
                # because config.torch_dtype can be None.
         
     | 
| 
       809 
     | 
    
         
            -
                config_dtype = getattr(config, " 
     | 
| 
      
 831 
     | 
    
         
            +
                config_dtype = getattr(config, "dtype", None)
         
     | 
| 
       810 
832 
     | 
    
         
             
                if isinstance(config_dtype, str):
         
     | 
| 
       811 
833 
     | 
    
         
             
                    config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
         
     | 
| 
       812 
834 
     | 
    
         
             
                if config_dtype is None:
         
     | 
| 
         @@ -915,12 +937,13 @@ multimodal_model_archs = [ 
     | 
|
| 
       915 
937 
     | 
    
         
             
                "InternVLChatModel",
         
     | 
| 
       916 
938 
     | 
    
         
             
                "InternS1ForConditionalGeneration",
         
     | 
| 
       917 
939 
     | 
    
         
             
                "Phi4MMForCausalLM",
         
     | 
| 
       918 
     | 
    
         
            -
                "VILAForConditionalGeneration",
         
     | 
| 
       919 
940 
     | 
    
         
             
                "Step3VLForConditionalGeneration",
         
     | 
| 
       920 
941 
     | 
    
         
             
                "POINTSV15ChatModel",
         
     | 
| 
       921 
942 
     | 
    
         
             
                "DotsVLMForCausalLM",
         
     | 
| 
       922 
943 
     | 
    
         
             
                "DotsOCRForCausalLM",
         
     | 
| 
       923 
944 
     | 
    
         
             
                "Sarashina2VisionForCausalLM",
         
     | 
| 
      
 945 
     | 
    
         
            +
                "NVILAForConditionalGeneration",
         
     | 
| 
      
 946 
     | 
    
         
            +
                "NVILALiteForConditionalGeneration",
         
     | 
| 
       924 
947 
     | 
    
         
             
                "DeepseekOCRForCausalLM",
         
     | 
| 
       925 
948 
     | 
    
         
             
            ]
         
     | 
| 
       926 
949 
     | 
    
         | 
    
        sglang/srt/constants.py
    CHANGED
    
    
| 
         @@ -0,0 +1,149 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """
         
     | 
| 
      
 2 
     | 
    
         
            +
            This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
         
     | 
| 
      
 3 
     | 
    
         
            +
            After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            Usage:
         
     | 
| 
      
 6 
     | 
    
         
            +
            Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
         
     | 
| 
      
 7 
     | 
    
         
            +
            A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
         
     | 
| 
      
 8 
     | 
    
         
            +
            Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
         
     | 
| 
      
 9 
     | 
    
         
            +
            The file contains a series of key-value pairs, where the keys correspond to operator names in the model
         
     | 
| 
      
 10 
     | 
    
         
            +
            (similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
         
     | 
| 
      
 11 
     | 
    
         
            +
            """
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 14 
     | 
    
         
            +
            import os
         
     | 
| 
      
 15 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         
     | 
| 
      
 20 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            class TensorDumper:
         
     | 
| 
      
 26 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 27 
     | 
    
         
            +
                    self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
         
     | 
| 
      
 28 
     | 
    
         
            +
                ):
         
     | 
| 
      
 29 
     | 
    
         
            +
                    self._dump_layers = dump_layers
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self._forward_pass_id = 0
         
     | 
| 
      
 31 
     | 
    
         
            +
                    self._pid = os.getpid()
         
     | 
| 
      
 32 
     | 
    
         
            +
                    self._current_tensors = {}
         
     | 
| 
      
 33 
     | 
    
         
            +
                    self._base_dir = Path(dump_dir)
         
     | 
| 
      
 34 
     | 
    
         
            +
                    rank = tp_size * pp_rank + tp_rank
         
     | 
| 
      
 35 
     | 
    
         
            +
                    self._process_dir = (
         
     | 
| 
      
 36 
     | 
    
         
            +
                        self._base_dir / f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}"
         
     | 
| 
      
 37 
     | 
    
         
            +
                    )
         
     | 
| 
      
 38 
     | 
    
         
            +
                    self._process_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                def get_dump_dir(self):
         
     | 
| 
      
 41 
     | 
    
         
            +
                    return str(self._process_dir)
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                def add_tensor(self, name, tensor_item):
         
     | 
| 
      
 44 
     | 
    
         
            +
                    if isinstance(tensor_item, (tuple, list)):
         
     | 
| 
      
 45 
     | 
    
         
            +
                        tensors = [t.cpu() for t in tensor_item if t is not None]
         
     | 
| 
      
 46 
     | 
    
         
            +
                        if len(tensors) == 1:
         
     | 
| 
      
 47 
     | 
    
         
            +
                            self._current_tensors[name] = tensors[0]
         
     | 
| 
      
 48 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 49 
     | 
    
         
            +
                            self._current_tensors[name] = tensors
         
     | 
| 
      
 50 
     | 
    
         
            +
                    elif isinstance(tensor_item, torch.Tensor):
         
     | 
| 
      
 51 
     | 
    
         
            +
                        self._current_tensors[name] = tensor_item.cpu()
         
     | 
| 
      
 52 
     | 
    
         
            +
                    elif isinstance(tensor_item, LogitsProcessorOutput):
         
     | 
| 
      
 53 
     | 
    
         
            +
                        self._current_tensors[name] = tensor_item.next_token_logits.cpu()
         
     | 
| 
      
 54 
     | 
    
         
            +
                    elif isinstance(tensor_item, ForwardBatch):
         
     | 
| 
      
 55 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.input_ids"] = (
         
     | 
| 
      
 56 
     | 
    
         
            +
                            tensor_item.input_ids.cpu()
         
     | 
| 
      
 57 
     | 
    
         
            +
                        )
         
     | 
| 
      
 58 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.seq_lens"] = (
         
     | 
| 
      
 59 
     | 
    
         
            +
                            tensor_item.seq_lens.cpu()
         
     | 
| 
      
 60 
     | 
    
         
            +
                        )
         
     | 
| 
      
 61 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.positions"] = (
         
     | 
| 
      
 62 
     | 
    
         
            +
                            tensor_item.positions.cpu()
         
     | 
| 
      
 63 
     | 
    
         
            +
                        )
         
     | 
| 
      
 64 
     | 
    
         
            +
                    elif isinstance(tensor_item, PPProxyTensors):
         
     | 
| 
      
 65 
     | 
    
         
            +
                        for tensor_name in tensor_item.tensors.keys():
         
     | 
| 
      
 66 
     | 
    
         
            +
                            self._current_tensors[name + ".pp_proxy_tensors." + tensor_name] = (
         
     | 
| 
      
 67 
     | 
    
         
            +
                                tensor_item.tensors[tensor_name].cpu()
         
     | 
| 
      
 68 
     | 
    
         
            +
                            )
         
     | 
| 
      
 69 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 70 
     | 
    
         
            +
                        logger.warning(f"Unsupported type: {type(tensor_item)}: {tensor_item}")
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                def dump_current_tensors(self):
         
     | 
| 
      
 73 
     | 
    
         
            +
                    if len(self._current_tensors) == 0:
         
     | 
| 
      
 74 
     | 
    
         
            +
                        return
         
     | 
| 
      
 75 
     | 
    
         
            +
                    tensor_file_for_pass = self._process_dir / f"Pass{self._forward_pass_id:05d}.pt"
         
     | 
| 
      
 76 
     | 
    
         
            +
                    logger.info(
         
     | 
| 
      
 77 
     | 
    
         
            +
                        f"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}"
         
     | 
| 
      
 78 
     | 
    
         
            +
                    )
         
     | 
| 
      
 79 
     | 
    
         
            +
                    torch.save(self._current_tensors, str(tensor_file_for_pass))
         
     | 
| 
      
 80 
     | 
    
         
            +
                    self._current_tensors = {}
         
     | 
| 
      
 81 
     | 
    
         
            +
                    self._forward_pass_id += 1
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                def _add_hook_recursive(
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self, model, prefix, top_level_module_name, layers_module_name
         
     | 
| 
      
 85 
     | 
    
         
            +
                ):
         
     | 
| 
      
 86 
     | 
    
         
            +
                    model_top_level_module_matched = False
         
     | 
| 
      
 87 
     | 
    
         
            +
                    layers_prefix = top_level_module_name + "." + layers_module_name
         
     | 
| 
      
 88 
     | 
    
         
            +
                    for name, module in model._modules.items():
         
     | 
| 
      
 89 
     | 
    
         
            +
                        top_level_model = False
         
     | 
| 
      
 90 
     | 
    
         
            +
                        if len(prefix) == 0:
         
     | 
| 
      
 91 
     | 
    
         
            +
                            cur_name = name
         
     | 
| 
      
 92 
     | 
    
         
            +
                            if cur_name == top_level_module_name:
         
     | 
| 
      
 93 
     | 
    
         
            +
                                model_top_level_module_matched = True
         
     | 
| 
      
 94 
     | 
    
         
            +
                                top_level_model = True
         
     | 
| 
      
 95 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 96 
     | 
    
         
            +
                            cur_name = prefix + "." + name
         
     | 
| 
      
 97 
     | 
    
         
            +
                        if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
         
     | 
| 
      
 98 
     | 
    
         
            +
                            # If we only need n layers, skip the reset layers.
         
     | 
| 
      
 99 
     | 
    
         
            +
                            # Most models' layout is like model.layers.0.
         
     | 
| 
      
 100 
     | 
    
         
            +
                            cur_layer = int(name)
         
     | 
| 
      
 101 
     | 
    
         
            +
                            if cur_layer >= self._dump_layers:
         
     | 
| 
      
 102 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 103 
     | 
    
         
            +
                        if module is not None:
         
     | 
| 
      
 104 
     | 
    
         
            +
                            _, sub_count = self._add_hook_recursive(
         
     | 
| 
      
 105 
     | 
    
         
            +
                                module, cur_name, top_level_module_name, layers_module_name
         
     | 
| 
      
 106 
     | 
    
         
            +
                            )
         
     | 
| 
      
 107 
     | 
    
         
            +
                            if sub_count == 0 or top_level_model:
         
     | 
| 
      
 108 
     | 
    
         
            +
                                # Avoid duplicated output hooks, e.g. self_attn may contain:
         
     | 
| 
      
 109 
     | 
    
         
            +
                                # self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
         
     | 
| 
      
 110 
     | 
    
         
            +
                                # Therefore, we do not need to add output hooks for self_attn,
         
     | 
| 
      
 111 
     | 
    
         
            +
                                # since the output of self_attn should be the same to self_attn.o_proj.
         
     | 
| 
      
 112 
     | 
    
         
            +
                                module.register_forward_hook(
         
     | 
| 
      
 113 
     | 
    
         
            +
                                    self._dump_hook(cur_name, top_level_model)
         
     | 
| 
      
 114 
     | 
    
         
            +
                                )
         
     | 
| 
      
 115 
     | 
    
         
            +
                    return model_top_level_module_matched, len(model._modules.items())
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                def _dump_hook(self, tensor_name, do_dump):
         
     | 
| 
      
 118 
     | 
    
         
            +
                    def inner_dump_hook(module, input, output):
         
     | 
| 
      
 119 
     | 
    
         
            +
                        if do_dump:
         
     | 
| 
      
 120 
     | 
    
         
            +
                            # This is the top-level model, so we will record the input for it.
         
     | 
| 
      
 121 
     | 
    
         
            +
                            for item in input:
         
     | 
| 
      
 122 
     | 
    
         
            +
                                if isinstance(item, ForwardBatch):
         
     | 
| 
      
 123 
     | 
    
         
            +
                                    self.add_tensor(tensor_name, item)
         
     | 
| 
      
 124 
     | 
    
         
            +
                            self.dump_current_tensors()
         
     | 
| 
      
 125 
     | 
    
         
            +
                        if output is not None:
         
     | 
| 
      
 126 
     | 
    
         
            +
                            self.add_tensor(tensor_name, output)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                    return inner_dump_hook
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
            def register_forward_hook_for_model(
         
     | 
| 
      
 132 
     | 
    
         
            +
                model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
         
     | 
| 
      
 133 
     | 
    
         
            +
            ):
         
     | 
| 
      
 134 
     | 
    
         
            +
                tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
         
     | 
| 
      
 135 
     | 
    
         
            +
                # Most models have the layerout like:
         
     | 
| 
      
 136 
     | 
    
         
            +
                # XxxxForCausalLM
         
     | 
| 
      
 137 
     | 
    
         
            +
                #     (model): XxxxModel
         
     | 
| 
      
 138 
     | 
    
         
            +
                #         (layers): ModuleList
         
     | 
| 
      
 139 
     | 
    
         
            +
                # If the model is not constructed with this layout,
         
     | 
| 
      
 140 
     | 
    
         
            +
                # environment variable can be used to specify the module names.
         
     | 
| 
      
 141 
     | 
    
         
            +
                top_level_module_name = os.getenv("TENSOR_DUMP_TOP_LEVEL_MODULE_NAME", "model")
         
     | 
| 
      
 142 
     | 
    
         
            +
                layers_module_name = os.getenv("TENSOR_DUMP_LAYERS_MODULE_NAME", "layers")
         
     | 
| 
      
 143 
     | 
    
         
            +
                model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(
         
     | 
| 
      
 144 
     | 
    
         
            +
                    model, "", top_level_module_name, layers_module_name
         
     | 
| 
      
 145 
     | 
    
         
            +
                )
         
     | 
| 
      
 146 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 147 
     | 
    
         
            +
                    model_top_level_module_matched
         
     | 
| 
      
 148 
     | 
    
         
            +
                ), f"model should have a module named {top_level_module_name}"
         
     | 
| 
      
 149 
     | 
    
         
            +
                return tensor_dumper
         
     | 
| 
         @@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import ( 
     | 
|
| 
       58 
58 
     | 
    
         
             
                ReqToTokenPool,
         
     | 
| 
       59 
59 
     | 
    
         
             
                SWAKVPool,
         
     | 
| 
       60 
60 
     | 
    
         
             
            )
         
     | 
| 
      
 61 
     | 
    
         
            +
            from sglang.srt.tracing.trace import (
         
     | 
| 
      
 62 
     | 
    
         
            +
                trace_event_batch,
         
     | 
| 
      
 63 
     | 
    
         
            +
                trace_slice_batch,
         
     | 
| 
      
 64 
     | 
    
         
            +
                trace_slice_end,
         
     | 
| 
      
 65 
     | 
    
         
            +
            )
         
     | 
| 
       61 
66 
     | 
    
         
             
            from sglang.srt.utils import get_int_env_var, require_mlp_sync
         
     | 
| 
       62 
67 
     | 
    
         
             
            from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
         
     | 
| 
       63 
68 
     | 
    
         | 
| 
         @@ -313,6 +318,7 @@ class DecodePreallocQueue: 
     | 
|
| 
       313 
318 
     | 
    
         
             
                        )
         
     | 
| 
       314 
319 
     | 
    
         | 
| 
       315 
320 
     | 
    
         
             
                        req.add_latency(RequestStage.DECODE_PREPARE)
         
     | 
| 
      
 321 
     | 
    
         
            +
                        trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
         
     | 
| 
       316 
322 
     | 
    
         
             
                        self.queue.append(
         
     | 
| 
       317 
323 
     | 
    
         
             
                            DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
         
     | 
| 
       318 
324 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -521,13 +527,15 @@ class DecodePreallocQueue: 
     | 
|
| 
       521 
527 
     | 
    
         
             
                        decode_req.kv_receiver.init(
         
     | 
| 
       522 
528 
     | 
    
         
             
                            page_indices, decode_req.metadata_buffer_index, state_indices
         
     | 
| 
       523 
529 
     | 
    
         
             
                        )
         
     | 
| 
       524 
     | 
    
         
            -
                        decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
         
     | 
| 
       525 
530 
     | 
    
         
             
                        preallocated_reqs.append(decode_req)
         
     | 
| 
       526 
531 
     | 
    
         
             
                        indices_to_remove.add(i)
         
     | 
| 
       527 
532 
     | 
    
         
             
                        decode_req.req.time_stats.decode_transfer_queue_entry_time = (
         
     | 
| 
       528 
533 
     | 
    
         
             
                            time.perf_counter()
         
     | 
| 
       529 
534 
     | 
    
         
             
                        )
         
     | 
| 
       530 
535 
     | 
    
         
             
                        decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
         
     | 
| 
      
 536 
     | 
    
         
            +
                        trace_slice_end(
         
     | 
| 
      
 537 
     | 
    
         
            +
                            RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
         
     | 
| 
      
 538 
     | 
    
         
            +
                        )
         
     | 
| 
       531 
539 
     | 
    
         | 
| 
       532 
540 
     | 
    
         
             
                    self.queue = [
         
     | 
| 
       533 
541 
     | 
    
         
             
                        entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
         
     | 
| 
         @@ -765,8 +773,12 @@ class DecodeTransferQueue: 
     | 
|
| 
       765 
773 
     | 
    
         
             
                            indices_to_remove.add(i)
         
     | 
| 
       766 
774 
     | 
    
         
             
                            decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
         
     | 
| 
       767 
775 
     | 
    
         | 
| 
       768 
     | 
    
         
            -
                            # special handling for  
     | 
| 
       769 
     | 
    
         
            -
                             
     | 
| 
      
 776 
     | 
    
         
            +
                            # special handling for corner cases
         
     | 
| 
      
 777 
     | 
    
         
            +
                            should_finish = (
         
     | 
| 
      
 778 
     | 
    
         
            +
                                decode_req.req.sampling_params.max_new_tokens == 1
         
     | 
| 
      
 779 
     | 
    
         
            +
                                or output_id in decode_req.req.eos_token_ids
         
     | 
| 
      
 780 
     | 
    
         
            +
                            )
         
     | 
| 
      
 781 
     | 
    
         
            +
                            if should_finish:
         
     | 
| 
       770 
782 
     | 
    
         
             
                                # finish immediately
         
     | 
| 
       771 
783 
     | 
    
         
             
                                decode_req.req.time_stats.forward_entry_time = (
         
     | 
| 
       772 
784 
     | 
    
         
             
                                    decode_req.req.time_stats.completion_time
         
     | 
| 
         @@ -776,8 +788,19 @@ class DecodeTransferQueue: 
     | 
|
| 
       776 
788 
     | 
    
         
             
                                    [decode_req.req], decode_req.req.return_logprob
         
     | 
| 
       777 
789 
     | 
    
         
             
                                )
         
     | 
| 
       778 
790 
     | 
    
         
             
                                self.tree_cache.cache_finished_req(decode_req.req)
         
     | 
| 
      
 791 
     | 
    
         
            +
                                trace_slice_end(
         
     | 
| 
      
 792 
     | 
    
         
            +
                                    RequestStage.DECODE_QUICK_FINISH,
         
     | 
| 
      
 793 
     | 
    
         
            +
                                    decode_req.req.rid,
         
     | 
| 
      
 794 
     | 
    
         
            +
                                    thread_finish_flag=True,
         
     | 
| 
      
 795 
     | 
    
         
            +
                                )
         
     | 
| 
       779 
796 
     | 
    
         
             
                            else:
         
     | 
| 
       780 
797 
     | 
    
         
             
                                transferred_reqs.append(decode_req.req)
         
     | 
| 
      
 798 
     | 
    
         
            +
                                trace_slice_end(
         
     | 
| 
      
 799 
     | 
    
         
            +
                                    RequestStage.DECODE_TRANSFERRED,
         
     | 
| 
      
 800 
     | 
    
         
            +
                                    decode_req.req.rid,
         
     | 
| 
      
 801 
     | 
    
         
            +
                                    auto_next_anon=True,
         
     | 
| 
      
 802 
     | 
    
         
            +
                                )
         
     | 
| 
      
 803 
     | 
    
         
            +
             
     | 
| 
       781 
804 
     | 
    
         
             
                        elif poll in [
         
     | 
| 
       782 
805 
     | 
    
         
             
                            KVPoll.Bootstrapping,
         
     | 
| 
       783 
806 
     | 
    
         
             
                            KVPoll.WaitingForInput,
         
     | 
| 
         @@ -823,6 +846,7 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       823 
846 
     | 
    
         
             
                                self.stream_output(
         
     | 
| 
       824 
847 
     | 
    
         
             
                                    batch.reqs, any(req.return_logprob for req in batch.reqs)
         
     | 
| 
       825 
848 
     | 
    
         
             
                                )
         
     | 
| 
      
 849 
     | 
    
         
            +
                                trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
         
     | 
| 
       826 
850 
     | 
    
         
             
                                if prepare_mlp_sync_flag:
         
     | 
| 
       827 
851 
     | 
    
         
             
                                    self._prepare_idle_batch_and_run(None)
         
     | 
| 
       828 
852 
     | 
    
         
             
                            else:
         
     | 
| 
         @@ -872,6 +896,7 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       872 
896 
     | 
    
         
             
                                self.stream_output(
         
     | 
| 
       873 
897 
     | 
    
         
             
                                    batch.reqs, any(req.return_logprob for req in batch.reqs)
         
     | 
| 
       874 
898 
     | 
    
         
             
                                )
         
     | 
| 
      
 899 
     | 
    
         
            +
                                trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
         
     | 
| 
       875 
900 
     | 
    
         
             
                                if prepare_mlp_sync_flag:
         
     | 
| 
       876 
901 
     | 
    
         
             
                                    batch_, batch_result = self._prepare_idle_batch_and_run(
         
     | 
| 
       877 
902 
     | 
    
         
             
                                        None, delay_process=True
         
     | 
| 
         @@ -954,6 +979,9 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       954 
979 
     | 
    
         
             
                            self.running_batch = self.update_running_batch(self.running_batch)
         
     | 
| 
       955 
980 
     | 
    
         
             
                            ret = self.running_batch if not self.running_batch.is_empty() else None
         
     | 
| 
       956 
981 
     | 
    
         | 
| 
      
 982 
     | 
    
         
            +
                    if ret:
         
     | 
| 
      
 983 
     | 
    
         
            +
                        attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
         
     | 
| 
      
 984 
     | 
    
         
            +
                        trace_event_batch("schedule", ret.reqs, attrs=attrs)
         
     | 
| 
       957 
985 
     | 
    
         
             
                    return ret
         
     | 
| 
       958 
986 
     | 
    
         | 
| 
       959 
987 
     | 
    
         
             
                def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
         
     | 
| 
         @@ -1009,6 +1037,9 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       1009 
1037 
     | 
    
         
             
                    return new_batch
         
     | 
| 
       1010 
1038 
     | 
    
         | 
| 
       1011 
1039 
     | 
    
         
             
                def process_decode_queue(self: Scheduler):
         
     | 
| 
      
 1040 
     | 
    
         
            +
                    if self.server_args.disaggregation_decode_enable_offload_kvcache:
         
     | 
| 
      
 1041 
     | 
    
         
            +
                        self.decode_offload_manager.check_offload_progress()
         
     | 
| 
      
 1042 
     | 
    
         
            +
             
     | 
| 
       1012 
1043 
     | 
    
         
             
                    # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
         
     | 
| 
       1013 
1044 
     | 
    
         
             
                    resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
         
     | 
| 
       1014 
1045 
     | 
    
         
             
                    self.waiting_queue.extend(resumed_reqs)
         
     | 
| 
         @@ -1031,6 +1062,3 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       1031 
1062 
     | 
    
         
             
                            self.disagg_decode_transfer_queue.pop_transferred()
         
     | 
| 
       1032 
1063 
     | 
    
         
             
                        )  # the requests which kv has arrived
         
     | 
| 
       1033 
1064 
     | 
    
         
             
                        self.waiting_queue.extend(alloc_reqs)
         
     | 
| 
       1034 
     | 
    
         
            -
             
     | 
| 
       1035 
     | 
    
         
            -
                    if self.server_args.disaggregation_decode_enable_offload_kvcache:
         
     | 
| 
       1036 
     | 
    
         
            -
                        self.decode_offload_manager.check_offload_progress()
         
     | 
| 
         @@ -231,8 +231,8 @@ class NixlKVManager(CommonKVManager): 
     | 
|
| 
       231 
231 
     | 
    
         
             
                        ]
         
     | 
| 
       232 
232 
     | 
    
         
             
                        for k in keys_to_remove:
         
     | 
| 
       233 
233 
     | 
    
         
             
                            del self.connection_pool[k]
         
     | 
| 
       234 
     | 
    
         
            -
                        if failed_bootstrap_addr in self. 
     | 
| 
       235 
     | 
    
         
            -
                            del self. 
     | 
| 
      
 234 
     | 
    
         
            +
                        if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
         
     | 
| 
      
 235 
     | 
    
         
            +
                            del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
         
     | 
| 
       236 
236 
     | 
    
         
             
                        if failed_bootstrap_addr in self.prefill_dp_size_table:
         
     | 
| 
       237 
237 
     | 
    
         
             
                            del self.prefill_dp_size_table[failed_bootstrap_addr]
         
     | 
| 
       238 
238 
     | 
    
         
             
                        if failed_bootstrap_addr in self.prefill_pp_size_table:
         
     | 
| 
         @@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( 
     | 
|
| 
       53 
53 
     | 
    
         
             
                NSATokenToKVPool,
         
     | 
| 
       54 
54 
     | 
    
         
             
                SWAKVPool,
         
     | 
| 
       55 
55 
     | 
    
         
             
            )
         
     | 
| 
      
 56 
     | 
    
         
            +
            from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
         
     | 
| 
       56 
57 
     | 
    
         
             
            from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
         
     | 
| 
       57 
58 
     | 
    
         | 
| 
       58 
59 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
         @@ -198,6 +199,7 @@ class PrefillBootstrapQueue: 
     | 
|
| 
       198 
199 
     | 
    
         
             
                    self._process_req(req)
         
     | 
| 
       199 
200 
     | 
    
         
             
                    req.add_latency(RequestStage.PREFILL_PREPARE)
         
     | 
| 
       200 
201 
     | 
    
         
             
                    self.queue.append(req)
         
     | 
| 
      
 202 
     | 
    
         
            +
                    trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
         
     | 
| 
       201 
203 
     | 
    
         | 
| 
       202 
204 
     | 
    
         
             
                def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
         
     | 
| 
       203 
205 
     | 
    
         
             
                    for req in reqs:
         
     | 
| 
         @@ -289,6 +291,10 @@ class PrefillBootstrapQueue: 
     | 
|
| 
       289 
291 
     | 
    
         
             
                        req.time_stats.wait_queue_entry_time = time.perf_counter()
         
     | 
| 
       290 
292 
     | 
    
         
             
                        req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
         
     | 
| 
       291 
293 
     | 
    
         | 
| 
      
 294 
     | 
    
         
            +
                        trace_slice_end(
         
     | 
| 
      
 295 
     | 
    
         
            +
                            RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
         
     | 
| 
      
 296 
     | 
    
         
            +
                        )
         
     | 
| 
      
 297 
     | 
    
         
            +
             
     | 
| 
       292 
298 
     | 
    
         
             
                    self.queue = [
         
     | 
| 
       293 
299 
     | 
    
         
             
                        entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
         
     | 
| 
       294 
300 
     | 
    
         
             
                    ]
         
     | 
| 
         @@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       316 
322 
     | 
    
         
             
                        )
         
     | 
| 
       317 
323 
     | 
    
         
             
                        self.process_prefill_chunk()
         
     | 
| 
       318 
324 
     | 
    
         
             
                        batch = self.get_new_batch_prefill()
         
     | 
| 
      
 325 
     | 
    
         
            +
                        if batch:
         
     | 
| 
      
 326 
     | 
    
         
            +
                            attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
         
     | 
| 
      
 327 
     | 
    
         
            +
                            trace_event_batch("schedule", batch.reqs, attrs=attrs)
         
     | 
| 
       319 
328 
     | 
    
         | 
| 
       320 
329 
     | 
    
         
             
                        if require_mlp_sync(self.server_args):
         
     | 
| 
       321 
330 
     | 
    
         
             
                            batch = self.prepare_mlp_sync_batch(batch)
         
     | 
| 
         @@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       348 
357 
     | 
    
         
             
                        )
         
     | 
| 
       349 
358 
     | 
    
         
             
                        self.process_prefill_chunk()
         
     | 
| 
       350 
359 
     | 
    
         
             
                        batch = self.get_new_batch_prefill()
         
     | 
| 
      
 360 
     | 
    
         
            +
                        if batch:
         
     | 
| 
      
 361 
     | 
    
         
            +
                            attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
         
     | 
| 
      
 362 
     | 
    
         
            +
                            trace_event_batch("schedule", batch.reqs, attrs=attrs)
         
     | 
| 
       351 
363 
     | 
    
         | 
| 
       352 
364 
     | 
    
         
             
                        if require_mlp_sync(self.server_args):
         
     | 
| 
       353 
365 
     | 
    
         
             
                            batch = self.prepare_mlp_sync_batch(batch)
         
     | 
| 
         @@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       423 
435 
     | 
    
         
             
                            req.output_ids.append(next_token_id)
         
     | 
| 
       424 
436 
     | 
    
         
             
                            self.tree_cache.cache_unfinished_req(req)  # update the tree and lock
         
     | 
| 
       425 
437 
     | 
    
         
             
                            req.add_latency(RequestStage.PREFILL_FORWARD)
         
     | 
| 
      
 438 
     | 
    
         
            +
                            trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
         
     | 
| 
       426 
439 
     | 
    
         
             
                            self.disagg_prefill_inflight_queue.append(req)
         
     | 
| 
       427 
440 
     | 
    
         
             
                            if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
         
     | 
| 
       428 
441 
     | 
    
         
             
                                req.output_topk_p = batch.spec_info.topk_p[i]
         
     | 
| 
         @@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       487 
500 
     | 
    
         | 
| 
       488 
501 
     | 
    
         
             
                            if self.enable_overlap:
         
     | 
| 
       489 
502 
     | 
    
         
             
                                self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
         
     | 
| 
      
 503 
     | 
    
         
            +
                            trace_slice(
         
     | 
| 
      
 504 
     | 
    
         
            +
                                RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
         
     | 
| 
      
 505 
     | 
    
         
            +
                            )
         
     | 
| 
       490 
506 
     | 
    
         | 
| 
       491 
507 
     | 
    
         
             
                    self.maybe_send_health_check_signal()
         
     | 
| 
       492 
508 
     | 
    
         | 
| 
         @@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       558 
574 
     | 
    
         
             
                        req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
         
     | 
| 
       559 
575 
     | 
    
         
             
                        self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
         
     | 
| 
       560 
576 
     | 
    
         
             
                        req.metadata_buffer_index = -1
         
     | 
| 
      
 577 
     | 
    
         
            +
                        trace_slice(
         
     | 
| 
      
 578 
     | 
    
         
            +
                            RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
         
     | 
| 
      
 579 
     | 
    
         
            +
                        )
         
     | 
| 
       561 
580 
     | 
    
         | 
| 
       562 
581 
     | 
    
         
             
                    self.disagg_prefill_inflight_queue = undone_reqs
         
     | 
| 
       563 
582 
     | 
    
         | 
| 
         @@ -569,7 +588,7 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       569 
588 
     | 
    
         
             
                    """
         
     | 
| 
       570 
589 
     | 
    
         
             
                    polls = poll_and_all_reduce(
         
     | 
| 
       571 
590 
     | 
    
         
             
                        [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
         
     | 
| 
       572 
     | 
    
         
            -
                        self.tp_worker. 
     | 
| 
      
 591 
     | 
    
         
            +
                        self.tp_worker.get_attention_tp_cpu_group(),
         
     | 
| 
       573 
592 
     | 
    
         
             
                    )
         
     | 
| 
       574 
593 
     | 
    
         | 
| 
       575 
594 
     | 
    
         
             
                    transferred_rids: List[str] = []
         
     | 
| 
         @@ -703,8 +722,11 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       703 
722 
     | 
    
         
             
                    else:
         
     | 
| 
       704 
723 
     | 
    
         
             
                        data = None
         
     | 
| 
       705 
724 
     | 
    
         | 
| 
       706 
     | 
    
         
            -
                    if self. 
     | 
| 
      
 725 
     | 
    
         
            +
                    if self.attn_tp_size != 1:
         
     | 
| 
       707 
726 
     | 
    
         
             
                        data = broadcast_pyobj(
         
     | 
| 
       708 
     | 
    
         
            -
                            data, 
     | 
| 
      
 727 
     | 
    
         
            +
                            data,
         
     | 
| 
      
 728 
     | 
    
         
            +
                            self.attn_tp_group.rank,
         
     | 
| 
      
 729 
     | 
    
         
            +
                            self.attn_tp_cpu_group,
         
     | 
| 
      
 730 
     | 
    
         
            +
                            src=self.attn_tp_group.ranks[0],
         
     | 
| 
       709 
731 
     | 
    
         
             
                        )
         
     | 
| 
       710 
732 
     | 
    
         
             
                    return data
         
     | 
| 
         @@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import 
     | 
|
| 
       18 
18 
     | 
    
         
             
                is_weak_contiguous,
         
     | 
| 
       19 
19 
     | 
    
         
             
            )
         
     | 
| 
       20 
20 
     | 
    
         
             
            from sglang.srt.distributed.parallel_state import in_the_same_node_as
         
     | 
| 
      
 21 
     | 
    
         
            +
            from sglang.srt.environ import envs
         
     | 
| 
       21 
22 
     | 
    
         
             
            from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
         
     | 
| 
       22 
23 
     | 
    
         | 
| 
       23 
24 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
         @@ -210,6 +211,7 @@ class CustomAllreduce: 
     | 
|
| 
       210 
211 
     | 
    
         
             
                        self.register_buffer(self.buffer)
         
     | 
| 
       211 
212 
     | 
    
         | 
| 
       212 
213 
     | 
    
         
             
                    self.disabled = False
         
     | 
| 
      
 214 
     | 
    
         
            +
                    self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
         
     | 
| 
       213 
215 
     | 
    
         | 
| 
       214 
216 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       215 
217 
     | 
    
         
             
                def create_shared_buffer(
         
     | 
| 
         @@ -394,7 +396,7 @@ class CustomAllreduce: 
     | 
|
| 
       394 
396 
     | 
    
         
             
                            if _is_hip:
         
     | 
| 
       395 
397 
     | 
    
         
             
                                return self.all_reduce_reg(input)
         
     | 
| 
       396 
398 
     | 
    
         
             
                            else:
         
     | 
| 
       397 
     | 
    
         
            -
                                return self.all_reduce(input, registered= 
     | 
| 
      
 399 
     | 
    
         
            +
                                return self.all_reduce(input, registered=not self.tms_cudagraph)
         
     | 
| 
       398 
400 
     | 
    
         
             
                        else:
         
     | 
| 
       399 
401 
     | 
    
         
             
                            # If warm up, mimic the allocation pattern since custom
         
     | 
| 
       400 
402 
     | 
    
         
             
                            # allreduce is out-of-place.
         
     | 
| 
         @@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) 
     | 
|
| 
       68 
68 
     | 
    
         | 
| 
       69 
69 
     | 
    
         
             
            @dataclass
         
     | 
| 
       70 
70 
     | 
    
         
             
            class GraphCaptureContext:
         
     | 
| 
       71 
     | 
    
         
            -
                stream: torch. 
     | 
| 
      
 71 
     | 
    
         
            +
                stream: torch.get_device_module().Stream
         
     | 
| 
       72 
72 
     | 
    
         | 
| 
       73 
73 
     | 
    
         | 
| 
       74 
74 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -340,17 +340,10 @@ class GroupCoordinator: 
     | 
|
| 
       340 
340 
     | 
    
         
             
                    self.qr_comm: Optional[QuickAllReduce] = None
         
     | 
| 
       341 
341 
     | 
    
         
             
                    if use_custom_allreduce and self.world_size > 1:
         
     | 
| 
       342 
342 
     | 
    
         
             
                        # Initialize a custom fast all-reduce implementation.
         
     | 
| 
       343 
     | 
    
         
            -
                        if torch_compile is not None and torch_compile:
         
     | 
| 
       344 
     | 
    
         
            -
                            # For piecewise CUDA graph, the requirement for custom allreduce is larger to
         
     | 
| 
       345 
     | 
    
         
            -
                            # avoid illegal cuda memory access.
         
     | 
| 
       346 
     | 
    
         
            -
                            ca_max_size = 256 * 1024 * 1024
         
     | 
| 
       347 
     | 
    
         
            -
                        else:
         
     | 
| 
       348 
     | 
    
         
            -
                            ca_max_size = 8 * 1024 * 1024
         
     | 
| 
       349 
343 
     | 
    
         
             
                        try:
         
     | 
| 
       350 
344 
     | 
    
         
             
                            self.ca_comm = CustomAllreduce(
         
     | 
| 
       351 
345 
     | 
    
         
             
                                group=self.cpu_group,
         
     | 
| 
       352 
346 
     | 
    
         
             
                                device=self.device,
         
     | 
| 
       353 
     | 
    
         
            -
                                max_size=ca_max_size,
         
     | 
| 
       354 
347 
     | 
    
         
             
                            )
         
     | 
| 
       355 
348 
     | 
    
         
             
                        except Exception as e:
         
     | 
| 
       356 
349 
     | 
    
         
             
                            logger.warning(
         
     | 
| 
         @@ -505,7 +498,7 @@ class GroupCoordinator: 
     | 
|
| 
       505 
498 
     | 
    
         
             
                            maybe_pynccl_context = nullcontext()
         
     | 
| 
       506 
499 
     | 
    
         
             
                        else:
         
     | 
| 
       507 
500 
     | 
    
         
             
                            maybe_pynccl_context = pynccl_comm.change_state(
         
     | 
| 
       508 
     | 
    
         
            -
                                enable=True, stream=torch. 
     | 
| 
      
 501 
     | 
    
         
            +
                                enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
       509 
502 
     | 
    
         
             
                            )
         
     | 
| 
       510 
503 
     | 
    
         | 
| 
       511 
504 
     | 
    
         
             
                        pymscclpp_comm = self.pymscclpp_comm
         
     | 
| 
         @@ -562,7 +555,7 @@ class GroupCoordinator: 
     | 
|
| 
       562 
555 
     | 
    
         
             
                        and input_.symmetric_memory
         
     | 
| 
       563 
556 
     | 
    
         
             
                    ):
         
     | 
| 
       564 
557 
     | 
    
         
             
                        with self.pynccl_comm.change_state(
         
     | 
| 
       565 
     | 
    
         
            -
                            enable=True, stream=torch. 
     | 
| 
      
 558 
     | 
    
         
            +
                            enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
       566 
559 
     | 
    
         
             
                        ):
         
     | 
| 
       567 
560 
     | 
    
         
             
                            self.pynccl_comm.all_reduce(input_)
         
     | 
| 
       568 
561 
     | 
    
         
             
                            return input_
         
     | 
| 
         @@ -662,7 +655,9 @@ class GroupCoordinator: 
     | 
|
| 
       662 
655 
     | 
    
         
             
                    world_size = self.world_size
         
     | 
| 
       663 
656 
     | 
    
         
             
                    pynccl_comm = self.pynccl_comm
         
     | 
| 
       664 
657 
     | 
    
         | 
| 
       665 
     | 
    
         
            -
                    with pynccl_comm.change_state( 
     | 
| 
      
 658 
     | 
    
         
            +
                    with pynccl_comm.change_state(
         
     | 
| 
      
 659 
     | 
    
         
            +
                        enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
      
 660 
     | 
    
         
            +
                    ):
         
     | 
| 
       666 
661 
     | 
    
         
             
                        assert (
         
     | 
| 
       667 
662 
     | 
    
         
             
                            pynccl_comm is not None and not pynccl_comm.disabled
         
     | 
| 
       668 
663 
     | 
    
         
             
                        ), "pynccl is required for reduce_scatterv"
         
     | 
| 
         @@ -786,7 +781,9 @@ class GroupCoordinator: 
     | 
|
| 
       786 
781 
     | 
    
         
             
                    world_size = self.world_size
         
     | 
| 
       787 
782 
     | 
    
         
             
                    pynccl_comm = self.pynccl_comm
         
     | 
| 
       788 
783 
     | 
    
         | 
| 
       789 
     | 
    
         
            -
                    with pynccl_comm.change_state( 
     | 
| 
      
 784 
     | 
    
         
            +
                    with pynccl_comm.change_state(
         
     | 
| 
      
 785 
     | 
    
         
            +
                        enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
      
 786 
     | 
    
         
            +
                    ):
         
     | 
| 
       790 
787 
     | 
    
         
             
                        assert (
         
     | 
| 
       791 
788 
     | 
    
         
             
                            pynccl_comm is not None and not pynccl_comm.disabled
         
     | 
| 
       792 
789 
     | 
    
         
             
                        ), "pynccl is required for all_gatherv"
         
     |