sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -205,6 +205,14 @@ class ModelConfig: 
     | 
|
| 
       205 
205 
     | 
    
         
             
                        self.hf_config, "image_token_id", None
         
     | 
| 
       206 
206 
     | 
    
         
             
                    ) or getattr(self.hf_config, "image_token_index", None)
         
     | 
| 
       207 
207 
     | 
    
         | 
| 
      
 208 
     | 
    
         
            +
                    # matryoshka embeddings
         
     | 
| 
      
 209 
     | 
    
         
            +
                    self.matryoshka_dimensions = getattr(
         
     | 
| 
      
 210 
     | 
    
         
            +
                        self.hf_config, "matryoshka_dimensions", None
         
     | 
| 
      
 211 
     | 
    
         
            +
                    )
         
     | 
| 
      
 212 
     | 
    
         
            +
                    self.is_matryoshka = self.matryoshka_dimensions or getattr(
         
     | 
| 
      
 213 
     | 
    
         
            +
                        self.hf_config, "is_matryoshka", False
         
     | 
| 
      
 214 
     | 
    
         
            +
                    )
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
       208 
216 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       209 
217 
     | 
    
         
             
                def from_server_args(
         
     | 
| 
       210 
218 
     | 
    
         
             
                    server_args: ServerArgs,
         
     | 
| 
         @@ -358,6 +366,13 @@ class ModelConfig: 
     | 
|
| 
       358 
366 
     | 
    
         
             
                        self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
         
     | 
| 
       359 
367 
     | 
    
         
             
                        self.v_head_dim = self.hf_text_config.v_head_dim
         
     | 
| 
       360 
368 
     | 
    
         
             
                        self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
         
     | 
| 
      
 369 
     | 
    
         
            +
                    elif "KimiLinearForCausalLM" in self.hf_config.architectures:
         
     | 
| 
      
 370 
     | 
    
         
            +
                        self.head_dim = 72
         
     | 
| 
      
 371 
     | 
    
         
            +
                        self.attention_arch = AttentionArch.MLA
         
     | 
| 
      
 372 
     | 
    
         
            +
                        self.kv_lora_rank = self.hf_config.kv_lora_rank
         
     | 
| 
      
 373 
     | 
    
         
            +
                        self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
         
     | 
| 
      
 374 
     | 
    
         
            +
                        self.v_head_dim = self.hf_config.v_head_dim
         
     | 
| 
      
 375 
     | 
    
         
            +
                        self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
         
     | 
| 
       361 
376 
     | 
    
         
             
                    else:
         
     | 
| 
       362 
377 
     | 
    
         
             
                        if (
         
     | 
| 
       363 
378 
     | 
    
         
             
                            "MistralModel" in self.hf_config.architectures
         
     | 
| 
         @@ -582,14 +597,20 @@ class ModelConfig: 
     | 
|
| 
       582 
597 
     | 
    
         
             
                        return
         
     | 
| 
       583 
598 
     | 
    
         | 
| 
       584 
599 
     | 
    
         
             
                    # Check if ModelOpt quantization is specified
         
     | 
| 
       585 
     | 
    
         
            -
                     
     | 
| 
      
 600 
     | 
    
         
            +
                    _MODELOPT_QUANTIZATION_METHODS = [
         
     | 
| 
       586 
601 
     | 
    
         
             
                        "modelopt",
         
     | 
| 
       587 
602 
     | 
    
         
             
                        "modelopt_fp8",
         
     | 
| 
       588 
603 
     | 
    
         
             
                        "modelopt_fp4",
         
     | 
| 
       589 
604 
     | 
    
         
             
                    ]
         
     | 
| 
      
 605 
     | 
    
         
            +
                    modelopt_quantization_specified = (
         
     | 
| 
      
 606 
     | 
    
         
            +
                        self.quantization in _MODELOPT_QUANTIZATION_METHODS
         
     | 
| 
      
 607 
     | 
    
         
            +
                    )
         
     | 
| 
       590 
608 
     | 
    
         | 
| 
       591 
609 
     | 
    
         
             
                    if not modelopt_quantization_specified:
         
     | 
| 
       592 
     | 
    
         
            -
                        raise ValueError( 
     | 
| 
      
 610 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 611 
     | 
    
         
            +
                            "quantize_and_serve requires ModelOpt quantization (set with --quantization "
         
     | 
| 
      
 612 
     | 
    
         
            +
                            f"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})"
         
     | 
| 
      
 613 
     | 
    
         
            +
                        )
         
     | 
| 
       593 
614 
     | 
    
         | 
| 
       594 
615 
     | 
    
         
             
                    # quantize_and_serve is disabled due to compatibility issues
         
     | 
| 
       595 
616 
     | 
    
         
             
                    raise NotImplementedError(
         
     | 
| 
         @@ -613,6 +634,7 @@ class ModelConfig: 
     | 
|
| 
       613 
634 
     | 
    
         
             
                        "petit_nvfp4",
         
     | 
| 
       614 
635 
     | 
    
         
             
                        "quark",
         
     | 
| 
       615 
636 
     | 
    
         
             
                        "mxfp4",
         
     | 
| 
      
 637 
     | 
    
         
            +
                        "auto-round",
         
     | 
| 
       616 
638 
     | 
    
         
             
                    ]
         
     | 
| 
       617 
639 
     | 
    
         
             
                    optimized_quantization_methods = [
         
     | 
| 
       618 
640 
     | 
    
         
             
                        "fp8",
         
     | 
| 
         @@ -634,6 +656,7 @@ class ModelConfig: 
     | 
|
| 
       634 
656 
     | 
    
         
             
                        "petit_nvfp4",
         
     | 
| 
       635 
657 
     | 
    
         
             
                    ]
         
     | 
| 
       636 
658 
     | 
    
         
             
                    compatible_quantization_methods = {
         
     | 
| 
      
 659 
     | 
    
         
            +
                        "modelopt_fp8": ["modelopt"],
         
     | 
| 
       637 
660 
     | 
    
         
             
                        "modelopt_fp4": ["modelopt"],
         
     | 
| 
       638 
661 
     | 
    
         
             
                        "petit_nvfp4": ["modelopt"],
         
     | 
| 
       639 
662 
     | 
    
         
             
                        "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
         
     | 
    
        sglang/srt/constants.py
    CHANGED
    
    
| 
         @@ -0,0 +1,149 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """
         
     | 
| 
      
 2 
     | 
    
         
            +
            This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
         
     | 
| 
      
 3 
     | 
    
         
            +
            After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            Usage:
         
     | 
| 
      
 6 
     | 
    
         
            +
            Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
         
     | 
| 
      
 7 
     | 
    
         
            +
            A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
         
     | 
| 
      
 8 
     | 
    
         
            +
            Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
         
     | 
| 
      
 9 
     | 
    
         
            +
            The file contains a series of key-value pairs, where the keys correspond to operator names in the model
         
     | 
| 
      
 10 
     | 
    
         
            +
            (similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
         
     | 
| 
      
 11 
     | 
    
         
            +
            """
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 14 
     | 
    
         
            +
            import os
         
     | 
| 
      
 15 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         
     | 
| 
      
 20 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            class TensorDumper:
         
     | 
| 
      
 26 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 27 
     | 
    
         
            +
                    self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
         
     | 
| 
      
 28 
     | 
    
         
            +
                ):
         
     | 
| 
      
 29 
     | 
    
         
            +
                    self._dump_layers = dump_layers
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self._forward_pass_id = 0
         
     | 
| 
      
 31 
     | 
    
         
            +
                    self._pid = os.getpid()
         
     | 
| 
      
 32 
     | 
    
         
            +
                    self._current_tensors = {}
         
     | 
| 
      
 33 
     | 
    
         
            +
                    self._base_dir = Path(dump_dir)
         
     | 
| 
      
 34 
     | 
    
         
            +
                    rank = tp_size * pp_rank + tp_rank
         
     | 
| 
      
 35 
     | 
    
         
            +
                    self._process_dir = (
         
     | 
| 
      
 36 
     | 
    
         
            +
                        self._base_dir / f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}"
         
     | 
| 
      
 37 
     | 
    
         
            +
                    )
         
     | 
| 
      
 38 
     | 
    
         
            +
                    self._process_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                def get_dump_dir(self):
         
     | 
| 
      
 41 
     | 
    
         
            +
                    return str(self._process_dir)
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                def add_tensor(self, name, tensor_item):
         
     | 
| 
      
 44 
     | 
    
         
            +
                    if isinstance(tensor_item, (tuple, list)):
         
     | 
| 
      
 45 
     | 
    
         
            +
                        tensors = [t.cpu() for t in tensor_item if t is not None]
         
     | 
| 
      
 46 
     | 
    
         
            +
                        if len(tensors) == 1:
         
     | 
| 
      
 47 
     | 
    
         
            +
                            self._current_tensors[name] = tensors[0]
         
     | 
| 
      
 48 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 49 
     | 
    
         
            +
                            self._current_tensors[name] = tensors
         
     | 
| 
      
 50 
     | 
    
         
            +
                    elif isinstance(tensor_item, torch.Tensor):
         
     | 
| 
      
 51 
     | 
    
         
            +
                        self._current_tensors[name] = tensor_item.cpu()
         
     | 
| 
      
 52 
     | 
    
         
            +
                    elif isinstance(tensor_item, LogitsProcessorOutput):
         
     | 
| 
      
 53 
     | 
    
         
            +
                        self._current_tensors[name] = tensor_item.next_token_logits.cpu()
         
     | 
| 
      
 54 
     | 
    
         
            +
                    elif isinstance(tensor_item, ForwardBatch):
         
     | 
| 
      
 55 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.input_ids"] = (
         
     | 
| 
      
 56 
     | 
    
         
            +
                            tensor_item.input_ids.cpu()
         
     | 
| 
      
 57 
     | 
    
         
            +
                        )
         
     | 
| 
      
 58 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.seq_lens"] = (
         
     | 
| 
      
 59 
     | 
    
         
            +
                            tensor_item.seq_lens.cpu()
         
     | 
| 
      
 60 
     | 
    
         
            +
                        )
         
     | 
| 
      
 61 
     | 
    
         
            +
                        self._current_tensors[name + ".forward_batch_info.positions"] = (
         
     | 
| 
      
 62 
     | 
    
         
            +
                            tensor_item.positions.cpu()
         
     | 
| 
      
 63 
     | 
    
         
            +
                        )
         
     | 
| 
      
 64 
     | 
    
         
            +
                    elif isinstance(tensor_item, PPProxyTensors):
         
     | 
| 
      
 65 
     | 
    
         
            +
                        for tensor_name in tensor_item.tensors.keys():
         
     | 
| 
      
 66 
     | 
    
         
            +
                            self._current_tensors[name + ".pp_proxy_tensors." + tensor_name] = (
         
     | 
| 
      
 67 
     | 
    
         
            +
                                tensor_item.tensors[tensor_name].cpu()
         
     | 
| 
      
 68 
     | 
    
         
            +
                            )
         
     | 
| 
      
 69 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 70 
     | 
    
         
            +
                        logger.warning(f"Unsupported type: {type(tensor_item)}: {tensor_item}")
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                def dump_current_tensors(self):
         
     | 
| 
      
 73 
     | 
    
         
            +
                    if len(self._current_tensors) == 0:
         
     | 
| 
      
 74 
     | 
    
         
            +
                        return
         
     | 
| 
      
 75 
     | 
    
         
            +
                    tensor_file_for_pass = self._process_dir / f"Pass{self._forward_pass_id:05d}.pt"
         
     | 
| 
      
 76 
     | 
    
         
            +
                    logger.info(
         
     | 
| 
      
 77 
     | 
    
         
            +
                        f"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}"
         
     | 
| 
      
 78 
     | 
    
         
            +
                    )
         
     | 
| 
      
 79 
     | 
    
         
            +
                    torch.save(self._current_tensors, str(tensor_file_for_pass))
         
     | 
| 
      
 80 
     | 
    
         
            +
                    self._current_tensors = {}
         
     | 
| 
      
 81 
     | 
    
         
            +
                    self._forward_pass_id += 1
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                def _add_hook_recursive(
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self, model, prefix, top_level_module_name, layers_module_name
         
     | 
| 
      
 85 
     | 
    
         
            +
                ):
         
     | 
| 
      
 86 
     | 
    
         
            +
                    model_top_level_module_matched = False
         
     | 
| 
      
 87 
     | 
    
         
            +
                    layers_prefix = top_level_module_name + "." + layers_module_name
         
     | 
| 
      
 88 
     | 
    
         
            +
                    for name, module in model._modules.items():
         
     | 
| 
      
 89 
     | 
    
         
            +
                        top_level_model = False
         
     | 
| 
      
 90 
     | 
    
         
            +
                        if len(prefix) == 0:
         
     | 
| 
      
 91 
     | 
    
         
            +
                            cur_name = name
         
     | 
| 
      
 92 
     | 
    
         
            +
                            if cur_name == top_level_module_name:
         
     | 
| 
      
 93 
     | 
    
         
            +
                                model_top_level_module_matched = True
         
     | 
| 
      
 94 
     | 
    
         
            +
                                top_level_model = True
         
     | 
| 
      
 95 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 96 
     | 
    
         
            +
                            cur_name = prefix + "." + name
         
     | 
| 
      
 97 
     | 
    
         
            +
                        if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
         
     | 
| 
      
 98 
     | 
    
         
            +
                            # If we only need n layers, skip the reset layers.
         
     | 
| 
      
 99 
     | 
    
         
            +
                            # Most models' layout is like model.layers.0.
         
     | 
| 
      
 100 
     | 
    
         
            +
                            cur_layer = int(name)
         
     | 
| 
      
 101 
     | 
    
         
            +
                            if cur_layer >= self._dump_layers:
         
     | 
| 
      
 102 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 103 
     | 
    
         
            +
                        if module is not None:
         
     | 
| 
      
 104 
     | 
    
         
            +
                            _, sub_count = self._add_hook_recursive(
         
     | 
| 
      
 105 
     | 
    
         
            +
                                module, cur_name, top_level_module_name, layers_module_name
         
     | 
| 
      
 106 
     | 
    
         
            +
                            )
         
     | 
| 
      
 107 
     | 
    
         
            +
                            if sub_count == 0 or top_level_model:
         
     | 
| 
      
 108 
     | 
    
         
            +
                                # Avoid duplicated output hooks, e.g. self_attn may contain:
         
     | 
| 
      
 109 
     | 
    
         
            +
                                # self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
         
     | 
| 
      
 110 
     | 
    
         
            +
                                # Therefore, we do not need to add output hooks for self_attn,
         
     | 
| 
      
 111 
     | 
    
         
            +
                                # since the output of self_attn should be the same to self_attn.o_proj.
         
     | 
| 
      
 112 
     | 
    
         
            +
                                module.register_forward_hook(
         
     | 
| 
      
 113 
     | 
    
         
            +
                                    self._dump_hook(cur_name, top_level_model)
         
     | 
| 
      
 114 
     | 
    
         
            +
                                )
         
     | 
| 
      
 115 
     | 
    
         
            +
                    return model_top_level_module_matched, len(model._modules.items())
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                def _dump_hook(self, tensor_name, do_dump):
         
     | 
| 
      
 118 
     | 
    
         
            +
                    def inner_dump_hook(module, input, output):
         
     | 
| 
      
 119 
     | 
    
         
            +
                        if do_dump:
         
     | 
| 
      
 120 
     | 
    
         
            +
                            # This is the top-level model, so we will record the input for it.
         
     | 
| 
      
 121 
     | 
    
         
            +
                            for item in input:
         
     | 
| 
      
 122 
     | 
    
         
            +
                                if isinstance(item, ForwardBatch):
         
     | 
| 
      
 123 
     | 
    
         
            +
                                    self.add_tensor(tensor_name, item)
         
     | 
| 
      
 124 
     | 
    
         
            +
                            self.dump_current_tensors()
         
     | 
| 
      
 125 
     | 
    
         
            +
                        if output is not None:
         
     | 
| 
      
 126 
     | 
    
         
            +
                            self.add_tensor(tensor_name, output)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                    return inner_dump_hook
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
            def register_forward_hook_for_model(
         
     | 
| 
      
 132 
     | 
    
         
            +
                model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
         
     | 
| 
      
 133 
     | 
    
         
            +
            ):
         
     | 
| 
      
 134 
     | 
    
         
            +
                tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
         
     | 
| 
      
 135 
     | 
    
         
            +
                # Most models have the layerout like:
         
     | 
| 
      
 136 
     | 
    
         
            +
                # XxxxForCausalLM
         
     | 
| 
      
 137 
     | 
    
         
            +
                #     (model): XxxxModel
         
     | 
| 
      
 138 
     | 
    
         
            +
                #         (layers): ModuleList
         
     | 
| 
      
 139 
     | 
    
         
            +
                # If the model is not constructed with this layout,
         
     | 
| 
      
 140 
     | 
    
         
            +
                # environment variable can be used to specify the module names.
         
     | 
| 
      
 141 
     | 
    
         
            +
                top_level_module_name = os.getenv("TENSOR_DUMP_TOP_LEVEL_MODULE_NAME", "model")
         
     | 
| 
      
 142 
     | 
    
         
            +
                layers_module_name = os.getenv("TENSOR_DUMP_LAYERS_MODULE_NAME", "layers")
         
     | 
| 
      
 143 
     | 
    
         
            +
                model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(
         
     | 
| 
      
 144 
     | 
    
         
            +
                    model, "", top_level_module_name, layers_module_name
         
     | 
| 
      
 145 
     | 
    
         
            +
                )
         
     | 
| 
      
 146 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 147 
     | 
    
         
            +
                    model_top_level_module_matched
         
     | 
| 
      
 148 
     | 
    
         
            +
                ), f"model should have a module named {top_level_module_name}"
         
     | 
| 
      
 149 
     | 
    
         
            +
                return tensor_dumper
         
     | 
| 
         @@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import ( 
     | 
|
| 
       58 
58 
     | 
    
         
             
                ReqToTokenPool,
         
     | 
| 
       59 
59 
     | 
    
         
             
                SWAKVPool,
         
     | 
| 
       60 
60 
     | 
    
         
             
            )
         
     | 
| 
      
 61 
     | 
    
         
            +
            from sglang.srt.tracing.trace import (
         
     | 
| 
      
 62 
     | 
    
         
            +
                trace_event_batch,
         
     | 
| 
      
 63 
     | 
    
         
            +
                trace_slice_batch,
         
     | 
| 
      
 64 
     | 
    
         
            +
                trace_slice_end,
         
     | 
| 
      
 65 
     | 
    
         
            +
            )
         
     | 
| 
       61 
66 
     | 
    
         
             
            from sglang.srt.utils import get_int_env_var, require_mlp_sync
         
     | 
| 
       62 
67 
     | 
    
         
             
            from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
         
     | 
| 
       63 
68 
     | 
    
         | 
| 
         @@ -313,6 +318,7 @@ class DecodePreallocQueue: 
     | 
|
| 
       313 
318 
     | 
    
         
             
                        )
         
     | 
| 
       314 
319 
     | 
    
         | 
| 
       315 
320 
     | 
    
         
             
                        req.add_latency(RequestStage.DECODE_PREPARE)
         
     | 
| 
      
 321 
     | 
    
         
            +
                        trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
         
     | 
| 
       316 
322 
     | 
    
         
             
                        self.queue.append(
         
     | 
| 
       317 
323 
     | 
    
         
             
                            DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
         
     | 
| 
       318 
324 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -521,13 +527,15 @@ class DecodePreallocQueue: 
     | 
|
| 
       521 
527 
     | 
    
         
             
                        decode_req.kv_receiver.init(
         
     | 
| 
       522 
528 
     | 
    
         
             
                            page_indices, decode_req.metadata_buffer_index, state_indices
         
     | 
| 
       523 
529 
     | 
    
         
             
                        )
         
     | 
| 
       524 
     | 
    
         
            -
                        decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
         
     | 
| 
       525 
530 
     | 
    
         
             
                        preallocated_reqs.append(decode_req)
         
     | 
| 
       526 
531 
     | 
    
         
             
                        indices_to_remove.add(i)
         
     | 
| 
       527 
532 
     | 
    
         
             
                        decode_req.req.time_stats.decode_transfer_queue_entry_time = (
         
     | 
| 
       528 
533 
     | 
    
         
             
                            time.perf_counter()
         
     | 
| 
       529 
534 
     | 
    
         
             
                        )
         
     | 
| 
       530 
535 
     | 
    
         
             
                        decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
         
     | 
| 
      
 536 
     | 
    
         
            +
                        trace_slice_end(
         
     | 
| 
      
 537 
     | 
    
         
            +
                            RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
         
     | 
| 
      
 538 
     | 
    
         
            +
                        )
         
     | 
| 
       531 
539 
     | 
    
         | 
| 
       532 
540 
     | 
    
         
             
                    self.queue = [
         
     | 
| 
       533 
541 
     | 
    
         
             
                        entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
         
     | 
| 
         @@ -765,8 +773,12 @@ class DecodeTransferQueue: 
     | 
|
| 
       765 
773 
     | 
    
         
             
                            indices_to_remove.add(i)
         
     | 
| 
       766 
774 
     | 
    
         
             
                            decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
         
     | 
| 
       767 
775 
     | 
    
         | 
| 
       768 
     | 
    
         
            -
                            # special handling for  
     | 
| 
       769 
     | 
    
         
            -
                             
     | 
| 
      
 776 
     | 
    
         
            +
                            # special handling for corner cases
         
     | 
| 
      
 777 
     | 
    
         
            +
                            should_finish = (
         
     | 
| 
      
 778 
     | 
    
         
            +
                                decode_req.req.sampling_params.max_new_tokens == 1
         
     | 
| 
      
 779 
     | 
    
         
            +
                                or output_id in decode_req.req.eos_token_ids
         
     | 
| 
      
 780 
     | 
    
         
            +
                            )
         
     | 
| 
      
 781 
     | 
    
         
            +
                            if should_finish:
         
     | 
| 
       770 
782 
     | 
    
         
             
                                # finish immediately
         
     | 
| 
       771 
783 
     | 
    
         
             
                                decode_req.req.time_stats.forward_entry_time = (
         
     | 
| 
       772 
784 
     | 
    
         
             
                                    decode_req.req.time_stats.completion_time
         
     | 
| 
         @@ -776,8 +788,19 @@ class DecodeTransferQueue: 
     | 
|
| 
       776 
788 
     | 
    
         
             
                                    [decode_req.req], decode_req.req.return_logprob
         
     | 
| 
       777 
789 
     | 
    
         
             
                                )
         
     | 
| 
       778 
790 
     | 
    
         
             
                                self.tree_cache.cache_finished_req(decode_req.req)
         
     | 
| 
      
 791 
     | 
    
         
            +
                                trace_slice_end(
         
     | 
| 
      
 792 
     | 
    
         
            +
                                    RequestStage.DECODE_QUICK_FINISH,
         
     | 
| 
      
 793 
     | 
    
         
            +
                                    decode_req.req.rid,
         
     | 
| 
      
 794 
     | 
    
         
            +
                                    thread_finish_flag=True,
         
     | 
| 
      
 795 
     | 
    
         
            +
                                )
         
     | 
| 
       779 
796 
     | 
    
         
             
                            else:
         
     | 
| 
       780 
797 
     | 
    
         
             
                                transferred_reqs.append(decode_req.req)
         
     | 
| 
      
 798 
     | 
    
         
            +
                                trace_slice_end(
         
     | 
| 
      
 799 
     | 
    
         
            +
                                    RequestStage.DECODE_TRANSFERRED,
         
     | 
| 
      
 800 
     | 
    
         
            +
                                    decode_req.req.rid,
         
     | 
| 
      
 801 
     | 
    
         
            +
                                    auto_next_anon=True,
         
     | 
| 
      
 802 
     | 
    
         
            +
                                )
         
     | 
| 
      
 803 
     | 
    
         
            +
             
     | 
| 
       781 
804 
     | 
    
         
             
                        elif poll in [
         
     | 
| 
       782 
805 
     | 
    
         
             
                            KVPoll.Bootstrapping,
         
     | 
| 
       783 
806 
     | 
    
         
             
                            KVPoll.WaitingForInput,
         
     | 
| 
         @@ -823,6 +846,7 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       823 
846 
     | 
    
         
             
                                self.stream_output(
         
     | 
| 
       824 
847 
     | 
    
         
             
                                    batch.reqs, any(req.return_logprob for req in batch.reqs)
         
     | 
| 
       825 
848 
     | 
    
         
             
                                )
         
     | 
| 
      
 849 
     | 
    
         
            +
                                trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
         
     | 
| 
       826 
850 
     | 
    
         
             
                                if prepare_mlp_sync_flag:
         
     | 
| 
       827 
851 
     | 
    
         
             
                                    self._prepare_idle_batch_and_run(None)
         
     | 
| 
       828 
852 
     | 
    
         
             
                            else:
         
     | 
| 
         @@ -872,6 +896,7 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       872 
896 
     | 
    
         
             
                                self.stream_output(
         
     | 
| 
       873 
897 
     | 
    
         
             
                                    batch.reqs, any(req.return_logprob for req in batch.reqs)
         
     | 
| 
       874 
898 
     | 
    
         
             
                                )
         
     | 
| 
      
 899 
     | 
    
         
            +
                                trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
         
     | 
| 
       875 
900 
     | 
    
         
             
                                if prepare_mlp_sync_flag:
         
     | 
| 
       876 
901 
     | 
    
         
             
                                    batch_, batch_result = self._prepare_idle_batch_and_run(
         
     | 
| 
       877 
902 
     | 
    
         
             
                                        None, delay_process=True
         
     | 
| 
         @@ -954,6 +979,9 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       954 
979 
     | 
    
         
             
                            self.running_batch = self.update_running_batch(self.running_batch)
         
     | 
| 
       955 
980 
     | 
    
         
             
                            ret = self.running_batch if not self.running_batch.is_empty() else None
         
     | 
| 
       956 
981 
     | 
    
         | 
| 
      
 982 
     | 
    
         
            +
                    if ret:
         
     | 
| 
      
 983 
     | 
    
         
            +
                        attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
         
     | 
| 
      
 984 
     | 
    
         
            +
                        trace_event_batch("schedule", ret.reqs, attrs=attrs)
         
     | 
| 
       957 
985 
     | 
    
         
             
                    return ret
         
     | 
| 
       958 
986 
     | 
    
         | 
| 
       959 
987 
     | 
    
         
             
                def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
         
     | 
| 
         @@ -1009,6 +1037,9 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       1009 
1037 
     | 
    
         
             
                    return new_batch
         
     | 
| 
       1010 
1038 
     | 
    
         | 
| 
       1011 
1039 
     | 
    
         
             
                def process_decode_queue(self: Scheduler):
         
     | 
| 
      
 1040 
     | 
    
         
            +
                    if self.server_args.disaggregation_decode_enable_offload_kvcache:
         
     | 
| 
      
 1041 
     | 
    
         
            +
                        self.decode_offload_manager.check_offload_progress()
         
     | 
| 
      
 1042 
     | 
    
         
            +
             
     | 
| 
       1012 
1043 
     | 
    
         
             
                    # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
         
     | 
| 
       1013 
1044 
     | 
    
         
             
                    resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
         
     | 
| 
       1014 
1045 
     | 
    
         
             
                    self.waiting_queue.extend(resumed_reqs)
         
     | 
| 
         @@ -1031,6 +1062,3 @@ class SchedulerDisaggregationDecodeMixin: 
     | 
|
| 
       1031 
1062 
     | 
    
         
             
                            self.disagg_decode_transfer_queue.pop_transferred()
         
     | 
| 
       1032 
1063 
     | 
    
         
             
                        )  # the requests which kv has arrived
         
     | 
| 
       1033 
1064 
     | 
    
         
             
                        self.waiting_queue.extend(alloc_reqs)
         
     | 
| 
       1034 
     | 
    
         
            -
             
     | 
| 
       1035 
     | 
    
         
            -
                    if self.server_args.disaggregation_decode_enable_offload_kvcache:
         
     | 
| 
       1036 
     | 
    
         
            -
                        self.decode_offload_manager.check_offload_progress()
         
     | 
| 
         @@ -231,8 +231,8 @@ class NixlKVManager(CommonKVManager): 
     | 
|
| 
       231 
231 
     | 
    
         
             
                        ]
         
     | 
| 
       232 
232 
     | 
    
         
             
                        for k in keys_to_remove:
         
     | 
| 
       233 
233 
     | 
    
         
             
                            del self.connection_pool[k]
         
     | 
| 
       234 
     | 
    
         
            -
                        if failed_bootstrap_addr in self. 
     | 
| 
       235 
     | 
    
         
            -
                            del self. 
     | 
| 
      
 234 
     | 
    
         
            +
                        if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
         
     | 
| 
      
 235 
     | 
    
         
            +
                            del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
         
     | 
| 
       236 
236 
     | 
    
         
             
                        if failed_bootstrap_addr in self.prefill_dp_size_table:
         
     | 
| 
       237 
237 
     | 
    
         
             
                            del self.prefill_dp_size_table[failed_bootstrap_addr]
         
     | 
| 
       238 
238 
     | 
    
         
             
                        if failed_bootstrap_addr in self.prefill_pp_size_table:
         
     | 
| 
         @@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( 
     | 
|
| 
       53 
53 
     | 
    
         
             
                NSATokenToKVPool,
         
     | 
| 
       54 
54 
     | 
    
         
             
                SWAKVPool,
         
     | 
| 
       55 
55 
     | 
    
         
             
            )
         
     | 
| 
      
 56 
     | 
    
         
            +
            from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
         
     | 
| 
       56 
57 
     | 
    
         
             
            from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
         
     | 
| 
       57 
58 
     | 
    
         | 
| 
       58 
59 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
         @@ -198,6 +199,7 @@ class PrefillBootstrapQueue: 
     | 
|
| 
       198 
199 
     | 
    
         
             
                    self._process_req(req)
         
     | 
| 
       199 
200 
     | 
    
         
             
                    req.add_latency(RequestStage.PREFILL_PREPARE)
         
     | 
| 
       200 
201 
     | 
    
         
             
                    self.queue.append(req)
         
     | 
| 
      
 202 
     | 
    
         
            +
                    trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
         
     | 
| 
       201 
203 
     | 
    
         | 
| 
       202 
204 
     | 
    
         
             
                def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
         
     | 
| 
       203 
205 
     | 
    
         
             
                    for req in reqs:
         
     | 
| 
         @@ -289,6 +291,10 @@ class PrefillBootstrapQueue: 
     | 
|
| 
       289 
291 
     | 
    
         
             
                        req.time_stats.wait_queue_entry_time = time.perf_counter()
         
     | 
| 
       290 
292 
     | 
    
         
             
                        req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
         
     | 
| 
       291 
293 
     | 
    
         | 
| 
      
 294 
     | 
    
         
            +
                        trace_slice_end(
         
     | 
| 
      
 295 
     | 
    
         
            +
                            RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
         
     | 
| 
      
 296 
     | 
    
         
            +
                        )
         
     | 
| 
      
 297 
     | 
    
         
            +
             
     | 
| 
       292 
298 
     | 
    
         
             
                    self.queue = [
         
     | 
| 
       293 
299 
     | 
    
         
             
                        entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
         
     | 
| 
       294 
300 
     | 
    
         
             
                    ]
         
     | 
| 
         @@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       316 
322 
     | 
    
         
             
                        )
         
     | 
| 
       317 
323 
     | 
    
         
             
                        self.process_prefill_chunk()
         
     | 
| 
       318 
324 
     | 
    
         
             
                        batch = self.get_new_batch_prefill()
         
     | 
| 
      
 325 
     | 
    
         
            +
                        if batch:
         
     | 
| 
      
 326 
     | 
    
         
            +
                            attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
         
     | 
| 
      
 327 
     | 
    
         
            +
                            trace_event_batch("schedule", batch.reqs, attrs=attrs)
         
     | 
| 
       319 
328 
     | 
    
         | 
| 
       320 
329 
     | 
    
         
             
                        if require_mlp_sync(self.server_args):
         
     | 
| 
       321 
330 
     | 
    
         
             
                            batch = self.prepare_mlp_sync_batch(batch)
         
     | 
| 
         @@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       348 
357 
     | 
    
         
             
                        )
         
     | 
| 
       349 
358 
     | 
    
         
             
                        self.process_prefill_chunk()
         
     | 
| 
       350 
359 
     | 
    
         
             
                        batch = self.get_new_batch_prefill()
         
     | 
| 
      
 360 
     | 
    
         
            +
                        if batch:
         
     | 
| 
      
 361 
     | 
    
         
            +
                            attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
         
     | 
| 
      
 362 
     | 
    
         
            +
                            trace_event_batch("schedule", batch.reqs, attrs=attrs)
         
     | 
| 
       351 
363 
     | 
    
         | 
| 
       352 
364 
     | 
    
         
             
                        if require_mlp_sync(self.server_args):
         
     | 
| 
       353 
365 
     | 
    
         
             
                            batch = self.prepare_mlp_sync_batch(batch)
         
     | 
| 
         @@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       423 
435 
     | 
    
         
             
                            req.output_ids.append(next_token_id)
         
     | 
| 
       424 
436 
     | 
    
         
             
                            self.tree_cache.cache_unfinished_req(req)  # update the tree and lock
         
     | 
| 
       425 
437 
     | 
    
         
             
                            req.add_latency(RequestStage.PREFILL_FORWARD)
         
     | 
| 
      
 438 
     | 
    
         
            +
                            trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
         
     | 
| 
       426 
439 
     | 
    
         
             
                            self.disagg_prefill_inflight_queue.append(req)
         
     | 
| 
       427 
440 
     | 
    
         
             
                            if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
         
     | 
| 
       428 
441 
     | 
    
         
             
                                req.output_topk_p = batch.spec_info.topk_p[i]
         
     | 
| 
         @@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       487 
500 
     | 
    
         | 
| 
       488 
501 
     | 
    
         
             
                            if self.enable_overlap:
         
     | 
| 
       489 
502 
     | 
    
         
             
                                self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
         
     | 
| 
      
 503 
     | 
    
         
            +
                            trace_slice(
         
     | 
| 
      
 504 
     | 
    
         
            +
                                RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
         
     | 
| 
      
 505 
     | 
    
         
            +
                            )
         
     | 
| 
       490 
506 
     | 
    
         | 
| 
       491 
507 
     | 
    
         
             
                    self.maybe_send_health_check_signal()
         
     | 
| 
       492 
508 
     | 
    
         | 
| 
         @@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       558 
574 
     | 
    
         
             
                        req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
         
     | 
| 
       559 
575 
     | 
    
         
             
                        self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
         
     | 
| 
       560 
576 
     | 
    
         
             
                        req.metadata_buffer_index = -1
         
     | 
| 
      
 577 
     | 
    
         
            +
                        trace_slice(
         
     | 
| 
      
 578 
     | 
    
         
            +
                            RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
         
     | 
| 
      
 579 
     | 
    
         
            +
                        )
         
     | 
| 
       561 
580 
     | 
    
         | 
| 
       562 
581 
     | 
    
         
             
                    self.disagg_prefill_inflight_queue = undone_reqs
         
     | 
| 
       563 
582 
     | 
    
         | 
| 
         @@ -569,7 +588,7 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       569 
588 
     | 
    
         
             
                    """
         
     | 
| 
       570 
589 
     | 
    
         
             
                    polls = poll_and_all_reduce(
         
     | 
| 
       571 
590 
     | 
    
         
             
                        [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
         
     | 
| 
       572 
     | 
    
         
            -
                        self.tp_worker. 
     | 
| 
      
 591 
     | 
    
         
            +
                        self.tp_worker.get_attention_tp_cpu_group(),
         
     | 
| 
       573 
592 
     | 
    
         
             
                    )
         
     | 
| 
       574 
593 
     | 
    
         | 
| 
       575 
594 
     | 
    
         
             
                    transferred_rids: List[str] = []
         
     | 
| 
         @@ -703,8 +722,11 @@ class SchedulerDisaggregationPrefillMixin: 
     | 
|
| 
       703 
722 
     | 
    
         
             
                    else:
         
     | 
| 
       704 
723 
     | 
    
         
             
                        data = None
         
     | 
| 
       705 
724 
     | 
    
         | 
| 
       706 
     | 
    
         
            -
                    if self. 
     | 
| 
      
 725 
     | 
    
         
            +
                    if self.attn_tp_size != 1:
         
     | 
| 
       707 
726 
     | 
    
         
             
                        data = broadcast_pyobj(
         
     | 
| 
       708 
     | 
    
         
            -
                            data, 
     | 
| 
      
 727 
     | 
    
         
            +
                            data,
         
     | 
| 
      
 728 
     | 
    
         
            +
                            self.attn_tp_group.rank,
         
     | 
| 
      
 729 
     | 
    
         
            +
                            self.attn_tp_cpu_group,
         
     | 
| 
      
 730 
     | 
    
         
            +
                            src=self.attn_tp_group.ranks[0],
         
     | 
| 
       709 
731 
     | 
    
         
             
                        )
         
     | 
| 
       710 
732 
     | 
    
         
             
                    return data
         
     | 
| 
         @@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import 
     | 
|
| 
       18 
18 
     | 
    
         
             
                is_weak_contiguous,
         
     | 
| 
       19 
19 
     | 
    
         
             
            )
         
     | 
| 
       20 
20 
     | 
    
         
             
            from sglang.srt.distributed.parallel_state import in_the_same_node_as
         
     | 
| 
      
 21 
     | 
    
         
            +
            from sglang.srt.environ import envs
         
     | 
| 
       21 
22 
     | 
    
         
             
            from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
         
     | 
| 
       22 
23 
     | 
    
         | 
| 
       23 
24 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
         @@ -210,6 +211,7 @@ class CustomAllreduce: 
     | 
|
| 
       210 
211 
     | 
    
         
             
                        self.register_buffer(self.buffer)
         
     | 
| 
       211 
212 
     | 
    
         | 
| 
       212 
213 
     | 
    
         
             
                    self.disabled = False
         
     | 
| 
      
 214 
     | 
    
         
            +
                    self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
         
     | 
| 
       213 
215 
     | 
    
         | 
| 
       214 
216 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       215 
217 
     | 
    
         
             
                def create_shared_buffer(
         
     | 
| 
         @@ -394,7 +396,7 @@ class CustomAllreduce: 
     | 
|
| 
       394 
396 
     | 
    
         
             
                            if _is_hip:
         
     | 
| 
       395 
397 
     | 
    
         
             
                                return self.all_reduce_reg(input)
         
     | 
| 
       396 
398 
     | 
    
         
             
                            else:
         
     | 
| 
       397 
     | 
    
         
            -
                                return self.all_reduce(input, registered= 
     | 
| 
      
 399 
     | 
    
         
            +
                                return self.all_reduce(input, registered=not self.tms_cudagraph)
         
     | 
| 
       398 
400 
     | 
    
         
             
                        else:
         
     | 
| 
       399 
401 
     | 
    
         
             
                            # If warm up, mimic the allocation pattern since custom
         
     | 
| 
       400 
402 
     | 
    
         
             
                            # allreduce is out-of-place.
         
     | 
| 
         @@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) 
     | 
|
| 
       68 
68 
     | 
    
         | 
| 
       69 
69 
     | 
    
         
             
            @dataclass
         
     | 
| 
       70 
70 
     | 
    
         
             
            class GraphCaptureContext:
         
     | 
| 
       71 
     | 
    
         
            -
                stream: torch. 
     | 
| 
      
 71 
     | 
    
         
            +
                stream: torch.get_device_module().Stream
         
     | 
| 
       72 
72 
     | 
    
         | 
| 
       73 
73 
     | 
    
         | 
| 
       74 
74 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -498,7 +498,7 @@ class GroupCoordinator: 
     | 
|
| 
       498 
498 
     | 
    
         
             
                            maybe_pynccl_context = nullcontext()
         
     | 
| 
       499 
499 
     | 
    
         
             
                        else:
         
     | 
| 
       500 
500 
     | 
    
         
             
                            maybe_pynccl_context = pynccl_comm.change_state(
         
     | 
| 
       501 
     | 
    
         
            -
                                enable=True, stream=torch. 
     | 
| 
      
 501 
     | 
    
         
            +
                                enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
       502 
502 
     | 
    
         
             
                            )
         
     | 
| 
       503 
503 
     | 
    
         | 
| 
       504 
504 
     | 
    
         
             
                        pymscclpp_comm = self.pymscclpp_comm
         
     | 
| 
         @@ -555,7 +555,7 @@ class GroupCoordinator: 
     | 
|
| 
       555 
555 
     | 
    
         
             
                        and input_.symmetric_memory
         
     | 
| 
       556 
556 
     | 
    
         
             
                    ):
         
     | 
| 
       557 
557 
     | 
    
         
             
                        with self.pynccl_comm.change_state(
         
     | 
| 
       558 
     | 
    
         
            -
                            enable=True, stream=torch. 
     | 
| 
      
 558 
     | 
    
         
            +
                            enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
       559 
559 
     | 
    
         
             
                        ):
         
     | 
| 
       560 
560 
     | 
    
         
             
                            self.pynccl_comm.all_reduce(input_)
         
     | 
| 
       561 
561 
     | 
    
         
             
                            return input_
         
     | 
| 
         @@ -655,7 +655,9 @@ class GroupCoordinator: 
     | 
|
| 
       655 
655 
     | 
    
         
             
                    world_size = self.world_size
         
     | 
| 
       656 
656 
     | 
    
         
             
                    pynccl_comm = self.pynccl_comm
         
     | 
| 
       657 
657 
     | 
    
         | 
| 
       658 
     | 
    
         
            -
                    with pynccl_comm.change_state( 
     | 
| 
      
 658 
     | 
    
         
            +
                    with pynccl_comm.change_state(
         
     | 
| 
      
 659 
     | 
    
         
            +
                        enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
      
 660 
     | 
    
         
            +
                    ):
         
     | 
| 
       659 
661 
     | 
    
         
             
                        assert (
         
     | 
| 
       660 
662 
     | 
    
         
             
                            pynccl_comm is not None and not pynccl_comm.disabled
         
     | 
| 
       661 
663 
     | 
    
         
             
                        ), "pynccl is required for reduce_scatterv"
         
     | 
| 
         @@ -779,7 +781,9 @@ class GroupCoordinator: 
     | 
|
| 
       779 
781 
     | 
    
         
             
                    world_size = self.world_size
         
     | 
| 
       780 
782 
     | 
    
         
             
                    pynccl_comm = self.pynccl_comm
         
     | 
| 
       781 
783 
     | 
    
         | 
| 
       782 
     | 
    
         
            -
                    with pynccl_comm.change_state( 
     | 
| 
      
 784 
     | 
    
         
            +
                    with pynccl_comm.change_state(
         
     | 
| 
      
 785 
     | 
    
         
            +
                        enable=True, stream=torch.get_device_module().current_stream()
         
     | 
| 
      
 786 
     | 
    
         
            +
                    ):
         
     | 
| 
       783 
787 
     | 
    
         
             
                        assert (
         
     | 
| 
       784 
788 
     | 
    
         
             
                            pynccl_comm is not None and not pynccl_comm.disabled
         
     | 
| 
       785 
789 
     | 
    
         
             
                        ), "pynccl is required for all_gatherv"
         
     | 
    
        sglang/srt/entrypoints/engine.py
    CHANGED
    
    | 
         @@ -143,10 +143,13 @@ class Engine(EngineBase): 
     | 
|
| 
       143 
143 
     | 
    
         | 
| 
       144 
144 
     | 
    
         
             
                    # Enable tracing
         
     | 
| 
       145 
145 
     | 
    
         
             
                    if server_args.enable_trace:
         
     | 
| 
       146 
     | 
    
         
            -
                        process_tracing_init(server_args. 
     | 
| 
       147 
     | 
    
         
            -
                         
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
                             
     | 
| 
      
 146 
     | 
    
         
            +
                        process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
         
     | 
| 
      
 147 
     | 
    
         
            +
                        thread_label = "Tokenizer"
         
     | 
| 
      
 148 
     | 
    
         
            +
                        if server_args.disaggregation_mode == "prefill":
         
     | 
| 
      
 149 
     | 
    
         
            +
                            thread_label = "Prefill Tokenizer"
         
     | 
| 
      
 150 
     | 
    
         
            +
                        elif server_args.disaggregation_mode == "decode":
         
     | 
| 
      
 151 
     | 
    
         
            +
                            thread_label = "Decode Tokenizer"
         
     | 
| 
      
 152 
     | 
    
         
            +
                        trace_set_thread_info(thread_label)
         
     | 
| 
       150 
153 
     | 
    
         | 
| 
       151 
154 
     | 
    
         
             
                    try:
         
     | 
| 
       152 
155 
     | 
    
         
             
                        self.loop = asyncio.get_running_loop()
         
     | 
| 
         @@ -312,6 +315,7 @@ class Engine(EngineBase): 
     | 
|
| 
       312 
315 
     | 
    
         
             
                    image_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
       313 
316 
     | 
    
         
             
                    audio_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
       314 
317 
     | 
    
         
             
                    video_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
      
 318 
     | 
    
         
            +
                    dimensions: Optional[int] = None,
         
     | 
| 
       315 
319 
     | 
    
         
             
                ) -> Dict:
         
     | 
| 
       316 
320 
     | 
    
         
             
                    """
         
     | 
| 
       317 
321 
     | 
    
         
             
                    The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
         
     | 
| 
         @@ -322,6 +326,7 @@ class Engine(EngineBase): 
     | 
|
| 
       322 
326 
     | 
    
         
             
                        image_data=image_data,
         
     | 
| 
       323 
327 
     | 
    
         
             
                        audio_data=audio_data,
         
     | 
| 
       324 
328 
     | 
    
         
             
                        video_data=video_data,
         
     | 
| 
      
 329 
     | 
    
         
            +
                        dimensions=dimensions,
         
     | 
| 
       325 
330 
     | 
    
         
             
                    )
         
     | 
| 
       326 
331 
     | 
    
         
             
                    generator = self.tokenizer_manager.generate_request(obj, None)
         
     | 
| 
       327 
332 
     | 
    
         
             
                    ret = self.loop.run_until_complete(generator.__anext__())
         
     | 
| 
         @@ -333,6 +338,7 @@ class Engine(EngineBase): 
     | 
|
| 
       333 
338 
     | 
    
         
             
                    image_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
       334 
339 
     | 
    
         
             
                    audio_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
       335 
340 
     | 
    
         
             
                    video_data: Optional[MultimodalDataInputFormat] = None,
         
     | 
| 
      
 341 
     | 
    
         
            +
                    dimensions: Optional[int] = None,
         
     | 
| 
       336 
342 
     | 
    
         
             
                ) -> Dict:
         
     | 
| 
       337 
343 
     | 
    
         
             
                    """
         
     | 
| 
       338 
344 
     | 
    
         
             
                    Asynchronous version of encode method.
         
     | 
| 
         @@ -345,6 +351,7 @@ class Engine(EngineBase): 
     | 
|
| 
       345 
351 
     | 
    
         
             
                        image_data=image_data,
         
     | 
| 
       346 
352 
     | 
    
         
             
                        audio_data=audio_data,
         
     | 
| 
       347 
353 
     | 
    
         
             
                        video_data=video_data,
         
     | 
| 
      
 354 
     | 
    
         
            +
                        dimensions=dimensions,
         
     | 
| 
       348 
355 
     | 
    
         
             
                    )
         
     | 
| 
       349 
356 
     | 
    
         
             
                    generator = self.tokenizer_manager.generate_request(obj, None)
         
     | 
| 
       350 
357 
     | 
    
         
             
                    return await generator.__anext__()
         
     | 
| 
         @@ -670,7 +677,8 @@ class Engine(EngineBase): 
     | 
|
| 
       670 
677 
     | 
    
         
             
            def _set_envs_and_config(server_args: ServerArgs):
         
     | 
| 
       671 
678 
     | 
    
         
             
                # Set global environments
         
     | 
| 
       672 
679 
     | 
    
         
             
                os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
         
     | 
| 
       673 
     | 
    
         
            -
                 
     | 
| 
      
 680 
     | 
    
         
            +
                if "NCCL_CUMEM_ENABLE" not in os.environ:
         
     | 
| 
      
 681 
     | 
    
         
            +
                    os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
         
     | 
| 
       674 
682 
     | 
    
         
             
                if not server_args.enable_symm_mem:
         
     | 
| 
       675 
683 
     | 
    
         
             
                    os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
         
     | 
| 
       676 
684 
     | 
    
         
             
                os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
         
     | 
| 
         @@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI): 
     | 
|
| 
       220 
220 
     | 
    
         | 
| 
       221 
221 
     | 
    
         
             
                # Init tracing
         
     | 
| 
       222 
222 
     | 
    
         
             
                if server_args.enable_trace:
         
     | 
| 
       223 
     | 
    
         
            -
                    process_tracing_init(server_args. 
     | 
| 
       224 
     | 
    
         
            -
                    if server_args.disaggregation_mode == " 
     | 
| 
       225 
     | 
    
         
            -
                         
     | 
| 
      
 223 
     | 
    
         
            +
                    process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
         
     | 
| 
      
 224 
     | 
    
         
            +
                    if server_args.disaggregation_mode == "prefill":
         
     | 
| 
      
 225 
     | 
    
         
            +
                        thread_label = "Prefill" + thread_label
         
     | 
| 
      
 226 
     | 
    
         
            +
                    elif server_args.disaggregation_mode == "decode":
         
     | 
| 
      
 227 
     | 
    
         
            +
                        thread_label = "Decode" + thread_label
         
     | 
| 
      
 228 
     | 
    
         
            +
                    trace_set_thread_info(thread_label)
         
     | 
| 
       226 
229 
     | 
    
         | 
| 
       227 
230 
     | 
    
         
             
                # Initialize OpenAI serving handlers
         
     | 
| 
       228 
231 
     | 
    
         
             
                fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
         
     | 
| 
         @@ -1168,6 +1171,8 @@ async def available_models(): 
     | 
|
| 
       1168 
1171 
     | 
    
         
             
                """Show available models. OpenAI-compatible endpoint."""
         
     | 
| 
       1169 
1172 
     | 
    
         
             
                served_model_names = [_global_state.tokenizer_manager.served_model_name]
         
     | 
| 
       1170 
1173 
     | 
    
         
             
                model_cards = []
         
     | 
| 
      
 1174 
     | 
    
         
            +
             
     | 
| 
      
 1175 
     | 
    
         
            +
                # Add base model
         
     | 
| 
       1171 
1176 
     | 
    
         
             
                for served_model_name in served_model_names:
         
     | 
| 
       1172 
1177 
     | 
    
         
             
                    model_cards.append(
         
     | 
| 
       1173 
1178 
     | 
    
         
             
                        ModelCard(
         
     | 
| 
         @@ -1176,6 +1181,20 @@ async def available_models(): 
     | 
|
| 
       1176 
1181 
     | 
    
         
             
                            max_model_len=_global_state.tokenizer_manager.model_config.context_len,
         
     | 
| 
       1177 
1182 
     | 
    
         
             
                        )
         
     | 
| 
       1178 
1183 
     | 
    
         
             
                    )
         
     | 
| 
      
 1184 
     | 
    
         
            +
             
     | 
| 
      
 1185 
     | 
    
         
            +
                # Add loaded LoRA adapters
         
     | 
| 
      
 1186 
     | 
    
         
            +
                if _global_state.tokenizer_manager.server_args.enable_lora:
         
     | 
| 
      
 1187 
     | 
    
         
            +
                    lora_registry = _global_state.tokenizer_manager.lora_registry
         
     | 
| 
      
 1188 
     | 
    
         
            +
                    for _, lora_ref in lora_registry.get_all_adapters().items():
         
     | 
| 
      
 1189 
     | 
    
         
            +
                        model_cards.append(
         
     | 
| 
      
 1190 
     | 
    
         
            +
                            ModelCard(
         
     | 
| 
      
 1191 
     | 
    
         
            +
                                id=lora_ref.lora_name,
         
     | 
| 
      
 1192 
     | 
    
         
            +
                                root=lora_ref.lora_path,
         
     | 
| 
      
 1193 
     | 
    
         
            +
                                parent=served_model_names[0],
         
     | 
| 
      
 1194 
     | 
    
         
            +
                                max_model_len=None,
         
     | 
| 
      
 1195 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1196 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1197 
     | 
    
         
            +
             
     | 
| 
       1179 
1198 
     | 
    
         
             
                return ModelList(data=model_cards)
         
     | 
| 
       1180 
1199 
     | 
    
         | 
| 
       1181 
1200 
     | 
    
         | 
| 
         @@ -37,7 +37,11 @@ from pydantic import ( 
     | 
|
| 
       37 
37 
     | 
    
         
             
                model_validator,
         
     | 
| 
       38 
38 
     | 
    
         
             
            )
         
     | 
| 
       39 
39 
     | 
    
         
             
            from typing_extensions import Literal
         
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
            try:
         
     | 
| 
      
 42 
     | 
    
         
            +
                from xgrammar import StructuralTag
         
     | 
| 
      
 43 
     | 
    
         
            +
            except:
         
     | 
| 
      
 44 
     | 
    
         
            +
                StructuralTag = Any
         
     | 
| 
       41 
45 
     | 
    
         | 
| 
       42 
46 
     | 
    
         
             
            from sglang.utils import convert_json_schema_to_str
         
     | 
| 
       43 
47 
     | 
    
         | 
| 
         @@ -54,6 +58,7 @@ class ModelCard(BaseModel): 
     | 
|
| 
       54 
58 
     | 
    
         
             
                created: int = Field(default_factory=lambda: int(time.time()))
         
     | 
| 
       55 
59 
     | 
    
         
             
                owned_by: str = "sglang"
         
     | 
| 
       56 
60 
     | 
    
         
             
                root: Optional[str] = None
         
     | 
| 
      
 61 
     | 
    
         
            +
                parent: Optional[str] = None
         
     | 
| 
       57 
62 
     | 
    
         
             
                max_model_len: Optional[int] = None
         
     | 
| 
       58 
63 
     | 
    
         | 
| 
       59 
64 
     | 
    
         | 
| 
         @@ -108,6 +113,7 @@ class UsageInfo(BaseModel): 
     | 
|
| 
       108 
113 
     | 
    
         | 
| 
       109 
114 
     | 
    
         
             
            class StreamOptions(BaseModel):
         
     | 
| 
       110 
115 
     | 
    
         
             
                include_usage: Optional[bool] = False
         
     | 
| 
      
 116 
     | 
    
         
            +
                continuous_usage_stats: Optional[bool] = False
         
     | 
| 
       111 
117 
     | 
    
         | 
| 
       112 
118 
     | 
    
         | 
| 
       113 
119 
     | 
    
         
             
            class JsonSchemaResponseFormat(BaseModel):
         
     |