sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -85,7 +85,7 @@ def execute_sbo( 
     | 
|
| 
       85 
85 
     | 
    
         
             
                    _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
         
     | 
| 
       86 
86 
     | 
    
         
             
                )
         
     | 
| 
       87 
87 
     | 
    
         | 
| 
       88 
     | 
    
         
            -
                 
     | 
| 
      
 88 
     | 
    
         
            +
                combine_input = experts.run_moe_core(
         
     | 
| 
       89 
89 
     | 
    
         
             
                    dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
         
     | 
| 
       90 
90 
     | 
    
         
             
                )
         
     | 
| 
       91 
91 
     | 
    
         
             
                if (e := meta_overlap_args.get("record_event_after_down")) is not None:
         
     | 
| 
         @@ -98,12 +98,7 @@ def execute_sbo( 
     | 
|
| 
       98 
98 
     | 
    
         
             
                    ):
         
     | 
| 
       99 
99 
     | 
    
         
             
                        forward_shared_experts()
         
     | 
| 
       100 
100 
     | 
    
         | 
| 
       101 
     | 
    
         
            -
                hidden_states = experts.dispatcher.combine(
         
     | 
| 
       102 
     | 
    
         
            -
                    hidden_states=hidden_states,
         
     | 
| 
       103 
     | 
    
         
            -
                    topk_ids=dispatch_output.topk_ids,
         
     | 
| 
       104 
     | 
    
         
            -
                    topk_weights=dispatch_output.topk_weights,
         
     | 
| 
       105 
     | 
    
         
            -
                    overlap_args=combine_overlap_args,
         
     | 
| 
       106 
     | 
    
         
            -
                )
         
     | 
| 
      
 101 
     | 
    
         
            +
                hidden_states = experts.dispatcher.combine(combine_input=combine_input)
         
     | 
| 
       107 
102 
     | 
    
         | 
| 
       108 
103 
     | 
    
         
             
                return hidden_states
         
     | 
| 
       109 
104 
     | 
    
         | 
    
        sglang/srt/utils/common.py
    CHANGED
    
    | 
         @@ -56,7 +56,6 @@ from json import JSONDecodeError 
     | 
|
| 
       56 
56 
     | 
    
         
             
            from multiprocessing.reduction import ForkingPickler
         
     | 
| 
       57 
57 
     | 
    
         
             
            from pathlib import Path
         
     | 
| 
       58 
58 
     | 
    
         
             
            from typing import (
         
     | 
| 
       59 
     | 
    
         
            -
                TYPE_CHECKING,
         
     | 
| 
       60 
59 
     | 
    
         
             
                Any,
         
     | 
| 
       61 
60 
     | 
    
         
             
                Callable,
         
     | 
| 
       62 
61 
     | 
    
         
             
                Dict,
         
     | 
| 
         @@ -94,9 +93,6 @@ from typing_extensions import Literal 
     | 
|
| 
       94 
93 
     | 
    
         
             
            from sglang.srt.environ import envs
         
     | 
| 
       95 
94 
     | 
    
         
             
            from sglang.srt.metrics.func_timer import enable_func_timer
         
     | 
| 
       96 
95 
     | 
    
         | 
| 
       97 
     | 
    
         
            -
            if TYPE_CHECKING:
         
     | 
| 
       98 
     | 
    
         
            -
                from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
         
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
       100 
96 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       101 
97 
     | 
    
         | 
| 
       102 
98 
     | 
    
         
             
            show_time_cost = False
         
     | 
| 
         @@ -138,6 +134,7 @@ def is_xpu() -> bool: 
     | 
|
| 
       138 
134 
     | 
    
         
             
                return hasattr(torch, "xpu") and torch.xpu.is_available()
         
     | 
| 
       139 
135 
     | 
    
         | 
| 
       140 
136 
     | 
    
         | 
| 
      
 137 
     | 
    
         
            +
            @lru_cache(maxsize=1)
         
     | 
| 
       141 
138 
     | 
    
         
             
            def is_npu() -> bool:
         
     | 
| 
       142 
139 
     | 
    
         
             
                return hasattr(torch, "npu") and torch.npu.is_available()
         
     | 
| 
       143 
140 
     | 
    
         | 
| 
         @@ -1069,32 +1066,6 @@ def monkey_patch_p2p_access_check(): 
     | 
|
| 
       1069 
1066 
     | 
    
         
             
                setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
         
     | 
| 
       1070 
1067 
     | 
    
         | 
| 
       1071 
1068 
     | 
    
         | 
| 
       1072 
     | 
    
         
            -
            def monkey_patch_vllm_gguf_config():
         
     | 
| 
       1073 
     | 
    
         
            -
                try:
         
     | 
| 
       1074 
     | 
    
         
            -
                    from vllm.model_executor.layers.quantization.gguf import (
         
     | 
| 
       1075 
     | 
    
         
            -
                        GGUFConfig,
         
     | 
| 
       1076 
     | 
    
         
            -
                        GGUFEmbeddingMethod,
         
     | 
| 
       1077 
     | 
    
         
            -
                        GGUFLinearMethod,
         
     | 
| 
       1078 
     | 
    
         
            -
                    )
         
     | 
| 
       1079 
     | 
    
         
            -
                except ImportError:
         
     | 
| 
       1080 
     | 
    
         
            -
                    return
         
     | 
| 
       1081 
     | 
    
         
            -
             
     | 
| 
       1082 
     | 
    
         
            -
                from sglang.srt.layers.linear import LinearBase
         
     | 
| 
       1083 
     | 
    
         
            -
                from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
         
     | 
| 
       1084 
     | 
    
         
            -
             
     | 
| 
       1085 
     | 
    
         
            -
                def get_quant_method_with_embedding_replaced(
         
     | 
| 
       1086 
     | 
    
         
            -
                    self, layer: torch.nn.Module, prefix: str
         
     | 
| 
       1087 
     | 
    
         
            -
                ) -> Optional[QuantizeMethodBase]:
         
     | 
| 
       1088 
     | 
    
         
            -
                    if isinstance(layer, LinearBase):
         
     | 
| 
       1089 
     | 
    
         
            -
                        return GGUFLinearMethod(self)
         
     | 
| 
       1090 
     | 
    
         
            -
                    elif isinstance(layer, VocabParallelEmbedding):
         
     | 
| 
       1091 
     | 
    
         
            -
                        # patch to own VocabParallelEmbedding
         
     | 
| 
       1092 
     | 
    
         
            -
                        return GGUFEmbeddingMethod(self)
         
     | 
| 
       1093 
     | 
    
         
            -
                    return None
         
     | 
| 
       1094 
     | 
    
         
            -
             
     | 
| 
       1095 
     | 
    
         
            -
                setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
         
     | 
| 
       1096 
     | 
    
         
            -
             
     | 
| 
       1097 
     | 
    
         
            -
             
     | 
| 
       1098 
1069 
     | 
    
         
             
            def set_ulimit(target_soft_limit=65535):
         
     | 
| 
       1099 
1070 
     | 
    
         
             
                # number of open files
         
     | 
| 
       1100 
1071 
     | 
    
         
             
                resource_type = resource.RLIMIT_NOFILE
         
     | 
| 
         @@ -1131,9 +1102,9 @@ def add_api_key_middleware(app, api_key: str): 
     | 
|
| 
       1131 
1102 
     | 
    
         
             
                async def authentication(request, call_next):
         
     | 
| 
       1132 
1103 
     | 
    
         
             
                    if request.method == "OPTIONS":
         
     | 
| 
       1133 
1104 
     | 
    
         
             
                        return await call_next(request)
         
     | 
| 
       1134 
     | 
    
         
            -
                    if request.url.path.startswith("/health") 
     | 
| 
       1135 
     | 
    
         
            -
                         
     | 
| 
       1136 
     | 
    
         
            -
                     
     | 
| 
      
 1105 
     | 
    
         
            +
                    if request.url.path.startswith("/health") or request.url.path.startswith(
         
     | 
| 
      
 1106 
     | 
    
         
            +
                        "/metrics"
         
     | 
| 
      
 1107 
     | 
    
         
            +
                    ):
         
     | 
| 
       1137 
1108 
     | 
    
         
             
                        return await call_next(request)
         
     | 
| 
       1138 
1109 
     | 
    
         
             
                    if request.headers.get("Authorization") != "Bearer " + api_key:
         
     | 
| 
       1139 
1110 
     | 
    
         
             
                        return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
         
     | 
| 
         @@ -2106,7 +2077,7 @@ class MultiprocessingSerializer: 
     | 
|
| 
       2106 
2077 
     | 
    
         | 
| 
       2107 
2078 
     | 
    
         
             
                    if output_str:
         
     | 
| 
       2108 
2079 
     | 
    
         
             
                        # Convert bytes to base64-encoded string
         
     | 
| 
       2109 
     | 
    
         
            -
                        pybase64.b64encode(output).decode("utf-8")
         
     | 
| 
      
 2080 
     | 
    
         
            +
                        output = pybase64.b64encode(output).decode("utf-8")
         
     | 
| 
       2110 
2081 
     | 
    
         | 
| 
       2111 
2082 
     | 
    
         
             
                    return output
         
     | 
| 
       2112 
2083 
     | 
    
         | 
| 
         @@ -2125,7 +2096,78 @@ class MultiprocessingSerializer: 
     | 
|
| 
       2125 
2096 
     | 
    
         
             
                        # Decode base64 string to bytes
         
     | 
| 
       2126 
2097 
     | 
    
         
             
                        data = pybase64.b64decode(data, validate=True)
         
     | 
| 
       2127 
2098 
     | 
    
         | 
| 
       2128 
     | 
    
         
            -
                    return  
     | 
| 
      
 2099 
     | 
    
         
            +
                    return SafeUnpickler(io.BytesIO(data)).load()
         
     | 
| 
      
 2100 
     | 
    
         
            +
             
     | 
| 
      
 2101 
     | 
    
         
            +
             
     | 
| 
      
 2102 
     | 
    
         
            +
            class SafeUnpickler(pickle.Unpickler):
         
     | 
| 
      
 2103 
     | 
    
         
            +
                ALLOWED_MODULE_PREFIXES = {
         
     | 
| 
      
 2104 
     | 
    
         
            +
                    # --- Python types ---
         
     | 
| 
      
 2105 
     | 
    
         
            +
                    "builtins.",
         
     | 
| 
      
 2106 
     | 
    
         
            +
                    "collections.",
         
     | 
| 
      
 2107 
     | 
    
         
            +
                    "copyreg.",
         
     | 
| 
      
 2108 
     | 
    
         
            +
                    "functools.",
         
     | 
| 
      
 2109 
     | 
    
         
            +
                    "itertools.",
         
     | 
| 
      
 2110 
     | 
    
         
            +
                    "operator.",
         
     | 
| 
      
 2111 
     | 
    
         
            +
                    "types.",
         
     | 
| 
      
 2112 
     | 
    
         
            +
                    "weakref.",
         
     | 
| 
      
 2113 
     | 
    
         
            +
                    # --- PyTorch types ---
         
     | 
| 
      
 2114 
     | 
    
         
            +
                    "torch.",
         
     | 
| 
      
 2115 
     | 
    
         
            +
                    "torch._tensor.",
         
     | 
| 
      
 2116 
     | 
    
         
            +
                    "torch.storage.",
         
     | 
| 
      
 2117 
     | 
    
         
            +
                    "torch.nn.parameter.",
         
     | 
| 
      
 2118 
     | 
    
         
            +
                    "torch.autograd.function.",
         
     | 
| 
      
 2119 
     | 
    
         
            +
                    # --- torch distributed ---
         
     | 
| 
      
 2120 
     | 
    
         
            +
                    "torch.distributed.",
         
     | 
| 
      
 2121 
     | 
    
         
            +
                    "torch.distributed._shard.",
         
     | 
| 
      
 2122 
     | 
    
         
            +
                    "torch.distributed._composable.",
         
     | 
| 
      
 2123 
     | 
    
         
            +
                    "torch._C._distributed_c10d.",
         
     | 
| 
      
 2124 
     | 
    
         
            +
                    "torch._C._distributed_fsdp.",
         
     | 
| 
      
 2125 
     | 
    
         
            +
                    "torch.distributed.optim.",
         
     | 
| 
      
 2126 
     | 
    
         
            +
                    # --- multiprocessing ---
         
     | 
| 
      
 2127 
     | 
    
         
            +
                    "multiprocessing.resource_sharer.",
         
     | 
| 
      
 2128 
     | 
    
         
            +
                    "multiprocessing.reduction.",
         
     | 
| 
      
 2129 
     | 
    
         
            +
                    "pickletools.",
         
     | 
| 
      
 2130 
     | 
    
         
            +
                    # --- PEFT / LoRA ---
         
     | 
| 
      
 2131 
     | 
    
         
            +
                    "peft.",
         
     | 
| 
      
 2132 
     | 
    
         
            +
                    "transformers.",
         
     | 
| 
      
 2133 
     | 
    
         
            +
                    "huggingface_hub.",
         
     | 
| 
      
 2134 
     | 
    
         
            +
                    # --- SGLang & Unitest ---
         
     | 
| 
      
 2135 
     | 
    
         
            +
                    "sglang.srt.weight_sync.tensor_bucket.",
         
     | 
| 
      
 2136 
     | 
    
         
            +
                    "sglang.srt.model_executor.model_runner.",
         
     | 
| 
      
 2137 
     | 
    
         
            +
                    "sglang.srt.layers.",
         
     | 
| 
      
 2138 
     | 
    
         
            +
                    "sglang.srt.utils.",
         
     | 
| 
      
 2139 
     | 
    
         
            +
                }
         
     | 
| 
      
 2140 
     | 
    
         
            +
             
     | 
| 
      
 2141 
     | 
    
         
            +
                DENY_CLASSES = {
         
     | 
| 
      
 2142 
     | 
    
         
            +
                    ("builtins", "eval"),
         
     | 
| 
      
 2143 
     | 
    
         
            +
                    ("builtins", "exec"),
         
     | 
| 
      
 2144 
     | 
    
         
            +
                    ("builtins", "compile"),
         
     | 
| 
      
 2145 
     | 
    
         
            +
                    ("os", "system"),
         
     | 
| 
      
 2146 
     | 
    
         
            +
                    ("subprocess", "Popen"),
         
     | 
| 
      
 2147 
     | 
    
         
            +
                    ("subprocess", "run"),
         
     | 
| 
      
 2148 
     | 
    
         
            +
                    ("codecs", "decode"),
         
     | 
| 
      
 2149 
     | 
    
         
            +
                    ("types", "CodeType"),
         
     | 
| 
      
 2150 
     | 
    
         
            +
                    ("types", "FunctionType"),
         
     | 
| 
      
 2151 
     | 
    
         
            +
                }
         
     | 
| 
      
 2152 
     | 
    
         
            +
             
     | 
| 
      
 2153 
     | 
    
         
            +
                def find_class(self, module, name):
         
     | 
| 
      
 2154 
     | 
    
         
            +
                    # Block deterministic attacks
         
     | 
| 
      
 2155 
     | 
    
         
            +
                    if (module, name) in self.DENY_CLASSES:
         
     | 
| 
      
 2156 
     | 
    
         
            +
                        raise RuntimeError(
         
     | 
| 
      
 2157 
     | 
    
         
            +
                            f"Blocked unsafe class loading ({module}.{name}), "
         
     | 
| 
      
 2158 
     | 
    
         
            +
                            f"to prevent exploitation of CVE-2025-10164"
         
     | 
| 
      
 2159 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2160 
     | 
    
         
            +
                    # Allowlist of safe-to-load modules.
         
     | 
| 
      
 2161 
     | 
    
         
            +
                    if any(
         
     | 
| 
      
 2162 
     | 
    
         
            +
                        (module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES
         
     | 
| 
      
 2163 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 2164 
     | 
    
         
            +
                        return super().find_class(module, name)
         
     | 
| 
      
 2165 
     | 
    
         
            +
             
     | 
| 
      
 2166 
     | 
    
         
            +
                    # Block everything else. (Potential attack surface)
         
     | 
| 
      
 2167 
     | 
    
         
            +
                    raise RuntimeError(
         
     | 
| 
      
 2168 
     | 
    
         
            +
                        f"Blocked unsafe class loading ({module}.{name}), "
         
     | 
| 
      
 2169 
     | 
    
         
            +
                        f"to prevent exploitation of CVE-2025-10164"
         
     | 
| 
      
 2170 
     | 
    
         
            +
                    )
         
     | 
| 
       2129 
2171 
     | 
    
         | 
| 
       2130 
2172 
     | 
    
         | 
| 
       2131 
2173 
     | 
    
         
             
            def debug_timing(func):
         
     | 
| 
         @@ -2578,17 +2620,12 @@ def get_local_ip_auto(fallback: str = None) -> str: 
     | 
|
| 
       2578 
2620 
     | 
    
         
             
                raise ValueError("Can not get local ip")
         
     | 
| 
       2579 
2621 
     | 
    
         | 
| 
       2580 
2622 
     | 
    
         | 
| 
       2581 
     | 
    
         
            -
            def is_page_size_one(server_args):
         
     | 
| 
       2582 
     | 
    
         
            -
                return server_args.page_size == 1
         
     | 
| 
       2583 
     | 
    
         
            -
             
     | 
| 
       2584 
     | 
    
         
            -
             
     | 
| 
       2585 
2623 
     | 
    
         
             
            # TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
         
     | 
| 
       2586 
2624 
     | 
    
         
             
            # TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
         
     | 
| 
       2587 
2625 
     | 
    
         
             
            def is_no_spec_infer_or_topk_one(server_args):
         
     | 
| 
       2588 
2626 
     | 
    
         
             
                return server_args.speculative_eagle_topk is None or (
         
     | 
| 
       2589 
     | 
    
         
            -
                    server_args.speculative_eagle_topk  
     | 
| 
       2590 
     | 
    
         
            -
                    and server_args. 
     | 
| 
       2591 
     | 
    
         
            -
                    and is_page_size_one(server_args)
         
     | 
| 
      
 2627 
     | 
    
         
            +
                    server_args.speculative_eagle_topk == 1
         
     | 
| 
      
 2628 
     | 
    
         
            +
                    and (server_args.page_size == 1 or server_args.page_size is None)
         
     | 
| 
       2592 
2629 
     | 
    
         
             
                )
         
     | 
| 
       2593 
2630 
     | 
    
         | 
| 
       2594 
2631 
     | 
    
         | 
| 
         @@ -3528,3 +3565,11 @@ def cached_triton_kernel(key_fn=None): 
     | 
|
| 
       3528 
3565 
     | 
    
         
             
                    return CachedKernel(fn, key_fn)
         
     | 
| 
       3529 
3566 
     | 
    
         | 
| 
       3530 
3567 
     | 
    
         
             
                return decorator
         
     | 
| 
      
 3568 
     | 
    
         
            +
             
     | 
| 
      
 3569 
     | 
    
         
            +
             
     | 
| 
      
 3570 
     | 
    
         
            +
            # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
         
     | 
| 
      
 3571 
     | 
    
         
            +
            def calc_diff(x, y):
         
     | 
| 
      
 3572 
     | 
    
         
            +
                x, y = x.double(), y.double()
         
     | 
| 
      
 3573 
     | 
    
         
            +
                denominator = (x * x + y * y).sum()
         
     | 
| 
      
 3574 
     | 
    
         
            +
                sim = 2 * (x * y).sum() / denominator
         
     | 
| 
      
 3575 
     | 
    
         
            +
                return 1 - sim
         
     | 
| 
         @@ -197,10 +197,14 @@ def get_config( 
     | 
|
| 
       197 
197 
     | 
    
         
             
                    config = AutoConfig.from_pretrained(
         
     | 
| 
       198 
198 
     | 
    
         
             
                        model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
         
     | 
| 
       199 
199 
     | 
    
         
             
                    )
         
     | 
| 
       200 
     | 
    
         
            -
                    if  
     | 
| 
      
 200 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 201 
     | 
    
         
            +
                        getattr(config, "auto_map", None) is not None
         
     | 
| 
      
 202 
     | 
    
         
            +
                        and config.auto_map.get("AutoModel")
         
     | 
| 
      
 203 
     | 
    
         
            +
                        == "modeling_deepseekocr.DeepseekOCRForCausalLM"
         
     | 
| 
      
 204 
     | 
    
         
            +
                    ):
         
     | 
| 
       201 
205 
     | 
    
         
             
                        config.model_type = "deepseek-ocr"
         
     | 
| 
       202 
     | 
    
         
            -
                        #  
     | 
| 
       203 
     | 
    
         
            -
                        #  
     | 
| 
      
 206 
     | 
    
         
            +
                        # TODO: Remove this workaround when AutoConfig correctly identifies deepseek-ocr.
         
     | 
| 
      
 207 
     | 
    
         
            +
                        # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
         
     | 
| 
       204 
208 
     | 
    
         | 
| 
       205 
209 
     | 
    
         
             
                except ValueError as e:
         
     | 
| 
       206 
210 
     | 
    
         
             
                    if not "deepseek_v32" in str(e):
         
     | 
| 
         @@ -17,7 +17,7 @@ import dataclasses 
     | 
|
| 
       17 
17 
     | 
    
         
             
            import json
         
     | 
| 
       18 
18 
     | 
    
         
             
            import os
         
     | 
| 
       19 
19 
     | 
    
         
             
            import random
         
     | 
| 
       20 
     | 
    
         
            -
            from typing import List
         
     | 
| 
      
 20 
     | 
    
         
            +
            from typing import Any, Dict, List, Optional
         
     | 
| 
       21 
21 
     | 
    
         | 
| 
       22 
22 
     | 
    
         
             
            import requests
         
     | 
| 
       23 
23 
     | 
    
         | 
| 
         @@ -78,6 +78,7 @@ class BenchArgs: 
     | 
|
| 
       78 
78 
     | 
    
         
             
                            "single",
         
     | 
| 
       79 
79 
     | 
    
         
             
                            "prefix",
         
     | 
| 
       80 
80 
     | 
    
         
             
                            "radix_cache",
         
     | 
| 
      
 81 
     | 
    
         
            +
                            "p_vs_d",
         
     | 
| 
       81 
82 
     | 
    
         
             
                        ],
         
     | 
| 
       82 
83 
     | 
    
         
             
                    )
         
     | 
| 
       83 
84 
     | 
    
         
             
                    parser.add_argument("--profile", action="store_true")
         
     | 
| 
         @@ -94,18 +95,21 @@ class BenchArgs: 
     | 
|
| 
       94 
95 
     | 
    
         | 
| 
       95 
96 
     | 
    
         
             
            def send_single(
         
     | 
| 
       96 
97 
     | 
    
         
             
                args,
         
     | 
| 
       97 
     | 
    
         
            -
                batch_size: int = 1,
         
     | 
| 
       98 
98 
     | 
    
         
             
                profile: bool = False,
         
     | 
| 
       99 
99 
     | 
    
         
             
                profile_steps: int = 3,
         
     | 
| 
       100 
100 
     | 
    
         
             
                profile_by_stage: bool = False,
         
     | 
| 
       101 
101 
     | 
    
         
             
                return_full_response: bool = False,
         
     | 
| 
       102 
102 
     | 
    
         
             
                input_ids: List[int] = None,
         
     | 
| 
      
 103 
     | 
    
         
            +
                prompt: List[str] = None,
         
     | 
| 
       103 
104 
     | 
    
         
             
                max_new_tokens: int = None,
         
     | 
| 
      
 105 
     | 
    
         
            +
                extra_params: Optional[Dict[str, Any]] = None,
         
     | 
| 
      
 106 
     | 
    
         
            +
                pick_first_result: bool = True,
         
     | 
| 
       104 
107 
     | 
    
         
             
            ):
         
     | 
| 
       105 
108 
     | 
    
         
             
                base_url = f"http://{args.host}:{args.port}"
         
     | 
| 
       106 
109 
     | 
    
         | 
| 
       107 
110 
     | 
    
         
             
                # Use input_ids if provided, otherwise use text prompts
         
     | 
| 
       108 
111 
     | 
    
         
             
                if input_ids is not None:
         
     | 
| 
      
 112 
     | 
    
         
            +
                    assert prompt is None
         
     | 
| 
       109 
113 
     | 
    
         
             
                    json_data = {
         
     | 
| 
       110 
114 
     | 
    
         
             
                        "input_ids": input_ids,
         
     | 
| 
       111 
115 
     | 
    
         
             
                        "sampling_params": {
         
     | 
| 
         @@ -120,9 +124,10 @@ def send_single( 
     | 
|
| 
       120 
124 
     | 
    
         
             
                        },
         
     | 
| 
       121 
125 
     | 
    
         
             
                        "return_logprob": args.return_logprob,
         
     | 
| 
       122 
126 
     | 
    
         
             
                        "stream": args.stream,
         
     | 
| 
      
 127 
     | 
    
         
            +
                        **(extra_params or {}),
         
     | 
| 
       123 
128 
     | 
    
         
             
                    }
         
     | 
| 
       124 
129 
     | 
    
         
             
                else:
         
     | 
| 
       125 
     | 
    
         
            -
                     
     | 
| 
      
 130 
     | 
    
         
            +
                    assert input_ids is None
         
     | 
| 
       126 
131 
     | 
    
         
             
                    json_data = {
         
     | 
| 
       127 
132 
     | 
    
         
             
                        "text": prompt,
         
     | 
| 
       128 
133 
     | 
    
         
             
                        "sampling_params": {
         
     | 
| 
         @@ -137,6 +142,7 @@ def send_single( 
     | 
|
| 
       137 
142 
     | 
    
         
             
                        },
         
     | 
| 
       138 
143 
     | 
    
         
             
                        "return_logprob": args.return_logprob,
         
     | 
| 
       139 
144 
     | 
    
         
             
                        "stream": args.stream,
         
     | 
| 
      
 145 
     | 
    
         
            +
                        **(extra_params or {}),
         
     | 
| 
       140 
146 
     | 
    
         
             
                    }
         
     | 
| 
       141 
147 
     | 
    
         | 
| 
       142 
148 
     | 
    
         
             
                if args.sampling_seed is not None:
         
     | 
| 
         @@ -169,7 +175,8 @@ def send_single( 
     | 
|
| 
       169 
175 
     | 
    
         
             
                else:
         
     | 
| 
       170 
176 
     | 
    
         
             
                    ret = response.json()
         
     | 
| 
       171 
177 
     | 
    
         | 
| 
       172 
     | 
    
         
            -
                 
     | 
| 
      
 178 
     | 
    
         
            +
                if pick_first_result:
         
     | 
| 
      
 179 
     | 
    
         
            +
                    ret = ret[0] if isinstance(ret, list) else ret
         
     | 
| 
       173 
180 
     | 
    
         | 
| 
       174 
181 
     | 
    
         
             
                if return_full_response:
         
     | 
| 
       175 
182 
     | 
    
         
             
                    return ret
         
     | 
| 
         @@ -177,7 +184,9 @@ def send_single( 
     | 
|
| 
       177 
184 
     | 
    
         
             
                    return ret["text"]
         
     | 
| 
       178 
185 
     | 
    
         | 
| 
       179 
186 
     | 
    
         | 
| 
       180 
     | 
    
         
            -
            def send_prefix( 
     | 
| 
      
 187 
     | 
    
         
            +
            def send_prefix(
         
     | 
| 
      
 188 
     | 
    
         
            +
                args, batch_size: int, prompts: List[str], return_full_response: bool = False
         
     | 
| 
      
 189 
     | 
    
         
            +
            ):
         
     | 
| 
       181 
190 
     | 
    
         
             
                requests.post(f"http://{args.host}:{args.port}/flush_cache")
         
     | 
| 
       182 
191 
     | 
    
         | 
| 
       183 
192 
     | 
    
         
             
                batch_data = []
         
     | 
| 
         @@ -212,11 +221,157 @@ def send_prefix(args, batch_size: int, prompts: List[str]): 
     | 
|
| 
       212 
221 
     | 
    
         
             
                    print(ret)
         
     | 
| 
       213 
222 
     | 
    
         
             
                    return -1, -1, -1
         
     | 
| 
       214 
223 
     | 
    
         | 
| 
       215 
     | 
    
         
            -
                 
     | 
| 
       216 
     | 
    
         
            -
             
     | 
| 
       217 
     | 
    
         
            -
                    ret_dict[ 
     | 
| 
      
 224 
     | 
    
         
            +
                if return_full_response:
         
     | 
| 
      
 225 
     | 
    
         
            +
                    # Return full responses grouped by prompt index
         
     | 
| 
      
 226 
     | 
    
         
            +
                    ret_dict = {i: [] for i in range(len(prompts))}
         
     | 
| 
      
 227 
     | 
    
         
            +
                    for i in range(batch_size):
         
     | 
| 
      
 228 
     | 
    
         
            +
                        ret_dict[sampled_indices[i]].append(ret[i])
         
     | 
| 
      
 229 
     | 
    
         
            +
                    return ret_dict
         
     | 
| 
      
 230 
     | 
    
         
            +
                else:
         
     | 
| 
      
 231 
     | 
    
         
            +
                    # Return only text grouped by prompt index
         
     | 
| 
      
 232 
     | 
    
         
            +
                    ret_dict = {i: [] for i in range(len(prompts))}
         
     | 
| 
      
 233 
     | 
    
         
            +
                    for i in range(batch_size):
         
     | 
| 
      
 234 
     | 
    
         
            +
                        ret_dict[sampled_indices[i]].append(ret[i]["text"])
         
     | 
| 
      
 235 
     | 
    
         
            +
                    return ret_dict
         
     | 
| 
      
 236 
     | 
    
         
            +
             
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
            def compare_logprobs(logprobs1, logprobs2, tolerance=0):
         
     | 
| 
      
 239 
     | 
    
         
            +
                """Compare two logprobs sequences with a tolerance."""
         
     | 
| 
      
 240 
     | 
    
         
            +
                if len(logprobs1) != len(logprobs2):
         
     | 
| 
      
 241 
     | 
    
         
            +
                    return False, f"Length mismatch: {len(logprobs1)} vs {len(logprobs2)}"
         
     | 
| 
      
 242 
     | 
    
         
            +
             
     | 
| 
      
 243 
     | 
    
         
            +
                for i, (lp1, lp2) in enumerate(zip(logprobs1, logprobs2)):
         
     | 
| 
      
 244 
     | 
    
         
            +
                    # Each element is [logprob, token_id]
         
     | 
| 
      
 245 
     | 
    
         
            +
                    if lp1[1] != lp2[1]:
         
     | 
| 
      
 246 
     | 
    
         
            +
                        return False, f"Token ID mismatch at position {i}: {lp1[1]} vs {lp2[1]}"
         
     | 
| 
      
 247 
     | 
    
         
            +
                    if abs(lp1[0] - lp2[0]) > tolerance:
         
     | 
| 
      
 248 
     | 
    
         
            +
                        return (
         
     | 
| 
      
 249 
     | 
    
         
            +
                            False,
         
     | 
| 
      
 250 
     | 
    
         
            +
                            f"Logprob mismatch at position {i}: {lp1[0]} vs {lp2[0]} (diff: {abs(lp1[0] - lp2[0])})",
         
     | 
| 
      
 251 
     | 
    
         
            +
                        )
         
     | 
| 
      
 252 
     | 
    
         
            +
             
     | 
| 
      
 253 
     | 
    
         
            +
                return True, "Logprobs match"
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
       218 
255 
     | 
    
         | 
| 
       219 
     | 
    
         
            -
             
     | 
| 
      
 256 
     | 
    
         
            +
            def _test_mode_p_vs_d(args, batch_size):
         
     | 
| 
      
 257 
     | 
    
         
            +
                print()
         
     | 
| 
      
 258 
     | 
    
         
            +
                print(f"Execute: test p_vs_d {batch_size=}")
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                random.seed(42)
         
     | 
| 
      
 261 
     | 
    
         
            +
                args.return_logprob = True
         
     | 
| 
      
 262 
     | 
    
         
            +
                query_extra_params = {
         
     | 
| 
      
 263 
     | 
    
         
            +
                    "logprob_start_len": 0,
         
     | 
| 
      
 264 
     | 
    
         
            +
                    "return_text_in_logprobs": True,
         
     | 
| 
      
 265 
     | 
    
         
            +
                }
         
     | 
| 
      
 266 
     | 
    
         
            +
             
     | 
| 
      
 267 
     | 
    
         
            +
                def _create_prompts():
         
     | 
| 
      
 268 
     | 
    
         
            +
                    ans = [PROMPT_1, PROMPT_2]
         
     | 
| 
      
 269 
     | 
    
         
            +
                    for i in range(batch_size - len(ans)):
         
     | 
| 
      
 270 
     | 
    
         
            +
                        end = random.randrange(1, 4096)
         
     | 
| 
      
 271 
     | 
    
         
            +
                        if random.random() < 0.5:
         
     | 
| 
      
 272 
     | 
    
         
            +
                            begin = 0
         
     | 
| 
      
 273 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 274 
     | 
    
         
            +
                            begin = random.randrange(0, end)
         
     | 
| 
      
 275 
     | 
    
         
            +
                        ans.append(LONG_PROMPT[begin:end])
         
     | 
| 
      
 276 
     | 
    
         
            +
                    return ans[:batch_size]
         
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
                # warmup + flush
         
     | 
| 
      
 279 
     | 
    
         
            +
                send_single(args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True)
         
     | 
| 
      
 280 
     | 
    
         
            +
                requests.post(f"http://{args.host}:{args.port}/flush_cache")
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
                prompts = _create_prompts()
         
     | 
| 
      
 283 
     | 
    
         
            +
             
     | 
| 
      
 284 
     | 
    
         
            +
                resp_a = send_single(
         
     | 
| 
      
 285 
     | 
    
         
            +
                    args,
         
     | 
| 
      
 286 
     | 
    
         
            +
                    prompt=prompts,
         
     | 
| 
      
 287 
     | 
    
         
            +
                    max_new_tokens=args.max_new_tokens,
         
     | 
| 
      
 288 
     | 
    
         
            +
                    return_full_response=True,
         
     | 
| 
      
 289 
     | 
    
         
            +
                    pick_first_result=False,
         
     | 
| 
      
 290 
     | 
    
         
            +
                    extra_params=query_extra_params,
         
     | 
| 
      
 291 
     | 
    
         
            +
                )
         
     | 
| 
      
 292 
     | 
    
         
            +
                info_a = _extract_ids_and_logprobs(resp_a)
         
     | 
| 
      
 293 
     | 
    
         
            +
             
     | 
| 
      
 294 
     | 
    
         
            +
                requests.post(f"http://{args.host}:{args.port}/flush_cache")
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
      
 296 
     | 
    
         
            +
                resp_b = send_single(
         
     | 
| 
      
 297 
     | 
    
         
            +
                    args,
         
     | 
| 
      
 298 
     | 
    
         
            +
                    input_ids=[x["io"].token_ids for x in info_a],
         
     | 
| 
      
 299 
     | 
    
         
            +
                    max_new_tokens=1,
         
     | 
| 
      
 300 
     | 
    
         
            +
                    return_full_response=True,
         
     | 
| 
      
 301 
     | 
    
         
            +
                    pick_first_result=False,
         
     | 
| 
      
 302 
     | 
    
         
            +
                    extra_params=query_extra_params,
         
     | 
| 
      
 303 
     | 
    
         
            +
                )
         
     | 
| 
      
 304 
     | 
    
         
            +
                info_b = _extract_ids_and_logprobs(resp_b)
         
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
                ans = []
         
     | 
| 
      
 307 
     | 
    
         
            +
                for i, (info_a_item, info_b_item) in enumerate(zip(info_a, info_b, strict=True)):
         
     | 
| 
      
 308 
     | 
    
         
            +
                    print(f"Compare sequence {i} in batch...")
         
     | 
| 
      
 309 
     | 
    
         
            +
                    correct = TokenIdsAndLogprobs.compare(info_a_item["io"], info_b_item["input"])
         
     | 
| 
      
 310 
     | 
    
         
            +
                    ans.append(int(correct))
         
     | 
| 
      
 311 
     | 
    
         
            +
             
     | 
| 
      
 312 
     | 
    
         
            +
                return ans
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
             
     | 
| 
      
 315 
     | 
    
         
            +
            @dataclasses.dataclass
         
     | 
| 
      
 316 
     | 
    
         
            +
            class TokenIdsAndLogprobs:
         
     | 
| 
      
 317 
     | 
    
         
            +
                token_ids: List[int]
         
     | 
| 
      
 318 
     | 
    
         
            +
                logprobs: List[float]
         
     | 
| 
      
 319 
     | 
    
         
            +
             
     | 
| 
      
 320 
     | 
    
         
            +
                def __add__(self, other):
         
     | 
| 
      
 321 
     | 
    
         
            +
                    return TokenIdsAndLogprobs(
         
     | 
| 
      
 322 
     | 
    
         
            +
                        token_ids=self.token_ids + other.token_ids,
         
     | 
| 
      
 323 
     | 
    
         
            +
                        logprobs=self.logprobs + other.logprobs,
         
     | 
| 
      
 324 
     | 
    
         
            +
                    )
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 327 
     | 
    
         
            +
                def compare(cls, a: "TokenIdsAndLogprobs", b: "TokenIdsAndLogprobs"):
         
     | 
| 
      
 328 
     | 
    
         
            +
                    assert len(a.token_ids) == len(b.token_ids)
         
     | 
| 
      
 329 
     | 
    
         
            +
                    token_match = a.token_ids == b.token_ids
         
     | 
| 
      
 330 
     | 
    
         
            +
                    logprobs_match = a.logprobs == b.logprobs
         
     | 
| 
      
 331 
     | 
    
         
            +
             
     | 
| 
      
 332 
     | 
    
         
            +
                    if token_match:
         
     | 
| 
      
 333 
     | 
    
         
            +
                        print(f"Token match: {a.token_ids}")
         
     | 
| 
      
 334 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 335 
     | 
    
         
            +
                        print(f"❗Token mismatch: {a.token_ids=} {b.token_ids=}")
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
      
 337 
     | 
    
         
            +
                    if logprobs_match:
         
     | 
| 
      
 338 
     | 
    
         
            +
                        print(f"Logprobs match:", a.logprobs)
         
     | 
| 
      
 339 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 340 
     | 
    
         
            +
                        print(f"❗Logprobs mismatch")
         
     | 
| 
      
 341 
     | 
    
         
            +
                        print(
         
     | 
| 
      
 342 
     | 
    
         
            +
                            "    A:   ",
         
     | 
| 
      
 343 
     | 
    
         
            +
                            [f"{x:.10f}" if x is not None else "None" for x in a.logprobs],
         
     | 
| 
      
 344 
     | 
    
         
            +
                        )
         
     | 
| 
      
 345 
     | 
    
         
            +
                        print(
         
     | 
| 
      
 346 
     | 
    
         
            +
                            "    B:   ",
         
     | 
| 
      
 347 
     | 
    
         
            +
                            [f"{x:.10f}" if x is not None else "None" for x in b.logprobs],
         
     | 
| 
      
 348 
     | 
    
         
            +
                        )
         
     | 
| 
      
 349 
     | 
    
         
            +
                        diff = [
         
     | 
| 
      
 350 
     | 
    
         
            +
                            abs(x - y) if x is not None else float("nan")
         
     | 
| 
      
 351 
     | 
    
         
            +
                            for x, y in zip(a.logprobs, b.logprobs)
         
     | 
| 
      
 352 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 353 
     | 
    
         
            +
                        print("    Diff:", [f"{x:.10e}" for x in diff])
         
     | 
| 
      
 354 
     | 
    
         
            +
             
     | 
| 
      
 355 
     | 
    
         
            +
                    return token_match and logprobs_match
         
     | 
| 
      
 356 
     | 
    
         
            +
             
     | 
| 
      
 357 
     | 
    
         
            +
             
     | 
| 
      
 358 
     | 
    
         
            +
            def _extract_ids_and_logprobs(responses):
         
     | 
| 
      
 359 
     | 
    
         
            +
                def _extract_part(response, name):
         
     | 
| 
      
 360 
     | 
    
         
            +
                    token_ids, logprobs = [], []
         
     | 
| 
      
 361 
     | 
    
         
            +
                    for item in response["meta_info"][name]:
         
     | 
| 
      
 362 
     | 
    
         
            +
                        logprob, token_id, text = item
         
     | 
| 
      
 363 
     | 
    
         
            +
                        token_ids.append(token_id)
         
     | 
| 
      
 364 
     | 
    
         
            +
                        logprobs.append(logprob)
         
     | 
| 
      
 365 
     | 
    
         
            +
                    return TokenIdsAndLogprobs(token_ids=token_ids, logprobs=logprobs)
         
     | 
| 
      
 366 
     | 
    
         
            +
             
     | 
| 
      
 367 
     | 
    
         
            +
                def _extract_one_response(response):
         
     | 
| 
      
 368 
     | 
    
         
            +
                    input = _extract_part(response, "input_token_logprobs")
         
     | 
| 
      
 369 
     | 
    
         
            +
                    output = _extract_part(response, "output_token_logprobs")
         
     | 
| 
      
 370 
     | 
    
         
            +
                    return dict(input=input, output=output, io=input + output)
         
     | 
| 
      
 371 
     | 
    
         
            +
             
     | 
| 
      
 372 
     | 
    
         
            +
                if not isinstance(responses, list):
         
     | 
| 
      
 373 
     | 
    
         
            +
                    responses = [responses]
         
     | 
| 
      
 374 
     | 
    
         
            +
                return [_extract_one_response(x) for x in responses]
         
     | 
| 
       220 
375 
     | 
    
         | 
| 
       221 
376 
     | 
    
         | 
| 
       222 
377 
     | 
    
         
             
            def test_deterministic(args):
         
     | 
| 
         @@ -225,7 +380,7 @@ def test_deterministic(args): 
     | 
|
| 
       225 
380 
     | 
    
         
             
                    texts = []
         
     | 
| 
       226 
381 
     | 
    
         
             
                    for i in range(1, args.n_trials + 1):
         
     | 
| 
       227 
382 
     | 
    
         
             
                        batch_size = i
         
     | 
| 
       228 
     | 
    
         
            -
                        text = send_single(args,  
     | 
| 
      
 383 
     | 
    
         
            +
                        text = send_single(args, args.profile, prompt=[PROMPT_1] * batch_size)
         
     | 
| 
       229 
384 
     | 
    
         
             
                        text = text.replace("\n", " ")
         
     | 
| 
       230 
385 
     | 
    
         
             
                        print(f"Trial {i} with batch size {batch_size}: {text}")
         
     | 
| 
       231 
386 
     | 
    
         
             
                        texts.append(text)
         
     | 
| 
         @@ -238,15 +393,28 @@ def test_deterministic(args): 
     | 
|
| 
       238 
393 
     | 
    
         
             
                    num_prompts = len(len_prefix)
         
     | 
| 
       239 
394 
     | 
    
         
             
                    outputs = {i: [] for i in range(4)}
         
     | 
| 
       240 
395 
     | 
    
         
             
                    prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
         
     | 
| 
      
 396 
     | 
    
         
            +
             
     | 
| 
      
 397 
     | 
    
         
            +
                    # If return_logprob is enabled, store full responses for comparison
         
     | 
| 
      
 398 
     | 
    
         
            +
                    if args.return_logprob:
         
     | 
| 
      
 399 
     | 
    
         
            +
                        full_responses = {i: [] for i in range(4)}
         
     | 
| 
      
 400 
     | 
    
         
            +
             
     | 
| 
       241 
401 
     | 
    
         
             
                    for i in range(args.n_start, args.n_start + args.n_trials):
         
     | 
| 
       242 
402 
     | 
    
         
             
                        batch_size = i
         
     | 
| 
       243 
     | 
    
         
            -
                        ret_dict = send_prefix( 
     | 
| 
      
 403 
     | 
    
         
            +
                        ret_dict = send_prefix(
         
     | 
| 
      
 404 
     | 
    
         
            +
                            args, batch_size, prompts, return_full_response=args.return_logprob
         
     | 
| 
      
 405 
     | 
    
         
            +
                        )
         
     | 
| 
       244 
406 
     | 
    
         
             
                        msg = f"Testing Trial {i} with batch size {batch_size},"
         
     | 
| 
       245 
407 
     | 
    
         
             
                        for i in range(num_prompts):
         
     | 
| 
       246 
408 
     | 
    
         
             
                            msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
         
     | 
| 
       247 
409 
     | 
    
         
             
                        print(msg)
         
     | 
| 
       248 
410 
     | 
    
         
             
                        for i in range(num_prompts):
         
     | 
| 
       249 
     | 
    
         
            -
                             
     | 
| 
      
 411 
     | 
    
         
            +
                            if args.return_logprob:
         
     | 
| 
      
 412 
     | 
    
         
            +
                                # Store full response for logprob comparison
         
     | 
| 
      
 413 
     | 
    
         
            +
                                full_responses[i].extend(ret_dict[i])
         
     | 
| 
      
 414 
     | 
    
         
            +
                                # Extract text for determinism check
         
     | 
| 
      
 415 
     | 
    
         
            +
                                outputs[i].extend([resp["text"] for resp in ret_dict[i]])
         
     | 
| 
      
 416 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 417 
     | 
    
         
            +
                                outputs[i].extend(ret_dict[i])
         
     | 
| 
       250 
418 
     | 
    
         | 
| 
       251 
419 
     | 
    
         
             
                    for i in range(num_prompts):
         
     | 
| 
       252 
420 
     | 
    
         
             
                        print(
         
     | 
| 
         @@ -256,6 +424,54 @@ def test_deterministic(args): 
     | 
|
| 
       256 
424 
     | 
    
         
             
                    results = []
         
     | 
| 
       257 
425 
     | 
    
         
             
                    for i in range(num_prompts):
         
     | 
| 
       258 
426 
     | 
    
         
             
                        results.append(len(set(outputs[i])))
         
     | 
| 
      
 427 
     | 
    
         
            +
             
     | 
| 
      
 428 
     | 
    
         
            +
                    # If logprobs are enabled, compare them across different batch sizes
         
     | 
| 
      
 429 
     | 
    
         
            +
                    if args.return_logprob:
         
     | 
| 
      
 430 
     | 
    
         
            +
                        print(f"\n{'='*60}")
         
     | 
| 
      
 431 
     | 
    
         
            +
                        print("Logprobs Comparison Across Batch Sizes")
         
     | 
| 
      
 432 
     | 
    
         
            +
                        print("=" * 60)
         
     | 
| 
      
 433 
     | 
    
         
            +
             
     | 
| 
      
 434 
     | 
    
         
            +
                        logprob_results = []
         
     | 
| 
      
 435 
     | 
    
         
            +
                        for prompt_idx in range(num_prompts):
         
     | 
| 
      
 436 
     | 
    
         
            +
                            print(
         
     | 
| 
      
 437 
     | 
    
         
            +
                                f"\nPrompt {prompt_idx} (prefix length {len_prefix[prompt_idx]}):"
         
     | 
| 
      
 438 
     | 
    
         
            +
                            )
         
     | 
| 
      
 439 
     | 
    
         
            +
                            responses = full_responses[prompt_idx]
         
     | 
| 
      
 440 
     | 
    
         
            +
             
     | 
| 
      
 441 
     | 
    
         
            +
                            if len(responses) < 2:
         
     | 
| 
      
 442 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 443 
     | 
    
         
            +
             
     | 
| 
      
 444 
     | 
    
         
            +
                            # Compare all responses against the first one
         
     | 
| 
      
 445 
     | 
    
         
            +
                            reference = responses[0]
         
     | 
| 
      
 446 
     | 
    
         
            +
                            all_match = True
         
     | 
| 
      
 447 
     | 
    
         
            +
                            mismatches = []
         
     | 
| 
      
 448 
     | 
    
         
            +
             
     | 
| 
      
 449 
     | 
    
         
            +
                            for j, resp in enumerate(responses[1:], start=1):
         
     | 
| 
      
 450 
     | 
    
         
            +
                                ref_logprobs = reference["meta_info"]["output_token_logprobs"]
         
     | 
| 
      
 451 
     | 
    
         
            +
                                resp_logprobs = resp["meta_info"]["output_token_logprobs"]
         
     | 
| 
      
 452 
     | 
    
         
            +
             
     | 
| 
      
 453 
     | 
    
         
            +
                                match, msg = compare_logprobs(ref_logprobs, resp_logprobs)
         
     | 
| 
      
 454 
     | 
    
         
            +
             
     | 
| 
      
 455 
     | 
    
         
            +
                                if not match:
         
     | 
| 
      
 456 
     | 
    
         
            +
                                    print(f"  ✗ Sample {j+1}: {msg}")
         
     | 
| 
      
 457 
     | 
    
         
            +
                                    mismatches.append((j + 1, msg))
         
     | 
| 
      
 458 
     | 
    
         
            +
                                    all_match = False
         
     | 
| 
      
 459 
     | 
    
         
            +
             
     | 
| 
      
 460 
     | 
    
         
            +
                            if all_match:
         
     | 
| 
      
 461 
     | 
    
         
            +
                                print(f"  ✓ All {len(responses)} samples have identical logprobs")
         
     | 
| 
      
 462 
     | 
    
         
            +
                                logprob_results.append(1)
         
     | 
| 
      
 463 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 464 
     | 
    
         
            +
                                print(
         
     | 
| 
      
 465 
     | 
    
         
            +
                                    f"  ✗ Found {len(mismatches)} mismatches out of {len(responses)} samples"
         
     | 
| 
      
 466 
     | 
    
         
            +
                                )
         
     | 
| 
      
 467 
     | 
    
         
            +
                                logprob_results.append(0)
         
     | 
| 
      
 468 
     | 
    
         
            +
             
     | 
| 
      
 469 
     | 
    
         
            +
                        print(f"\n{'='*60}")
         
     | 
| 
      
 470 
     | 
    
         
            +
                        if all(r == 1 for r in logprob_results):
         
     | 
| 
      
 471 
     | 
    
         
            +
                            print("✓✓✓ Logprobs are identical across all batch sizes! ✓✓✓")
         
     | 
| 
      
 472 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 473 
     | 
    
         
            +
                            print("✗✗✗ Some logprobs differ across batch sizes! ✗✗✗")
         
     | 
| 
      
 474 
     | 
    
         
            +
             
     | 
| 
       259 
475 
     | 
    
         
             
                    return results
         
     | 
| 
       260 
476 
     | 
    
         | 
| 
       261 
477 
     | 
    
         
             
                elif args.test_mode == "radix_cache":
         
     | 
| 
         @@ -415,6 +631,13 @@ def test_deterministic(args): 
     | 
|
| 
       415 
631 
     | 
    
         
             
                        print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
         
     | 
| 
       416 
632 
     | 
    
         
             
                        return [0]
         
     | 
| 
       417 
633 
     | 
    
         | 
| 
      
 634 
     | 
    
         
            +
                elif args.test_mode == "p_vs_d":
         
     | 
| 
      
 635 
     | 
    
         
            +
                    # TODO also extract other modes to functions
         
     | 
| 
      
 636 
     | 
    
         
            +
                    ans = []
         
     | 
| 
      
 637 
     | 
    
         
            +
                    for i in range(1, args.n_trials + 1):
         
     | 
| 
      
 638 
     | 
    
         
            +
                        ans += _test_mode_p_vs_d(args, batch_size=i)
         
     | 
| 
      
 639 
     | 
    
         
            +
                    return ans
         
     | 
| 
      
 640 
     | 
    
         
            +
             
     | 
| 
       418 
641 
     | 
    
         
             
                else:
         
     | 
| 
       419 
642 
     | 
    
         
             
                    raise ValueError(f"Invalid test mode: {args.test_mode}")
         
     | 
| 
       420 
643 
     | 
    
         | 
| 
         @@ -60,7 +60,7 @@ class TestDeterministicBase(CustomTestCase): 
     | 
|
| 
       60 
60 
     | 
    
         
             
                    for result in results:
         
     | 
| 
       61 
61 
     | 
    
         
             
                        assert result == 1
         
     | 
| 
       62 
62 
     | 
    
         | 
| 
       63 
     | 
    
         
            -
                def  
     | 
| 
      
 63 
     | 
    
         
            +
                def test_prefix_with_logprobs(self):
         
     | 
| 
       64 
64 
     | 
    
         
             
                    args = BenchArgs()
         
     | 
| 
       65 
65 
     | 
    
         
             
                    url = DEFAULT_URL_FOR_TEST
         
     | 
| 
       66 
66 
     | 
    
         
             
                    args.host, args.port = self._extract_host_and_port(url)
         
     | 
| 
         @@ -68,6 +68,7 @@ class TestDeterministicBase(CustomTestCase): 
     | 
|
| 
       68 
68 
     | 
    
         
             
                    args.n_start = 10
         
     | 
| 
       69 
69 
     | 
    
         
             
                    args.n_trials = 10
         
     | 
| 
       70 
70 
     | 
    
         
             
                    args.temperature = 0.5  # test for deterministic sampling
         
     | 
| 
      
 71 
     | 
    
         
            +
                    args.return_logprob = True  # Enable logprobs comparison
         
     | 
| 
       71 
72 
     | 
    
         
             
                    results = test_deterministic(args)
         
     | 
| 
       72 
73 
     | 
    
         
             
                    for result in results:
         
     | 
| 
       73 
74 
     | 
    
         
             
                        assert result == 1
         
     | 
    
        sglang/version.py
    CHANGED
    
    | 
         @@ -1 +1 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            __version__ = "0.5.4"
         
     | 
| 
      
 1 
     | 
    
         
            +
            __version__ = "0.5.4.post1"
         
     | 
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.4
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: sglang
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.5.4
         
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.5.4.post1
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: SGLang is a fast serving framework for large language models and vision language models.
         
     | 
| 
       5 
5 
     | 
    
         
             
            License:                                  Apache License
         
     | 
| 
       6 
6 
     | 
    
         
             
                                               Version 2.0, January 2004
         
     | 
| 
         @@ -223,6 +223,7 @@ Requires-Dist: datasets 
     | 
|
| 
       223 
223 
     | 
    
         
             
            Requires-Dist: einops
         
     | 
| 
       224 
224 
     | 
    
         
             
            Requires-Dist: fastapi
         
     | 
| 
       225 
225 
     | 
    
         
             
            Requires-Dist: flashinfer_python==0.4.1
         
     | 
| 
      
 226 
     | 
    
         
            +
            Requires-Dist: gguf
         
     | 
| 
       226 
227 
     | 
    
         
             
            Requires-Dist: hf_transfer
         
     | 
| 
       227 
228 
     | 
    
         
             
            Requires-Dist: huggingface_hub
         
     | 
| 
       228 
229 
     | 
    
         
             
            Requires-Dist: interegular
         
     | 
| 
         @@ -251,7 +252,7 @@ Requires-Dist: requests 
     | 
|
| 
       251 
252 
     | 
    
         
             
            Requires-Dist: scipy
         
     | 
| 
       252 
253 
     | 
    
         
             
            Requires-Dist: sentencepiece
         
     | 
| 
       253 
254 
     | 
    
         
             
            Requires-Dist: setproctitle
         
     | 
| 
       254 
     | 
    
         
            -
            Requires-Dist: sgl-kernel==0.3.16. 
     | 
| 
      
 255 
     | 
    
         
            +
            Requires-Dist: sgl-kernel==0.3.16.post4
         
     | 
| 
       255 
256 
     | 
    
         
             
            Requires-Dist: soundfile==0.13.1
         
     | 
| 
       256 
257 
     | 
    
         
             
            Requires-Dist: tiktoken
         
     | 
| 
       257 
258 
     | 
    
         
             
            Requires-Dist: timm==1.0.16
         
     | 
| 
         @@ -274,7 +275,6 @@ Requires-Dist: nvidia-modelopt; extra == "modelopt" 
     | 
|
| 
       274 
275 
     | 
    
         
             
            Provides-Extra: test
         
     | 
| 
       275 
276 
     | 
    
         
             
            Requires-Dist: accelerate; extra == "test"
         
     | 
| 
       276 
277 
     | 
    
         
             
            Requires-Dist: expecttest; extra == "test"
         
     | 
| 
       277 
     | 
    
         
            -
            Requires-Dist: gguf; extra == "test"
         
     | 
| 
       278 
278 
     | 
    
         
             
            Requires-Dist: jsonlines; extra == "test"
         
     | 
| 
       279 
279 
     | 
    
         
             
            Requires-Dist: matplotlib; extra == "test"
         
     | 
| 
       280 
280 
     | 
    
         
             
            Requires-Dist: pandas; extra == "test"
         
     | 
| 
         @@ -320,7 +320,7 @@ Dynamic: license-file 
     | 
|
| 
       320 
320 
     | 
    
         | 
| 
       321 
321 
     | 
    
         
             
            --------------------------------------------------------------------------------
         
     | 
| 
       322 
322 
     | 
    
         | 
| 
       323 
     | 
    
         
            -
            | [**Blog**](https://lmsys.org/blog/ 
     | 
| 
      
 323 
     | 
    
         
            +
            | [**Blog**](https://lmsys.org/blog/)
         
     | 
| 
       324 
324 
     | 
    
         
             
            | [**Documentation**](https://docs.sglang.ai/)
         
     | 
| 
       325 
325 
     | 
    
         
             
            | [**Join Slack**](https://slack.sglang.ai/)
         
     | 
| 
       326 
326 
     | 
    
         
             
            | [**Join Bi-Weekly Development Meeting**](https://meeting.sglang.ai/)
         
     | 
| 
         @@ -328,9 +328,10 @@ Dynamic: license-file 
     | 
|
| 
       328 
328 
     | 
    
         
             
            | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
         
     | 
| 
       329 
329 
     | 
    
         | 
| 
       330 
330 
     | 
    
         
             
            ## News
         
     | 
| 
      
 331 
     | 
    
         
            +
            - [2025/10] 🔥 AMD AI Dev Day 2025 SGLang ([slide](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/sglang_amd_ai_devday_2025.pdf)), PyTorch Conference 2025 SGLang ([slide](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/sglang_pytorch_2025.pdf)).
         
     | 
| 
       331 
332 
     | 
    
         
             
            - [2025/09] 🔥 Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part II): 3.8x Prefill, 4.8x Decode Throughput ([blog](https://lmsys.org/blog/2025-09-25-gb200-part-2/)).
         
     | 
| 
       332 
     | 
    
         
            -
            - [2025/09]  
     | 
| 
       333 
     | 
    
         
            -
            - [2025/08]  
     | 
| 
      
 333 
     | 
    
         
            +
            - [2025/09] SGLang Day 0 Support for DeepSeek-V3.2 with Sparse Attention ([blog](https://lmsys.org/blog/2025-09-29-deepseek-V32/)).
         
     | 
| 
      
 334 
     | 
    
         
            +
            - [2025/08] SGLang x AMD SF Meetup on 8/22: Hands-on GPU workshop, tech talks by AMD/xAI/SGLang, and networking ([Roadmap](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_roadmap.pdf), [Large-scale EP](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_ep.pdf), [Highlights](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_highlights.pdf), [AITER/MoRI](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_aiter_mori.pdf), [Wave](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_wave.pdf)).
         
     | 
| 
       334 
335 
     | 
    
         
             
            - [2025/08] SGLang provides day-0 support for OpenAI gpt-oss model ([instructions](https://github.com/sgl-project/sglang/issues/8833))
         
     | 
| 
       335 
336 
     | 
    
         
             
            - [2025/05] Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
         
     | 
| 
       336 
337 
     | 
    
         
             
            - [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
         
     |