sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/tracing/trace.py
    CHANGED
    
    | 
         @@ -15,6 +15,8 @@ 
     | 
|
| 
       15 
15 
     | 
    
         | 
| 
       16 
16 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       17 
17 
     | 
    
         | 
| 
      
 18 
     | 
    
         
            +
            import base64
         
     | 
| 
      
 19 
     | 
    
         
            +
            import json
         
     | 
| 
       18 
20 
     | 
    
         
             
            import logging
         
     | 
| 
       19 
21 
     | 
    
         
             
            import os
         
     | 
| 
       20 
22 
     | 
    
         
             
            import random
         
     | 
| 
         @@ -24,6 +26,8 @@ import uuid 
     | 
|
| 
       24 
26 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       25 
27 
     | 
    
         
             
            from typing import TYPE_CHECKING, Any, Dict, List, Optional
         
     | 
| 
       26 
28 
     | 
    
         | 
| 
      
 29 
     | 
    
         
            +
            from sglang.srt.utils import get_int_env_var
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
       27 
31 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       28 
32 
     | 
    
         
             
                from sglang.srt.managers.scheduler import Req
         
     | 
| 
       29 
33 
     | 
    
         | 
| 
         @@ -85,6 +89,8 @@ class SglangTraceReqContext: 
     | 
|
| 
       85 
89 
     | 
    
         
             
                # Indicates whether this instance is a replica from the main process.
         
     | 
| 
       86 
90 
     | 
    
         
             
                # When True, root_span is None and only root_span_context is preserved.
         
     | 
| 
       87 
91 
     | 
    
         
             
                is_copy: bool = False
         
     | 
| 
      
 92 
     | 
    
         
            +
                bootstrap_room_span: Optional[trace.span.Span] = None
         
     | 
| 
      
 93 
     | 
    
         
            +
                bootstrap_room_span_context: Optional[context.Context] = None
         
     | 
| 
       88 
94 
     | 
    
         
             
                root_span: Optional[trace.span.Span] = None
         
     | 
| 
       89 
95 
     | 
    
         
             
                root_span_context: Optional[context.Context] = None
         
     | 
| 
       90 
96 
     | 
    
         | 
| 
         @@ -96,8 +102,7 @@ class SglangTracePropagateContext: 
     | 
|
| 
       96 
102 
     | 
    
         | 
| 
       97 
103 
     | 
    
         
             
                def to_dict(self):
         
     | 
| 
       98 
104 
     | 
    
         
             
                    carrier: dict[str, str] = {}
         
     | 
| 
       99 
     | 
    
         
            -
                     
     | 
| 
       100 
     | 
    
         
            -
                    propagate.inject(carrier)
         
     | 
| 
      
 105 
     | 
    
         
            +
                    propagate.inject(carrier, self.root_span_context)
         
     | 
| 
       101 
106 
     | 
    
         | 
| 
       102 
107 
     | 
    
         
             
                    if self.prev_span_context:
         
     | 
| 
       103 
108 
     | 
    
         
             
                        return {
         
     | 
| 
         @@ -149,6 +154,7 @@ class SglangTraceCustomIdGenerator(id_generator.IdGenerator): 
     | 
|
| 
       149 
154 
     | 
    
         | 
| 
       150 
155 
     | 
    
         | 
| 
       151 
156 
     | 
    
         
             
            # global variables
         
     | 
| 
      
 157 
     | 
    
         
            +
            remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {}
         
     | 
| 
       152 
158 
     | 
    
         
             
            threads_info: Dict[int, SglangTraceThreadInfo] = {}
         
     | 
| 
       153 
159 
     | 
    
         
             
            reqs_context: Dict[str, SglangTraceReqContext] = {}
         
     | 
| 
       154 
160 
     | 
    
         | 
| 
         @@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name): 
     | 
|
| 
       193 
199 
     | 
    
         
             
                        resource=resource, id_generator=SglangTraceCustomIdGenerator()
         
     | 
| 
       194 
200 
     | 
    
         
             
                    )
         
     | 
| 
       195 
201 
     | 
    
         | 
| 
      
 202 
     | 
    
         
            +
                    schedule_delay_millis = get_int_env_var(
         
     | 
| 
      
 203 
     | 
    
         
            +
                        "SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500
         
     | 
| 
      
 204 
     | 
    
         
            +
                    )
         
     | 
| 
      
 205 
     | 
    
         
            +
                    max_export_batch_size = get_int_env_var(
         
     | 
| 
      
 206 
     | 
    
         
            +
                        "SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64
         
     | 
| 
      
 207 
     | 
    
         
            +
                    )
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
       196 
209 
     | 
    
         
             
                    processor = BatchSpanProcessor(
         
     | 
| 
       197 
     | 
    
         
            -
                        OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
         
     | 
| 
      
 210 
     | 
    
         
            +
                        OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True),
         
     | 
| 
      
 211 
     | 
    
         
            +
                        schedule_delay_millis=schedule_delay_millis,
         
     | 
| 
      
 212 
     | 
    
         
            +
                        max_export_batch_size=max_export_batch_size,
         
     | 
| 
       198 
213 
     | 
    
         
             
                    )
         
     | 
| 
       199 
214 
     | 
    
         
             
                    tracer_provider.add_span_processor(processor)
         
     | 
| 
       200 
215 
     | 
    
         
             
                    trace.set_tracer_provider(tracer_provider)
         
     | 
| 
         @@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): 
     | 
|
| 
       266 
281 
     | 
    
         
             
                return thread_context
         
     | 
| 
       267 
282 
     | 
    
         | 
| 
       268 
283 
     | 
    
         | 
| 
       269 
     | 
    
         
            -
            def trace_get_proc_propagate_context( 
     | 
| 
      
 284 
     | 
    
         
            +
            def trace_get_proc_propagate_context(
         
     | 
| 
      
 285 
     | 
    
         
            +
                rid, remote_propagate=False
         
     | 
| 
      
 286 
     | 
    
         
            +
            ) -> Optional[Dict[str, Any]]:
         
     | 
| 
       270 
287 
     | 
    
         
             
                if not tracing_enabled:
         
     | 
| 
       271 
288 
     | 
    
         
             
                    return None
         
     | 
| 
       272 
289 
     | 
    
         | 
| 
         @@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: 
     | 
|
| 
       283 
300 
     | 
    
         
             
                elif thread_context.last_span_context:
         
     | 
| 
       284 
301 
     | 
    
         
             
                    prev_span_context = thread_context.last_span_context
         
     | 
| 
       285 
302 
     | 
    
         | 
| 
       286 
     | 
    
         
            -
                 
     | 
| 
       287 
     | 
    
         
            -
             
     | 
| 
       288 
     | 
    
         
            -
             
     | 
| 
      
 303 
     | 
    
         
            +
                root_span_context = reqs_context[rid].root_span_context
         
     | 
| 
      
 304 
     | 
    
         
            +
                if remote_propagate:
         
     | 
| 
      
 305 
     | 
    
         
            +
                    root_span_context = reqs_context[rid].bootstrap_room_span_context
         
     | 
| 
      
 306 
     | 
    
         
            +
             
     | 
| 
      
 307 
     | 
    
         
            +
                trace_context = SglangTracePropagateContext(root_span_context, prev_span_context)
         
     | 
| 
       289 
308 
     | 
    
         
             
                return trace_context.to_dict()
         
     | 
| 
       290 
309 
     | 
    
         | 
| 
       291 
310 
     | 
    
         | 
| 
         @@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any] 
     | 
|
| 
       327 
346 
     | 
    
         
             
                ].last_span_context = trace_context.prev_span_context
         
     | 
| 
       328 
347 
     | 
    
         | 
| 
       329 
348 
     | 
    
         | 
| 
      
 349 
     | 
    
         
            +
            def trace_get_remote_propagate_context(bootstrap_room_list: List[str]):
         
     | 
| 
      
 350 
     | 
    
         
            +
                if not tracing_enabled:
         
     | 
| 
      
 351 
     | 
    
         
            +
                    return ""
         
     | 
| 
      
 352 
     | 
    
         
            +
             
     | 
| 
      
 353 
     | 
    
         
            +
                reqs_trace_contexts = {}
         
     | 
| 
      
 354 
     | 
    
         
            +
                for bootstrap_room in bootstrap_room_list:
         
     | 
| 
      
 355 
     | 
    
         
            +
                    # In the router, rid is also the bootstrap room.
         
     | 
| 
      
 356 
     | 
    
         
            +
                    bootstrap_room = str(bootstrap_room)
         
     | 
| 
      
 357 
     | 
    
         
            +
             
     | 
| 
      
 358 
     | 
    
         
            +
                    if bootstrap_room not in reqs_context:
         
     | 
| 
      
 359 
     | 
    
         
            +
                        continue
         
     | 
| 
      
 360 
     | 
    
         
            +
             
     | 
| 
      
 361 
     | 
    
         
            +
                    _context = trace_get_proc_propagate_context(
         
     | 
| 
      
 362 
     | 
    
         
            +
                        bootstrap_room, remote_propagate=True
         
     | 
| 
      
 363 
     | 
    
         
            +
                    )
         
     | 
| 
      
 364 
     | 
    
         
            +
                    reqs_trace_contexts[bootstrap_room] = _context
         
     | 
| 
      
 365 
     | 
    
         
            +
             
     | 
| 
      
 366 
     | 
    
         
            +
                json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False)
         
     | 
| 
      
 367 
     | 
    
         
            +
                return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
         
     | 
| 
      
 368 
     | 
    
         
            +
             
     | 
| 
      
 369 
     | 
    
         
            +
             
     | 
| 
      
 370 
     | 
    
         
            +
            def trace_set_remote_propagate_context(base64_str):
         
     | 
| 
      
 371 
     | 
    
         
            +
                if not tracing_enabled:
         
     | 
| 
      
 372 
     | 
    
         
            +
                    return
         
     | 
| 
      
 373 
     | 
    
         
            +
             
     | 
| 
      
 374 
     | 
    
         
            +
                if base64_str is None or base64_str == "" or base64_str == "None":
         
     | 
| 
      
 375 
     | 
    
         
            +
                    return
         
     | 
| 
      
 376 
     | 
    
         
            +
             
     | 
| 
      
 377 
     | 
    
         
            +
                base64_bytes = base64.b64decode(base64_str)
         
     | 
| 
      
 378 
     | 
    
         
            +
                json_str = base64_bytes.decode("utf-8")
         
     | 
| 
      
 379 
     | 
    
         
            +
                remote_reqs_trace_contexts = json.loads(json_str)
         
     | 
| 
      
 380 
     | 
    
         
            +
             
     | 
| 
      
 381 
     | 
    
         
            +
                for bootstrap_room in remote_reqs_trace_contexts:
         
     | 
| 
      
 382 
     | 
    
         
            +
                    if bootstrap_room in remote_trace_contexts:
         
     | 
| 
      
 383 
     | 
    
         
            +
                        continue
         
     | 
| 
      
 384 
     | 
    
         
            +
             
     | 
| 
      
 385 
     | 
    
         
            +
                    remote_trace_contexts[bootstrap_room] = (
         
     | 
| 
      
 386 
     | 
    
         
            +
                        SglangTracePropagateContext.instance_from_dict(
         
     | 
| 
      
 387 
     | 
    
         
            +
                            remote_reqs_trace_contexts[bootstrap_room]
         
     | 
| 
      
 388 
     | 
    
         
            +
                        )
         
     | 
| 
      
 389 
     | 
    
         
            +
                    )
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
      
 391 
     | 
    
         
            +
             
     | 
| 
       330 
392 
     | 
    
         
             
            def trace_req_start(
         
     | 
| 
       331 
393 
     | 
    
         
             
                rid: str,
         
     | 
| 
       332 
394 
     | 
    
         
             
                bootstrap_room: Optional[int] = None,
         
     | 
| 
       333 
395 
     | 
    
         
             
                ts: Optional[int] = None,
         
     | 
| 
      
 396 
     | 
    
         
            +
                role: Optional[str] = "null",
         
     | 
| 
       334 
397 
     | 
    
         
             
            ):
         
     | 
| 
       335 
398 
     | 
    
         
             
                if not tracing_enabled:
         
     | 
| 
       336 
399 
     | 
    
         
             
                    return
         
     | 
| 
         @@ -344,6 +407,7 @@ def trace_req_start( 
     | 
|
| 
       344 
407 
     | 
    
         
             
                    return
         
     | 
| 
       345 
408 
     | 
    
         | 
| 
       346 
409 
     | 
    
         
             
                # create req context and root span
         
     | 
| 
      
 410 
     | 
    
         
            +
                bootstrap_room = 0 if bootstrap_room is None else bootstrap_room
         
     | 
| 
       347 
411 
     | 
    
         
             
                reqs_context[rid] = SglangTraceReqContext(
         
     | 
| 
       348 
412 
     | 
    
         
             
                    rid=rid,
         
     | 
| 
       349 
413 
     | 
    
         
             
                    start_time_ns=ts,
         
     | 
| 
         @@ -352,23 +416,42 @@ def trace_req_start( 
     | 
|
| 
       352 
416 
     | 
    
         
             
                    is_copy=False,
         
     | 
| 
       353 
417 
     | 
    
         
             
                )
         
     | 
| 
       354 
418 
     | 
    
         | 
| 
      
 419 
     | 
    
         
            +
                # create bootstrap room span
         
     | 
| 
      
 420 
     | 
    
         
            +
                tracer = threads_info[pid].tracer
         
     | 
| 
      
 421 
     | 
    
         
            +
                if str(bootstrap_room) not in remote_trace_contexts:
         
     | 
| 
      
 422 
     | 
    
         
            +
                    attrs = {"bootstrap_room": str(hex(bootstrap_room))}
         
     | 
| 
      
 423 
     | 
    
         
            +
                    bootstrap_room_span = tracer.start_span(
         
     | 
| 
      
 424 
     | 
    
         
            +
                        name=f"Bootstrap Room {hex(bootstrap_room)}",
         
     | 
| 
      
 425 
     | 
    
         
            +
                        start_time=ts,
         
     | 
| 
      
 426 
     | 
    
         
            +
                        attributes=attrs,
         
     | 
| 
      
 427 
     | 
    
         
            +
                    )
         
     | 
| 
      
 428 
     | 
    
         
            +
                    reqs_context[rid].bootstrap_room_span = bootstrap_room_span
         
     | 
| 
      
 429 
     | 
    
         
            +
                    bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span)
         
     | 
| 
      
 430 
     | 
    
         
            +
                else:
         
     | 
| 
      
 431 
     | 
    
         
            +
                    bootstrap_room_span_context = remote_trace_contexts[
         
     | 
| 
      
 432 
     | 
    
         
            +
                        str(bootstrap_room)
         
     | 
| 
      
 433 
     | 
    
         
            +
                    ].root_span_context
         
     | 
| 
      
 434 
     | 
    
         
            +
             
     | 
| 
       355 
435 
     | 
    
         
             
                # Drop the worker_id added by MultiTokenizer
         
     | 
| 
       356 
436 
     | 
    
         
             
                orig_rid = rid.split("_")[-1]
         
     | 
| 
       357 
     | 
    
         
            -
                 
     | 
| 
      
 437 
     | 
    
         
            +
                role = "" if role == "null" else role
         
     | 
| 
      
 438 
     | 
    
         
            +
                attrs = {"rid": orig_rid}
         
     | 
| 
       358 
439 
     | 
    
         
             
                root_span = tracer.start_span(
         
     | 
| 
       359 
     | 
    
         
            -
                    name=f"Req {orig_rid[:8]}",
         
     | 
| 
      
 440 
     | 
    
         
            +
                    name=f"{role} Req {orig_rid[:8]}",
         
     | 
| 
       360 
441 
     | 
    
         
             
                    start_time=ts,
         
     | 
| 
      
 442 
     | 
    
         
            +
                    context=bootstrap_room_span_context,
         
     | 
| 
      
 443 
     | 
    
         
            +
                    attributes=attrs,
         
     | 
| 
       361 
444 
     | 
    
         
             
                )
         
     | 
| 
       362 
445 
     | 
    
         | 
| 
       363 
446 
     | 
    
         
             
                root_span.set_attributes(
         
     | 
| 
       364 
447 
     | 
    
         
             
                    {
         
     | 
| 
       365 
448 
     | 
    
         
             
                        "rid": rid,
         
     | 
| 
       366 
     | 
    
         
            -
                        "bootstrap_room": bootstrap_room if bootstrap_room else "None",
         
     | 
| 
       367 
449 
     | 
    
         
             
                    }
         
     | 
| 
       368 
450 
     | 
    
         
             
                )
         
     | 
| 
       369 
451 
     | 
    
         | 
| 
       370 
452 
     | 
    
         
             
                reqs_context[rid].root_span = root_span
         
     | 
| 
       371 
453 
     | 
    
         
             
                reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
         
     | 
| 
      
 454 
     | 
    
         
            +
                reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context
         
     | 
| 
       372 
455 
     | 
    
         | 
| 
       373 
456 
     | 
    
         
             
                # create thread context and thread span
         
     | 
| 
       374 
457 
     | 
    
         
             
                reqs_context[rid].threads_context[pid] = __create_thread_context(
         
     | 
| 
         @@ -376,6 +459,10 @@ def trace_req_start( 
     | 
|
| 
       376 
459 
     | 
    
         
             
                    reqs_context[rid].root_span_context,
         
     | 
| 
       377 
460 
     | 
    
         
             
                    ts,
         
     | 
| 
       378 
461 
     | 
    
         
             
                )
         
     | 
| 
      
 462 
     | 
    
         
            +
                if str(bootstrap_room) in remote_trace_contexts:
         
     | 
| 
      
 463 
     | 
    
         
            +
                    reqs_context[rid].threads_context[pid].last_span_context = (
         
     | 
| 
      
 464 
     | 
    
         
            +
                        remote_trace_contexts[str(bootstrap_room)].prev_span_context
         
     | 
| 
      
 465 
     | 
    
         
            +
                    )
         
     | 
| 
       379 
466 
     | 
    
         | 
| 
       380 
467 
     | 
    
         | 
| 
       381 
468 
     | 
    
         
             
            def trace_req_finish(
         
     | 
| 
         @@ -399,6 +486,10 @@ def trace_req_finish( 
     | 
|
| 
       399 
486 
     | 
    
         
             
                    req_context.root_span.set_attributes(attrs)
         
     | 
| 
       400 
487 
     | 
    
         | 
| 
       401 
488 
     | 
    
         
             
                req_context.root_span.end(end_time=ts)
         
     | 
| 
      
 489 
     | 
    
         
            +
                if str(req_context.bootstrap_room) in remote_trace_contexts:
         
     | 
| 
      
 490 
     | 
    
         
            +
                    del remote_trace_contexts[str(req_context.bootstrap_room)]
         
     | 
| 
      
 491 
     | 
    
         
            +
                else:
         
     | 
| 
      
 492 
     | 
    
         
            +
                    req_context.bootstrap_room_span.end(end_time=ts)
         
     | 
| 
       402 
493 
     | 
    
         | 
| 
       403 
494 
     | 
    
         
             
                del reqs_context[rid]
         
     | 
| 
       404 
495 
     | 
    
         | 
| 
         @@ -518,7 +609,9 @@ trace_slice = trace_slice_end 
     | 
|
| 
       518 
609 
     | 
    
         | 
| 
       519 
610 
     | 
    
         | 
| 
       520 
611 
     | 
    
         
             
            # Add event to the current slice on the same thread with the same rid.
         
     | 
| 
       521 
     | 
    
         
            -
            def trace_event( 
     | 
| 
      
 612 
     | 
    
         
            +
            def trace_event(
         
     | 
| 
      
 613 
     | 
    
         
            +
                name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None
         
     | 
| 
      
 614 
     | 
    
         
            +
            ):
         
     | 
| 
       522 
615 
     | 
    
         
             
                if not tracing_enabled:
         
     | 
| 
       523 
616 
     | 
    
         
             
                    return
         
     | 
| 
       524 
617 
     | 
    
         | 
| 
         @@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None): 
     | 
|
| 
       539 
632 
     | 
    
         
             
                ts = ts or __get_cur_time_ns()
         
     | 
| 
       540 
633 
     | 
    
         | 
| 
       541 
634 
     | 
    
         
             
                slice_info = thread_context.cur_slice_stack[-1]
         
     | 
| 
       542 
     | 
    
         
            -
                slice_info.span.add_event(name=name, timestamp=ts)
         
     | 
| 
      
 635 
     | 
    
         
            +
                slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs)
         
     | 
| 
       543 
636 
     | 
    
         | 
| 
       544 
637 
     | 
    
         | 
| 
       545 
638 
     | 
    
         
             
            # Add attrs to the current slice on the same thread with the same rid.
         
     | 
| 
         @@ -569,6 +662,9 @@ def trace_slice_batch( 
     | 
|
| 
       569 
662 
     | 
    
         
             
                name: str,
         
     | 
| 
       570 
663 
     | 
    
         
             
                reqs: List[Req],
         
     | 
| 
       571 
664 
     | 
    
         
             
            ):
         
     | 
| 
      
 665 
     | 
    
         
            +
                if not tracing_enabled:
         
     | 
| 
      
 666 
     | 
    
         
            +
                    return
         
     | 
| 
      
 667 
     | 
    
         
            +
             
     | 
| 
       572 
668 
     | 
    
         
             
                for req in reqs:
         
     | 
| 
       573 
669 
     | 
    
         
             
                    trace_slice(
         
     | 
| 
       574 
670 
     | 
    
         
             
                        name,
         
     | 
| 
         @@ -576,3 +672,16 @@ def trace_slice_batch( 
     | 
|
| 
       576 
672 
     | 
    
         
             
                        auto_next_anon=not req.finished(),
         
     | 
| 
       577 
673 
     | 
    
         
             
                        thread_finish_flag=req.finished(),
         
     | 
| 
       578 
674 
     | 
    
         
             
                    )
         
     | 
| 
      
 675 
     | 
    
         
            +
             
     | 
| 
      
 676 
     | 
    
         
            +
             
     | 
| 
      
 677 
     | 
    
         
            +
            def trace_event_batch(
         
     | 
| 
      
 678 
     | 
    
         
            +
                name: str,
         
     | 
| 
      
 679 
     | 
    
         
            +
                reqs: List[Req],
         
     | 
| 
      
 680 
     | 
    
         
            +
                ts: Optional[int] = None,
         
     | 
| 
      
 681 
     | 
    
         
            +
                attrs: Dict[str, Any] = None,
         
     | 
| 
      
 682 
     | 
    
         
            +
            ):
         
     | 
| 
      
 683 
     | 
    
         
            +
                if not tracing_enabled:
         
     | 
| 
      
 684 
     | 
    
         
            +
                    return
         
     | 
| 
      
 685 
     | 
    
         
            +
             
     | 
| 
      
 686 
     | 
    
         
            +
                for req in reqs:
         
     | 
| 
      
 687 
     | 
    
         
            +
                    trace_event(name, req.rid, ts=ts, attrs=attrs)
         
     | 
    
        sglang/srt/utils/common.py
    CHANGED
    
    | 
         @@ -188,7 +188,16 @@ is_hopper_with_cuda_12_3 = lambda: _check(9) 
     | 
|
| 
       188 
188 
     | 
    
         
             
            def is_blackwell():
         
     | 
| 
       189 
189 
     | 
    
         
             
                if not is_cuda():
         
     | 
| 
       190 
190 
     | 
    
         
             
                    return False
         
     | 
| 
       191 
     | 
    
         
            -
                return torch.cuda.get_device_capability()[0]  
     | 
| 
      
 191 
     | 
    
         
            +
                return torch.cuda.get_device_capability()[0] in [10, 12]
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
            @lru_cache(maxsize=1)
         
     | 
| 
      
 195 
     | 
    
         
            +
            def is_blackwell_supported(device=None) -> bool:
         
     | 
| 
      
 196 
     | 
    
         
            +
                if not is_cuda_alike():
         
     | 
| 
      
 197 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 198 
     | 
    
         
            +
                return (torch.cuda.get_device_capability(device)[0] in [10, 12]) and (
         
     | 
| 
      
 199 
     | 
    
         
            +
                    torch.version.cuda >= "12.8"
         
     | 
| 
      
 200 
     | 
    
         
            +
                )
         
     | 
| 
       192 
201 
     | 
    
         | 
| 
       193 
202 
     | 
    
         | 
| 
       194 
203 
     | 
    
         
             
            @lru_cache(maxsize=1)
         
     | 
| 
         @@ -1230,42 +1239,34 @@ def point_to_point_pyobj( 
     | 
|
| 
       1230 
1239 
     | 
    
         
             
                dst: int = 1,
         
     | 
| 
       1231 
1240 
     | 
    
         
             
            ):
         
     | 
| 
       1232 
1241 
     | 
    
         
             
                """Send data from src to dst in group using DeviceToDevice communication."""
         
     | 
| 
       1233 
     | 
    
         
            -
             
     | 
| 
      
 1242 
     | 
    
         
            +
                device = torch.get_device_module().current_device()
         
     | 
| 
       1234 
1243 
     | 
    
         
             
                if rank == src:
         
     | 
| 
       1235 
1244 
     | 
    
         
             
                    if len(data) == 0:
         
     | 
| 
       1236 
     | 
    
         
            -
                        tensor_size = torch.tensor(
         
     | 
| 
       1237 
     | 
    
         
            -
                            [0], dtype=torch.long, device=torch.cuda.current_device()
         
     | 
| 
       1238 
     | 
    
         
            -
                        )
         
     | 
| 
      
 1245 
     | 
    
         
            +
                        tensor_size = torch.tensor([0], dtype=torch.long, device=device)
         
     | 
| 
       1239 
1246 
     | 
    
         
             
                        dist.send(tensor_size, dst=dst, group=group)
         
     | 
| 
       1240 
1247 
     | 
    
         
             
                    else:
         
     | 
| 
       1241 
1248 
     | 
    
         
             
                        serialized_data = pickle.dumps(data)
         
     | 
| 
       1242 
1249 
     | 
    
         
             
                        size = len(serialized_data)
         
     | 
| 
       1243 
1250 
     | 
    
         
             
                        tensor_data = torch.ByteTensor(
         
     | 
| 
       1244 
1251 
     | 
    
         
             
                            np.frombuffer(serialized_data, dtype=np.uint8)
         
     | 
| 
       1245 
     | 
    
         
            -
                        ). 
     | 
| 
       1246 
     | 
    
         
            -
                            device= 
     | 
| 
      
 1252 
     | 
    
         
            +
                        ).to(
         
     | 
| 
      
 1253 
     | 
    
         
            +
                            device=device
         
     | 
| 
       1247 
1254 
     | 
    
         
             
                        )  # Move to GPU
         
     | 
| 
       1248 
     | 
    
         
            -
                        tensor_size = torch.tensor(
         
     | 
| 
       1249 
     | 
    
         
            -
                            [size], dtype=torch.long, device=torch.cuda.current_device()
         
     | 
| 
       1250 
     | 
    
         
            -
                        )
         
     | 
| 
      
 1255 
     | 
    
         
            +
                        tensor_size = torch.tensor([size], dtype=torch.long, device=device)
         
     | 
| 
       1251 
1256 
     | 
    
         | 
| 
       1252 
1257 
     | 
    
         
             
                        dist.send(tensor_size, dst=dst, group=group)
         
     | 
| 
       1253 
1258 
     | 
    
         
             
                        dist.send(tensor_data, dst=dst, group=group)
         
     | 
| 
       1254 
1259 
     | 
    
         
             
                    return data
         
     | 
| 
       1255 
1260 
     | 
    
         | 
| 
       1256 
1261 
     | 
    
         
             
                elif rank == dst:
         
     | 
| 
       1257 
     | 
    
         
            -
                    tensor_size = torch.tensor(
         
     | 
| 
       1258 
     | 
    
         
            -
                        [0], dtype=torch.long, device=torch.cuda.current_device()
         
     | 
| 
       1259 
     | 
    
         
            -
                    )
         
     | 
| 
      
 1262 
     | 
    
         
            +
                    tensor_size = torch.tensor([0], dtype=torch.long, device=device)
         
     | 
| 
       1260 
1263 
     | 
    
         
             
                    dist.recv(tensor_size, src=src, group=group)
         
     | 
| 
       1261 
1264 
     | 
    
         
             
                    size = tensor_size.item()
         
     | 
| 
       1262 
1265 
     | 
    
         | 
| 
       1263 
1266 
     | 
    
         
             
                    if size == 0:
         
     | 
| 
       1264 
1267 
     | 
    
         
             
                        return []
         
     | 
| 
       1265 
1268 
     | 
    
         | 
| 
       1266 
     | 
    
         
            -
                    tensor_data = torch.empty(
         
     | 
| 
       1267 
     | 
    
         
            -
                        size, dtype=torch.uint8, device=torch.cuda.current_device()
         
     | 
| 
       1268 
     | 
    
         
            -
                    )
         
     | 
| 
      
 1269 
     | 
    
         
            +
                    tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
         
     | 
| 
       1269 
1270 
     | 
    
         
             
                    dist.recv(tensor_data, src=src, group=group)
         
     | 
| 
       1270 
1271 
     | 
    
         | 
| 
       1271 
1272 
     | 
    
         
             
                    serialized_data = bytes(
         
     | 
| 
         @@ -2350,16 +2351,24 @@ def launch_dummy_health_check_server(host, port, enable_metrics): 
     | 
|
| 
       2350 
2351 
     | 
    
         
             
                )
         
     | 
| 
       2351 
2352 
     | 
    
         
             
                server = uvicorn.Server(config=config)
         
     | 
| 
       2352 
2353 
     | 
    
         | 
| 
       2353 
     | 
    
         
            -
                 
     | 
| 
       2354 
     | 
    
         
            -
             
     | 
| 
       2355 
     | 
    
         
            -
             
     | 
| 
       2356 
     | 
    
         
            -
             
     | 
| 
       2357 
     | 
    
         
            -
             
     | 
| 
       2358 
     | 
    
         
            -
                     
     | 
| 
      
 2354 
     | 
    
         
            +
                # Run server in a background daemon thread with its own event loop
         
     | 
| 
      
 2355 
     | 
    
         
            +
                # This prevents blocking the main thread while still serving health checks
         
     | 
| 
      
 2356 
     | 
    
         
            +
                def run_server():
         
     | 
| 
      
 2357 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 2358 
     | 
    
         
            +
                        asyncio.run(server.serve())
         
     | 
| 
      
 2359 
     | 
    
         
            +
                    except Exception as e:
         
     | 
| 
      
 2360 
     | 
    
         
            +
                        logger.error(f"Dummy health check server failed to start: {e}")
         
     | 
| 
      
 2361 
     | 
    
         
            +
                        raise
         
     | 
| 
      
 2362 
     | 
    
         
            +
                    finally:
         
     | 
| 
      
 2363 
     | 
    
         
            +
                        logger.info(f"Dummy health check server stopped at {host}:{port}")
         
     | 
| 
       2359 
2364 
     | 
    
         | 
| 
       2360 
     | 
    
         
            -
                 
     | 
| 
       2361 
     | 
    
         
            -
                     
     | 
| 
       2362 
     | 
    
         
            -
             
     | 
| 
      
 2365 
     | 
    
         
            +
                thread = threading.Thread(
         
     | 
| 
      
 2366 
     | 
    
         
            +
                    target=run_server, daemon=True, name="health-check-server"
         
     | 
| 
      
 2367 
     | 
    
         
            +
                )
         
     | 
| 
      
 2368 
     | 
    
         
            +
                thread.start()
         
     | 
| 
      
 2369 
     | 
    
         
            +
                logger.info(
         
     | 
| 
      
 2370 
     | 
    
         
            +
                    f"Dummy health check server started in background thread at {host}:{port}"
         
     | 
| 
      
 2371 
     | 
    
         
            +
                )
         
     | 
| 
       2363 
2372 
     | 
    
         | 
| 
       2364 
2373 
     | 
    
         | 
| 
       2365 
2374 
     | 
    
         
             
            def create_checksum(directory: str):
         
     | 
| 
         @@ -3105,12 +3114,16 @@ def apply_module_patch(target_module, target_function, wrappers): 
     | 
|
| 
       3105 
3114 
     | 
    
         
             
                    setattr(original_module, target_function, candidate)
         
     | 
| 
       3106 
3115 
     | 
    
         | 
| 
       3107 
3116 
     | 
    
         
             
                for key, value in sys.modules.copy().items():
         
     | 
| 
       3108 
     | 
    
         
            -
                     
     | 
| 
       3109 
     | 
    
         
            -
                         
     | 
| 
       3110 
     | 
    
         
            -
             
     | 
| 
       3111 
     | 
    
         
            -
             
     | 
| 
       3112 
     | 
    
         
            -
             
     | 
| 
       3113 
     | 
    
         
            -
                         
     | 
| 
      
 3117 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 3118 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 3119 
     | 
    
         
            +
                            target_function is not None
         
     | 
| 
      
 3120 
     | 
    
         
            +
                            and hasattr(value, target_function)
         
     | 
| 
      
 3121 
     | 
    
         
            +
                            and id(getattr(value, target_function)) == original_function_id
         
     | 
| 
      
 3122 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 3123 
     | 
    
         
            +
                            setattr(value, target_function, candidate)
         
     | 
| 
      
 3124 
     | 
    
         
            +
                    except ImportError as e:
         
     | 
| 
      
 3125 
     | 
    
         
            +
                        # Ignore some modules reporting ImportError when calling hasattr
         
     | 
| 
      
 3126 
     | 
    
         
            +
                        logger.warning(f"Ignore {value} reports ImportError with:\n{str(e)}")
         
     | 
| 
       3114 
3127 
     | 
    
         | 
| 
       3115 
3128 
     | 
    
         | 
| 
       3116 
3129 
     | 
    
         
             
            def parse_module_path(module_path, function_name, create_dummy):
         
     | 
| 
         @@ -3562,7 +3575,17 @@ def cached_triton_kernel(key_fn=None): 
     | 
|
| 
       3562 
3575 
     | 
    
         
             
                """
         
     | 
| 
       3563 
3576 
     | 
    
         | 
| 
       3564 
3577 
     | 
    
         
             
                def decorator(fn):
         
     | 
| 
       3565 
     | 
    
         
            -
                     
     | 
| 
      
 3578 
     | 
    
         
            +
                    if envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.get():
         
     | 
| 
      
 3579 
     | 
    
         
            +
                        logger.debug(
         
     | 
| 
      
 3580 
     | 
    
         
            +
                            f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = True. Using custom triton kernel cache."
         
     | 
| 
      
 3581 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3582 
     | 
    
         
            +
                        return CachedKernel(fn, key_fn)
         
     | 
| 
      
 3583 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 3584 
     | 
    
         
            +
                        # Fallback to the native triton cache.
         
     | 
| 
      
 3585 
     | 
    
         
            +
                        logger.debug(
         
     | 
| 
      
 3586 
     | 
    
         
            +
                            f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = False. Using native triton kernel cache."
         
     | 
| 
      
 3587 
     | 
    
         
            +
                        )
         
     | 
| 
      
 3588 
     | 
    
         
            +
                        return fn
         
     | 
| 
       3566 
3589 
     | 
    
         | 
| 
       3567 
3590 
     | 
    
         
             
                return decorator
         
     | 
| 
       3568 
3591 
     | 
    
         | 
| 
         @@ -43,6 +43,7 @@ from sglang.srt.configs import ( 
     | 
|
| 
       43 
43 
     | 
    
         
             
                DotsVLMConfig,
         
     | 
| 
       44 
44 
     | 
    
         
             
                ExaoneConfig,
         
     | 
| 
       45 
45 
     | 
    
         
             
                FalconH1Config,
         
     | 
| 
      
 46 
     | 
    
         
            +
                KimiLinearConfig,
         
     | 
| 
       46 
47 
     | 
    
         
             
                KimiVLConfig,
         
     | 
| 
       47 
48 
     | 
    
         
             
                LongcatFlashConfig,
         
     | 
| 
       48 
49 
     | 
    
         
             
                MultiModalityConfig,
         
     | 
| 
         @@ -54,6 +55,7 @@ from sglang.srt.configs import ( 
     | 
|
| 
       54 
55 
     | 
    
         
             
            from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
         
     | 
| 
       55 
56 
     | 
    
         
             
            from sglang.srt.configs.internvl import InternVLChatConfig
         
     | 
| 
       56 
57 
     | 
    
         
             
            from sglang.srt.connector import create_remote_connector
         
     | 
| 
      
 58 
     | 
    
         
            +
            from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
         
     | 
| 
       57 
59 
     | 
    
         
             
            from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
         
     | 
| 
       58 
60 
     | 
    
         | 
| 
       59 
61 
     | 
    
         
             
            _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
         
     | 
| 
         @@ -67,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ 
     | 
|
| 
       67 
69 
     | 
    
         
             
                Step3VLConfig,
         
     | 
| 
       68 
70 
     | 
    
         
             
                LongcatFlashConfig,
         
     | 
| 
       69 
71 
     | 
    
         
             
                Olmo3Config,
         
     | 
| 
      
 72 
     | 
    
         
            +
                KimiLinearConfig,
         
     | 
| 
       70 
73 
     | 
    
         
             
                Qwen3NextConfig,
         
     | 
| 
       71 
74 
     | 
    
         
             
                FalconH1Config,
         
     | 
| 
       72 
75 
     | 
    
         
             
                DotsVLMConfig,
         
     | 
| 
         @@ -172,6 +175,16 @@ def _load_deepseek_v32_model( 
     | 
|
| 
       172 
175 
     | 
    
         
             
                )
         
     | 
| 
       173 
176 
     | 
    
         | 
| 
       174 
177 
     | 
    
         | 
| 
      
 178 
     | 
    
         
            +
            def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool:
         
     | 
| 
      
 179 
     | 
    
         
            +
                # TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr.
         
     | 
| 
      
 180 
     | 
    
         
            +
                # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
         
     | 
| 
      
 181 
     | 
    
         
            +
                return (
         
     | 
| 
      
 182 
     | 
    
         
            +
                    getattr(config, "auto_map", None) is not None
         
     | 
| 
      
 183 
     | 
    
         
            +
                    and config.auto_map.get("AutoModel")
         
     | 
| 
      
 184 
     | 
    
         
            +
                    == "modeling_deepseekocr.DeepseekOCRForCausalLM"
         
     | 
| 
      
 185 
     | 
    
         
            +
                )
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
       175 
188 
     | 
    
         
             
            @lru_cache_frozenset(maxsize=32)
         
     | 
| 
       176 
189 
     | 
    
         
             
            def get_config(
         
     | 
| 
       177 
190 
     | 
    
         
             
                model: str,
         
     | 
| 
         @@ -197,14 +210,6 @@ def get_config( 
     | 
|
| 
       197 
210 
     | 
    
         
             
                    config = AutoConfig.from_pretrained(
         
     | 
| 
       198 
211 
     | 
    
         
             
                        model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
         
     | 
| 
       199 
212 
     | 
    
         
             
                    )
         
     | 
| 
       200 
     | 
    
         
            -
                    if (
         
     | 
| 
       201 
     | 
    
         
            -
                        getattr(config, "auto_map", None) is not None
         
     | 
| 
       202 
     | 
    
         
            -
                        and config.auto_map.get("AutoModel")
         
     | 
| 
       203 
     | 
    
         
            -
                        == "modeling_deepseekocr.DeepseekOCRForCausalLM"
         
     | 
| 
       204 
     | 
    
         
            -
                    ):
         
     | 
| 
       205 
     | 
    
         
            -
                        config.model_type = "deepseek-ocr"
         
     | 
| 
       206 
     | 
    
         
            -
                        # TODO: Remove this workaround when AutoConfig correctly identifies deepseek-ocr.
         
     | 
| 
       207 
     | 
    
         
            -
                        # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
         
     | 
| 
       208 
213 
     | 
    
         | 
| 
       209 
214 
     | 
    
         
             
                except ValueError as e:
         
     | 
| 
       210 
215 
     | 
    
         
             
                    if not "deepseek_v32" in str(e):
         
     | 
| 
         @@ -241,7 +246,11 @@ def get_config( 
     | 
|
| 
       241 
246 
     | 
    
         
             
                            setattr(config, key, val)
         
     | 
| 
       242 
247 
     | 
    
         | 
| 
       243 
248 
     | 
    
         
             
                if config.model_type in _CONFIG_REGISTRY:
         
     | 
| 
       244 
     | 
    
         
            -
                     
     | 
| 
      
 249 
     | 
    
         
            +
                    model_type = config.model_type
         
     | 
| 
      
 250 
     | 
    
         
            +
                    if model_type == "deepseek_vl_v2":
         
     | 
| 
      
 251 
     | 
    
         
            +
                        if _is_deepseek_ocr_model(config):
         
     | 
| 
      
 252 
     | 
    
         
            +
                            model_type = "deepseek-ocr"
         
     | 
| 
      
 253 
     | 
    
         
            +
                    config_class = _CONFIG_REGISTRY[model_type]
         
     | 
| 
       245 
254 
     | 
    
         
             
                    config = config_class.from_pretrained(model, revision=revision)
         
     | 
| 
       246 
255 
     | 
    
         
             
                    # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
         
     | 
| 
       247 
256 
     | 
    
         
             
                    setattr(config, "_name_or_path", model)
         
     | 
| 
         @@ -445,6 +454,10 @@ def get_processor( 
     | 
|
| 
       445 
454 
     | 
    
         
             
                    **kwargs,
         
     | 
| 
       446 
455 
     | 
    
         
             
                )
         
     | 
| 
       447 
456 
     | 
    
         | 
| 
      
 457 
     | 
    
         
            +
                if _is_deepseek_ocr_model(config):
         
     | 
| 
      
 458 
     | 
    
         
            +
                    # Temporary hack for load deepseek-ocr
         
     | 
| 
      
 459 
     | 
    
         
            +
                    config.model_type = "deepseek-ocr"
         
     | 
| 
      
 460 
     | 
    
         
            +
             
     | 
| 
       448 
461 
     | 
    
         
             
                # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
         
     | 
| 
       449 
462 
     | 
    
         
             
                if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
         
     | 
| 
       450 
463 
     | 
    
         
             
                    if "size" not in kwargs:
         
     | 
| 
         @@ -462,13 +475,22 @@ def get_processor( 
     | 
|
| 
       462 
475 
     | 
    
         
             
                            **kwargs,
         
     | 
| 
       463 
476 
     | 
    
         
             
                        )
         
     | 
| 
       464 
477 
     | 
    
         
             
                    else:
         
     | 
| 
       465 
     | 
    
         
            -
                         
     | 
| 
       466 
     | 
    
         
            -
                             
     | 
| 
       467 
     | 
    
         
            -
             
     | 
| 
       468 
     | 
    
         
            -
             
     | 
| 
       469 
     | 
    
         
            -
             
     | 
| 
       470 
     | 
    
         
            -
             
     | 
| 
       471 
     | 
    
         
            -
             
     | 
| 
      
 478 
     | 
    
         
            +
                        if config.model_type in _CUSTOMIZED_MM_PROCESSOR:
         
     | 
| 
      
 479 
     | 
    
         
            +
                            processor = _CUSTOMIZED_MM_PROCESSOR[config.model_type].from_pretrained(
         
     | 
| 
      
 480 
     | 
    
         
            +
                                tokenizer_name,
         
     | 
| 
      
 481 
     | 
    
         
            +
                                *args,
         
     | 
| 
      
 482 
     | 
    
         
            +
                                trust_remote_code=trust_remote_code,
         
     | 
| 
      
 483 
     | 
    
         
            +
                                revision=revision,
         
     | 
| 
      
 484 
     | 
    
         
            +
                                **kwargs,
         
     | 
| 
      
 485 
     | 
    
         
            +
                            )
         
     | 
| 
      
 486 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 487 
     | 
    
         
            +
                            processor = AutoProcessor.from_pretrained(
         
     | 
| 
      
 488 
     | 
    
         
            +
                                tokenizer_name,
         
     | 
| 
      
 489 
     | 
    
         
            +
                                *args,
         
     | 
| 
      
 490 
     | 
    
         
            +
                                trust_remote_code=trust_remote_code,
         
     | 
| 
      
 491 
     | 
    
         
            +
                                revision=revision,
         
     | 
| 
      
 492 
     | 
    
         
            +
                                **kwargs,
         
     | 
| 
      
 493 
     | 
    
         
            +
                            )
         
     | 
| 
       472 
494 
     | 
    
         | 
| 
       473 
495 
     | 
    
         
             
                except ValueError as e:
         
     | 
| 
       474 
496 
     | 
    
         
             
                    error_message = str(e)
         
     | 
| 
         @@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC): 
     | 
|
| 
       41 
41 
     | 
    
         
             
                def region(self, tag: str, enable_cpu_backup: bool = False):
         
     | 
| 
       42 
42 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       43 
43 
     | 
    
         | 
| 
      
 44 
     | 
    
         
            +
                def cuda_graph(self, **kwargs):
         
     | 
| 
      
 45 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                def disable(self):
         
     | 
| 
      
 48 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
       44 
50 
     | 
    
         
             
                def pause(self, tag: str):
         
     | 
| 
       45 
51 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       46 
52 
     | 
    
         | 
| 
         @@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): 
     | 
|
| 
       61 
67 
     | 
    
         
             
                def region(self, tag: str, enable_cpu_backup: bool = False):
         
     | 
| 
       62 
68 
     | 
    
         
             
                    return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
         
     | 
| 
       63 
69 
     | 
    
         | 
| 
      
 70 
     | 
    
         
            +
                def cuda_graph(self, **kwargs):
         
     | 
| 
      
 71 
     | 
    
         
            +
                    return _memory_saver.cuda_graph(**kwargs)
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                def disable(self):
         
     | 
| 
      
 74 
     | 
    
         
            +
                    return _memory_saver.disable()
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
       64 
76 
     | 
    
         
             
                def pause(self, tag: str):
         
     | 
| 
       65 
77 
     | 
    
         
             
                    return _memory_saver.pause(tag=tag)
         
     | 
| 
       66 
78 
     | 
    
         | 
| 
         @@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): 
     | 
|
| 
       81 
93 
     | 
    
         
             
                def region(self, tag: str, enable_cpu_backup: bool = False):
         
     | 
| 
       82 
94 
     | 
    
         
             
                    yield
         
     | 
| 
       83 
95 
     | 
    
         | 
| 
      
 96 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 97 
     | 
    
         
            +
                def cuda_graph(self, **kwargs):
         
     | 
| 
      
 98 
     | 
    
         
            +
                    yield
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 101 
     | 
    
         
            +
                def disable(self):
         
     | 
| 
      
 102 
     | 
    
         
            +
                    yield
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
       84 
104 
     | 
    
         
             
                def pause(self, tag: str):
         
     | 
| 
       85 
105 
     | 
    
         
             
                    pass
         
     | 
| 
       86 
106 
     | 
    
         | 
| 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import random
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import requests
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            def gen_radix_tree(num_nodes=400, chunk_len=256):
         
     | 
| 
      
 7 
     | 
    
         
            +
                num0 = num_nodes // 2
         
     | 
| 
      
 8 
     | 
    
         
            +
                num1 = num_nodes - num0
         
     | 
| 
      
 9 
     | 
    
         
            +
                nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
         
     | 
| 
      
 10 
     | 
    
         
            +
                for _ in range(num0):
         
     | 
| 
      
 11 
     | 
    
         
            +
                    parent = random.choice(nodes)
         
     | 
| 
      
 12 
     | 
    
         
            +
                    unique_len = random.randint(0, chunk_len)
         
     | 
| 
      
 13 
     | 
    
         
            +
                    decode_len = random.randint(0, chunk_len)
         
     | 
| 
      
 14 
     | 
    
         
            +
                    token_id = random.randint(0, 32000)
         
     | 
| 
      
 15 
     | 
    
         
            +
                    child = {
         
     | 
| 
      
 16 
     | 
    
         
            +
                        "input_ids": parent["input_ids"] + [token_id] * unique_len,
         
     | 
| 
      
 17 
     | 
    
         
            +
                        "decode_len": decode_len,
         
     | 
| 
      
 18 
     | 
    
         
            +
                    }
         
     | 
| 
      
 19 
     | 
    
         
            +
                    nodes.append(child)
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
                while num1 > 0:
         
     | 
| 
      
 22 
     | 
    
         
            +
                    num_branch = random.randint(1, min(num1, 10))
         
     | 
| 
      
 23 
     | 
    
         
            +
                    parent = random.choice(nodes)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    for _ in range(num_branch):
         
     | 
| 
      
 25 
     | 
    
         
            +
                        unique_len = random.randint(0, chunk_len)
         
     | 
| 
      
 26 
     | 
    
         
            +
                        decode_len = random.randint(0, chunk_len)
         
     | 
| 
      
 27 
     | 
    
         
            +
                        token_id = random.randint(0, 32000)
         
     | 
| 
      
 28 
     | 
    
         
            +
                        child = {
         
     | 
| 
      
 29 
     | 
    
         
            +
                            "input_ids": parent["input_ids"] + [token_id] * unique_len,
         
     | 
| 
      
 30 
     | 
    
         
            +
                            "decode_len": decode_len,
         
     | 
| 
      
 31 
     | 
    
         
            +
                        }
         
     | 
| 
      
 32 
     | 
    
         
            +
                        nodes.append(child)
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                    num1 -= num_branch
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                random.shuffle(nodes)
         
     | 
| 
      
 37 
     | 
    
         
            +
                return nodes
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
            def run_radix_attention_test(base_url: str):
         
     | 
| 
      
 41 
     | 
    
         
            +
                nodes = gen_radix_tree()
         
     | 
| 
      
 42 
     | 
    
         
            +
                data = {
         
     | 
| 
      
 43 
     | 
    
         
            +
                    "input_ids": [node["input_ids"] for node in nodes],
         
     | 
| 
      
 44 
     | 
    
         
            +
                    "sampling_params": [
         
     | 
| 
      
 45 
     | 
    
         
            +
                        {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
         
     | 
| 
      
 46 
     | 
    
         
            +
                    ],
         
     | 
| 
      
 47 
     | 
    
         
            +
                }
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                res = requests.post(base_url + "/generate", json=data)
         
     | 
| 
      
 50 
     | 
    
         
            +
                assert res.status_code == 200
         
     |