sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +73 -14
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +30 -7
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -12
 - sglang/srt/entrypoints/engine.py +31 -20
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +94 -94
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +23 -2
 - sglang/srt/eplb/expert_distribution.py +64 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +19 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +24 -1
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +35 -6
 - sglang/srt/layers/logits_processor.py +9 -20
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +78 -289
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +35 -10
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +13 -84
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +130 -46
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +29 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +185 -144
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +165 -78
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +253 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +55 -14
 - sglang/srt/model_executor/model_runner.py +77 -170
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +296 -78
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +29 -197
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +23 -2
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +35 -5
 - sglang/srt/models/qwen3_moe.py +18 -12
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +459 -199
 - sglang/srt/single_batch_overlap.py +2 -4
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +142 -74
 - sglang/srt/utils/hf_transformers_utils.py +38 -12
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
 - sglang/srt/models/vila.py +0 -306
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -4,14 +4,128 @@ from typing import List, Optional 
     | 
|
| 
       4 
4 
     | 
    
         | 
| 
       5 
5 
     | 
    
         
             
            import torch
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
     | 
    
         
            -
            from sglang.srt.utils import is_cuda, is_hip
         
     | 
| 
      
 7 
     | 
    
         
            +
            from sglang.srt.utils import is_cuda, is_hip, is_npu
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
      
 9 
     | 
    
         
            +
            _is_cuda = is_cuda()
         
     | 
| 
      
 10 
     | 
    
         
            +
            _is_hip = is_hip()
         
     | 
| 
      
 11 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            if _is_cuda or _is_hip:
         
     | 
| 
       10 
14 
     | 
    
         
             
                from sgl_kernel import (
         
     | 
| 
       11 
15 
     | 
    
         
             
                    build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
         
     | 
| 
       12 
16 
     | 
    
         
             
                )
         
     | 
| 
       13 
17 
     | 
    
         | 
| 
       14 
18 
     | 
    
         | 
| 
      
 19 
     | 
    
         
            +
            def build_tree_efficient_native(
         
     | 
| 
      
 20 
     | 
    
         
            +
                parent_list: torch.Tensor,
         
     | 
| 
      
 21 
     | 
    
         
            +
                selected_index: torch.Tensor,
         
     | 
| 
      
 22 
     | 
    
         
            +
                verified_seq_len: torch.Tensor,
         
     | 
| 
      
 23 
     | 
    
         
            +
                tree_mask: torch.Tensor,
         
     | 
| 
      
 24 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 25 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 26 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 27 
     | 
    
         
            +
                topk: int,
         
     | 
| 
      
 28 
     | 
    
         
            +
                draft_token_num: int,
         
     | 
| 
      
 29 
     | 
    
         
            +
                tree_mask_mode: int,
         
     | 
| 
      
 30 
     | 
    
         
            +
                bs: int,
         
     | 
| 
      
 31 
     | 
    
         
            +
            ):
         
     | 
| 
      
 32 
     | 
    
         
            +
                # Generate batch and token index ranges
         
     | 
| 
      
 33 
     | 
    
         
            +
                bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
         
     | 
| 
      
 34 
     | 
    
         
            +
                draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                # Optimized common case for performance.
         
     | 
| 
      
 37 
     | 
    
         
            +
                if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    positions = verified_seq_len.repeat_interleave(draft_token_num)
         
     | 
| 
      
 39 
     | 
    
         
            +
                    positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                    retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
         
     | 
| 
      
 42 
     | 
    
         
            +
                    retrive_next_token[:, 0] = 1
         
     | 
| 
      
 43 
     | 
    
         
            +
                    retrive_next_token[:, 1] = -1
         
     | 
| 
      
 44 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 45 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 46 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 47 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 48 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 49 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 50 
     | 
    
         
            +
                    )
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                # Precompute sequence tree indices
         
     | 
| 
      
 53 
     | 
    
         
            +
                draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
         
     | 
| 
      
 54 
     | 
    
         
            +
                cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
         
     | 
| 
      
 55 
     | 
    
         
            +
                cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
         
     | 
| 
      
 56 
     | 
    
         
            +
                cum_seq_len = cum_seq_len[:-1]
         
     | 
| 
      
 57 
     | 
    
         
            +
                seq_tree_idx = (
         
     | 
| 
      
 58 
     | 
    
         
            +
                    draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
         
     | 
| 
      
 59 
     | 
    
         
            +
                    + cum_seq_len
         
     | 
| 
      
 60 
     | 
    
         
            +
                )
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                # Batch processing for tree mask
         
     | 
| 
      
 63 
     | 
    
         
            +
                if tree_mask_mode == TreeMaskMode.FULL_MASK:
         
     | 
| 
      
 64 
     | 
    
         
            +
                    token_tree_base = (
         
     | 
| 
      
 65 
     | 
    
         
            +
                        seq_tree_idx.view(-1, 1)
         
     | 
| 
      
 66 
     | 
    
         
            +
                        + (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
         
     | 
| 
      
 67 
     | 
    
         
            +
                    )
         
     | 
| 
      
 68 
     | 
    
         
            +
                    token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
         
     | 
| 
      
 69 
     | 
    
         
            +
                else:
         
     | 
| 
      
 70 
     | 
    
         
            +
                    token_tree_indices = (
         
     | 
| 
      
 71 
     | 
    
         
            +
                        bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
         
     | 
| 
      
 72 
     | 
    
         
            +
                    )
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                tree_mask[token_tree_indices.flatten() - 1] = True
         
     | 
| 
      
 75 
     | 
    
         
            +
                indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
         
     | 
| 
      
 76 
     | 
    
         
            +
                tree_mask[indices.view(-1)] = False
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                positions = verified_seq_len.repeat_interleave(draft_token_num)
         
     | 
| 
      
 79 
     | 
    
         
            +
                parent_tb_indices = selected_index // topk
         
     | 
| 
      
 80 
     | 
    
         
            +
                retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
         
     | 
| 
      
 81 
     | 
    
         
            +
                tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                for bid in range(bs):
         
     | 
| 
      
 84 
     | 
    
         
            +
                    for tid in range(draft_token_num):
         
     | 
| 
      
 85 
     | 
    
         
            +
                        position = 0
         
     | 
| 
      
 86 
     | 
    
         
            +
                        if tid == 0:
         
     | 
| 
      
 87 
     | 
    
         
            +
                            # Process root node
         
     | 
| 
      
 88 
     | 
    
         
            +
                            for i in range(draft_token_num - 1, 0, -1):
         
     | 
| 
      
 89 
     | 
    
         
            +
                                parent_position = 0
         
     | 
| 
      
 90 
     | 
    
         
            +
                                parent_tb_idx = parent_tb_indices[bid][i - 1]
         
     | 
| 
      
 91 
     | 
    
         
            +
                                if parent_tb_idx > 0:
         
     | 
| 
      
 92 
     | 
    
         
            +
                                    parent_token_idx = parent_list[bid][parent_tb_idx]
         
     | 
| 
      
 93 
     | 
    
         
            +
                                    loop_num = draft_token_num - parent_position
         
     | 
| 
      
 94 
     | 
    
         
            +
                                    for _ in range(loop_num):
         
     | 
| 
      
 95 
     | 
    
         
            +
                                        if selected_index[bid][parent_position] == parent_token_idx:
         
     | 
| 
      
 96 
     | 
    
         
            +
                                            parent_position += 1
         
     | 
| 
      
 97 
     | 
    
         
            +
                                            break
         
     | 
| 
      
 98 
     | 
    
         
            +
                                        parent_position += 1
         
     | 
| 
      
 99 
     | 
    
         
            +
                                if parent_position == draft_token_num:
         
     | 
| 
      
 100 
     | 
    
         
            +
                                    continue
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                                if retrive_next_token[bid][parent_position] != -1:
         
     | 
| 
      
 103 
     | 
    
         
            +
                                    retrive_next_sibling[bid][i] = retrive_next_token[bid][
         
     | 
| 
      
 104 
     | 
    
         
            +
                                        parent_position
         
     | 
| 
      
 105 
     | 
    
         
            +
                                    ]
         
     | 
| 
      
 106 
     | 
    
         
            +
                                retrive_next_token[bid][parent_position] = i
         
     | 
| 
      
 107 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 108 
     | 
    
         
            +
                            # Process no-root nodes
         
     | 
| 
      
 109 
     | 
    
         
            +
                            cur_position = tid - 1
         
     | 
| 
      
 110 
     | 
    
         
            +
                            while True:
         
     | 
| 
      
 111 
     | 
    
         
            +
                                position += 1
         
     | 
| 
      
 112 
     | 
    
         
            +
                                if cur_position >= draft_token_num:
         
     | 
| 
      
 113 
     | 
    
         
            +
                                    tree_mask[token_tree_indices + cur_position] = True
         
     | 
| 
      
 114 
     | 
    
         
            +
                                    parent_tb_idx = selected_index[bid][cur_position] // topk
         
     | 
| 
      
 115 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 116 
     | 
    
         
            +
                                    parent_tb_idx = parent_tb_indices[bid][cur_position]
         
     | 
| 
      
 117 
     | 
    
         
            +
                                if parent_tb_idx == 0:
         
     | 
| 
      
 118 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 119 
     | 
    
         
            +
                                token_idx = parent_list[bid][parent_tb_idx]
         
     | 
| 
      
 120 
     | 
    
         
            +
                                cur_position = 0
         
     | 
| 
      
 121 
     | 
    
         
            +
                                for _ in range(draft_token_num):
         
     | 
| 
      
 122 
     | 
    
         
            +
                                    if selected_index[bid][cur_position] == token_idx:
         
     | 
| 
      
 123 
     | 
    
         
            +
                                        break
         
     | 
| 
      
 124 
     | 
    
         
            +
                                    cur_position += 1
         
     | 
| 
      
 125 
     | 
    
         
            +
                            positions[bid * draft_token_num + tid] += position
         
     | 
| 
      
 126 
     | 
    
         
            +
                return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
       15 
129 
     | 
    
         
             
            def organize_draft_results(
         
     | 
| 
       16 
130 
     | 
    
         
             
                score_list: List[torch.Tensor],
         
     | 
| 
       17 
131 
     | 
    
         
             
                token_list: List[torch.Tensor],
         
     | 
| 
         @@ -114,20 +228,41 @@ def build_tree_kernel_efficient( 
     | 
|
| 
       114 
228 
     | 
    
         
             
                        (bs * num_verify_tokens,), device=device, dtype=torch.long
         
     | 
| 
       115 
229 
     | 
    
         
             
                    )
         
     | 
| 
       116 
230 
     | 
    
         | 
| 
       117 
     | 
    
         
            -
                 
     | 
| 
       118 
     | 
    
         
            -
                     
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
             
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
                     
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
      
 231 
     | 
    
         
            +
                if _is_npu:
         
     | 
| 
      
 232 
     | 
    
         
            +
                    (
         
     | 
| 
      
 233 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 234 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 235 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 236 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 237 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 238 
     | 
    
         
            +
                    ) = build_tree_efficient_native(
         
     | 
| 
      
 239 
     | 
    
         
            +
                        parent_list,
         
     | 
| 
      
 240 
     | 
    
         
            +
                        top_scores_index,
         
     | 
| 
      
 241 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 242 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 243 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 244 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 245 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 246 
     | 
    
         
            +
                        topk,
         
     | 
| 
      
 247 
     | 
    
         
            +
                        num_verify_tokens,
         
     | 
| 
      
 248 
     | 
    
         
            +
                        tree_mask_mode,
         
     | 
| 
      
 249 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 250 
     | 
    
         
            +
                    )
         
     | 
| 
      
 251 
     | 
    
         
            +
                else:
         
     | 
| 
      
 252 
     | 
    
         
            +
                    sgl_build_tree_kernel_efficient(
         
     | 
| 
      
 253 
     | 
    
         
            +
                        parent_list,
         
     | 
| 
      
 254 
     | 
    
         
            +
                        top_scores_index,
         
     | 
| 
      
 255 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 256 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 257 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 258 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 259 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 260 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 261 
     | 
    
         
            +
                        topk,
         
     | 
| 
      
 262 
     | 
    
         
            +
                        spec_steps,
         
     | 
| 
      
 263 
     | 
    
         
            +
                        num_verify_tokens,
         
     | 
| 
      
 264 
     | 
    
         
            +
                        tree_mask_mode,
         
     | 
| 
      
 265 
     | 
    
         
            +
                    )
         
     | 
| 
       131 
266 
     | 
    
         
             
                return (
         
     | 
| 
       132 
267 
     | 
    
         
             
                    tree_mask,
         
     | 
| 
       133 
268 
     | 
    
         
             
                    positions,
         
     | 
| 
         @@ -136,3 +271,113 @@ def build_tree_kernel_efficient( 
     | 
|
| 
       136 
271 
     | 
    
         
             
                    retrive_next_sibling,
         
     | 
| 
       137 
272 
     | 
    
         
             
                    draft_tokens,
         
     | 
| 
       138 
273 
     | 
    
         
             
                )
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
            def verify_tree_greedy_native(
         
     | 
| 
      
 277 
     | 
    
         
            +
                predicts: torch.Tensor,
         
     | 
| 
      
 278 
     | 
    
         
            +
                accept_index: torch.Tensor,
         
     | 
| 
      
 279 
     | 
    
         
            +
                accept_token_num: torch.Tensor,
         
     | 
| 
      
 280 
     | 
    
         
            +
                candidates: torch.Tensor,
         
     | 
| 
      
 281 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 282 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 283 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 284 
     | 
    
         
            +
                target_predict: torch.Tensor,
         
     | 
| 
      
 285 
     | 
    
         
            +
                topk: int = -1,
         
     | 
| 
      
 286 
     | 
    
         
            +
            ):
         
     | 
| 
      
 287 
     | 
    
         
            +
                batch_size, num_draft_tokens = candidates.shape
         
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
                # Optimized common case for performance.
         
     | 
| 
      
 290 
     | 
    
         
            +
                if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
         
     | 
| 
      
 291 
     | 
    
         
            +
                    comparison_result = candidates[:, 1] == target_predict[:, 0]
         
     | 
| 
      
 292 
     | 
    
         
            +
             
     | 
| 
      
 293 
     | 
    
         
            +
                    predicts = target_predict.flatten()
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                    accept_index = torch.arange(
         
     | 
| 
      
 296 
     | 
    
         
            +
                        0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
         
     | 
| 
      
 297 
     | 
    
         
            +
                    ).reshape(batch_size, num_draft_tokens)
         
     | 
| 
      
 298 
     | 
    
         
            +
                    comparison_result = comparison_result.to(torch.int64)
         
     | 
| 
      
 299 
     | 
    
         
            +
                    accept_index_mask = accept_index[:, 1] * comparison_result
         
     | 
| 
      
 300 
     | 
    
         
            +
                    accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
         
     | 
| 
      
 301 
     | 
    
         
            +
             
     | 
| 
      
 302 
     | 
    
         
            +
                    accept_token_num = comparison_result.int()
         
     | 
| 
      
 303 
     | 
    
         
            +
                    return predicts, accept_index, accept_token_num
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
                # BFS
         
     | 
| 
      
 306 
     | 
    
         
            +
                for bx in range(batch_size):
         
     | 
| 
      
 307 
     | 
    
         
            +
                    cur_candidates = candidates[bx]
         
     | 
| 
      
 308 
     | 
    
         
            +
                    cur_retrive_index = retrive_index[bx]
         
     | 
| 
      
 309 
     | 
    
         
            +
                    cur_next_token = retrive_next_token[bx]
         
     | 
| 
      
 310 
     | 
    
         
            +
                    cur_next_sibling = retrive_next_sibling[bx]
         
     | 
| 
      
 311 
     | 
    
         
            +
                    cur_target = target_predict[bx]
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                    last_accepted_idx = cur_retrive_index[0]
         
     | 
| 
      
 314 
     | 
    
         
            +
                    accept_index[bx, 0] = last_accepted_idx
         
     | 
| 
      
 315 
     | 
    
         
            +
                    num_accepted = 0
         
     | 
| 
      
 316 
     | 
    
         
            +
                    cur_node = 0
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
                    for _ in range(1, num_draft_tokens):
         
     | 
| 
      
 319 
     | 
    
         
            +
                        cur_node = cur_next_token[cur_node]
         
     | 
| 
      
 320 
     | 
    
         
            +
                        found = False
         
     | 
| 
      
 321 
     | 
    
         
            +
                        while cur_node != -1:
         
     | 
| 
      
 322 
     | 
    
         
            +
                            draft_idx = cur_retrive_index[cur_node]
         
     | 
| 
      
 323 
     | 
    
         
            +
                            draft_token = cur_candidates[cur_node]
         
     | 
| 
      
 324 
     | 
    
         
            +
                            target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                            if draft_token == target_token:
         
     | 
| 
      
 327 
     | 
    
         
            +
                                predicts[last_accepted_idx] = target_token
         
     | 
| 
      
 328 
     | 
    
         
            +
                                num_accepted += 1
         
     | 
| 
      
 329 
     | 
    
         
            +
                                accept_index[bx, num_accepted] = draft_idx
         
     | 
| 
      
 330 
     | 
    
         
            +
                                last_accepted_idx = draft_idx
         
     | 
| 
      
 331 
     | 
    
         
            +
                                found = True
         
     | 
| 
      
 332 
     | 
    
         
            +
                                break
         
     | 
| 
      
 333 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 334 
     | 
    
         
            +
                                cur_node = cur_next_sibling[cur_node]
         
     | 
| 
      
 335 
     | 
    
         
            +
                        if not found:
         
     | 
| 
      
 336 
     | 
    
         
            +
                            break
         
     | 
| 
      
 337 
     | 
    
         
            +
             
     | 
| 
      
 338 
     | 
    
         
            +
                    accept_token_num[bx] = num_accepted
         
     | 
| 
      
 339 
     | 
    
         
            +
                    predicts[last_accepted_idx] = cur_target[
         
     | 
| 
      
 340 
     | 
    
         
            +
                        last_accepted_idx - num_draft_tokens * bx
         
     | 
| 
      
 341 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 342 
     | 
    
         
            +
                return predicts, accept_index, accept_token_num
         
     | 
| 
      
 343 
     | 
    
         
            +
             
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
            def verify_tree_greedy_func(
         
     | 
| 
      
 346 
     | 
    
         
            +
                predicts: torch.Tensor,
         
     | 
| 
      
 347 
     | 
    
         
            +
                accept_index: torch.Tensor,
         
     | 
| 
      
 348 
     | 
    
         
            +
                accept_token_num: torch.Tensor,
         
     | 
| 
      
 349 
     | 
    
         
            +
                candidates: torch.Tensor,
         
     | 
| 
      
 350 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 351 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 352 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 353 
     | 
    
         
            +
                target_predict: torch.Tensor,
         
     | 
| 
      
 354 
     | 
    
         
            +
                topk: int = -1,
         
     | 
| 
      
 355 
     | 
    
         
            +
            ):
         
     | 
| 
      
 356 
     | 
    
         
            +
                if _is_cuda or _is_hip:
         
     | 
| 
      
 357 
     | 
    
         
            +
                    from sgl_kernel import verify_tree_greedy
         
     | 
| 
      
 358 
     | 
    
         
            +
             
     | 
| 
      
 359 
     | 
    
         
            +
                    verify_tree_greedy(
         
     | 
| 
      
 360 
     | 
    
         
            +
                        predicts=predicts,  # mutable
         
     | 
| 
      
 361 
     | 
    
         
            +
                        accept_index=accept_index,  # mutable
         
     | 
| 
      
 362 
     | 
    
         
            +
                        accept_token_num=accept_token_num,  # mutable
         
     | 
| 
      
 363 
     | 
    
         
            +
                        candidates=candidates,
         
     | 
| 
      
 364 
     | 
    
         
            +
                        retrive_index=retrive_index,
         
     | 
| 
      
 365 
     | 
    
         
            +
                        retrive_next_token=retrive_next_token,
         
     | 
| 
      
 366 
     | 
    
         
            +
                        retrive_next_sibling=retrive_next_sibling,
         
     | 
| 
      
 367 
     | 
    
         
            +
                        target_predict=target_predict,
         
     | 
| 
      
 368 
     | 
    
         
            +
                    )
         
     | 
| 
      
 369 
     | 
    
         
            +
             
     | 
| 
      
 370 
     | 
    
         
            +
                elif _is_npu:
         
     | 
| 
      
 371 
     | 
    
         
            +
                    predicts, accept_index, accept_token_num = verify_tree_greedy_native(
         
     | 
| 
      
 372 
     | 
    
         
            +
                        predicts=predicts,  # mutable
         
     | 
| 
      
 373 
     | 
    
         
            +
                        accept_index=accept_index,  # mutable
         
     | 
| 
      
 374 
     | 
    
         
            +
                        accept_token_num=accept_token_num,  # mutable
         
     | 
| 
      
 375 
     | 
    
         
            +
                        candidates=candidates,
         
     | 
| 
      
 376 
     | 
    
         
            +
                        retrive_index=retrive_index,
         
     | 
| 
      
 377 
     | 
    
         
            +
                        retrive_next_token=retrive_next_token,
         
     | 
| 
      
 378 
     | 
    
         
            +
                        retrive_next_sibling=retrive_next_sibling,
         
     | 
| 
      
 379 
     | 
    
         
            +
                        target_predict=target_predict,
         
     | 
| 
      
 380 
     | 
    
         
            +
                        topk=topk,
         
     | 
| 
      
 381 
     | 
    
         
            +
                    )
         
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
      
 383 
     | 
    
         
            +
                return predicts, accept_index, accept_token_num
         
     | 
| 
         @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple 
     | 
|
| 
       5 
5 
     | 
    
         
             
            import torch
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            from sglang.srt.distributed import get_tp_group
         
     | 
| 
      
 8 
     | 
    
         
            +
            from sglang.srt.layers.dp_attention import get_attention_tp_group
         
     | 
| 
       8 
9 
     | 
    
         
             
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         
     | 
| 
       9 
10 
     | 
    
         
             
            from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
         
     | 
| 
       10 
11 
     | 
    
         
             
            from sglang.srt.managers.schedule_batch import ScheduleBatch
         
     | 
| 
         @@ -52,9 +53,12 @@ from sglang.srt.utils import ( 
     | 
|
| 
       52 
53 
     | 
    
         
             
                get_available_gpu_memory,
         
     | 
| 
       53 
54 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
       54 
55 
     | 
    
         
             
                is_cuda,
         
     | 
| 
      
 56 
     | 
    
         
            +
                is_npu,
         
     | 
| 
       55 
57 
     | 
    
         
             
                next_power_of_2,
         
     | 
| 
       56 
58 
     | 
    
         
             
            )
         
     | 
| 
       57 
59 
     | 
    
         | 
| 
      
 60 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
       58 
62 
     | 
    
         
             
            if is_cuda():
         
     | 
| 
       59 
63 
     | 
    
         
             
                from sgl_kernel import segment_packbits  # noqa: F401
         
     | 
| 
       60 
64 
     | 
    
         | 
| 
         @@ -117,7 +121,11 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       117 
121 
     | 
    
         
             
                        self.hot_token_id = None
         
     | 
| 
       118 
122 
     | 
    
         | 
| 
       119 
123 
     | 
    
         
             
                    # Init draft worker
         
     | 
| 
       120 
     | 
    
         
            -
                     
     | 
| 
      
 124 
     | 
    
         
            +
                    if server_args.enable_dp_attention and self.speculative_algorithm.is_eagle3():
         
     | 
| 
      
 125 
     | 
    
         
            +
                        ctx = draft_tp_context(get_attention_tp_group())
         
     | 
| 
      
 126 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 127 
     | 
    
         
            +
                        ctx = empty_context()
         
     | 
| 
      
 128 
     | 
    
         
            +
                    with ctx:
         
     | 
| 
       121 
129 
     | 
    
         
             
                        super().__init__(
         
     | 
| 
       122 
130 
     | 
    
         
             
                            server_args=server_args,
         
     | 
| 
       123 
131 
     | 
    
         
             
                            gpu_id=gpu_id,
         
     | 
| 
         @@ -200,7 +208,7 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       200 
208 
     | 
    
         
             
                    self.cuda_graph_runner = None
         
     | 
| 
       201 
209 
     | 
    
         
             
                    self.cuda_graph_runner_for_draft_extend = None
         
     | 
| 
       202 
210 
     | 
    
         | 
| 
       203 
     | 
    
         
            -
                    if self.server_args.disable_cuda_graph:
         
     | 
| 
      
 211 
     | 
    
         
            +
                    if self.server_args.disable_cuda_graph or _is_npu:
         
     | 
| 
       204 
212 
     | 
    
         
             
                        return
         
     | 
| 
       205 
213 
     | 
    
         | 
| 
       206 
214 
     | 
    
         
             
                    # Capture draft
         
     | 
| 
         @@ -940,7 +948,7 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       940 
948 
     | 
    
         
             
                    draft_input.hidden_states = logits_output.hidden_states
         
     | 
| 
       941 
949 
     | 
    
         | 
| 
       942 
950 
     | 
    
         | 
| 
       943 
     | 
    
         
            -
            @torch.compile(dynamic=True)
         
     | 
| 
      
 951 
     | 
    
         
            +
            @torch.compile(dynamic=True, disable=_is_npu)
         
     | 
| 
       944 
952 
     | 
    
         
             
            def get_last_loc_large_page_size_top_k_1(
         
     | 
| 
       945 
953 
     | 
    
         
             
                req_to_token: torch.Tensor,
         
     | 
| 
       946 
954 
     | 
    
         
             
                req_pool_indices: torch.Tensor,
         
     | 
| 
         @@ -4,7 +4,6 @@ import time 
     | 
|
| 
       4 
4 
     | 
    
         
             
            from typing import List, Optional, Tuple
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            import torch
         
     | 
| 
       7 
     | 
    
         
            -
            from torch.cuda import Stream as CudaStream
         
     | 
| 
       8 
7 
     | 
    
         | 
| 
       9 
8 
     | 
    
         
             
            from sglang.srt.environ import envs
         
     | 
| 
       10 
9 
     | 
    
         
             
            from sglang.srt.managers.schedule_batch import ModelWorkerBatch
         
     | 
| 
         @@ -38,18 +37,21 @@ from sglang.srt.utils.common import ( 
     | 
|
| 
       38 
37 
     | 
    
         
             
                empty_context,
         
     | 
| 
       39 
38 
     | 
    
         
             
                fast_topk,
         
     | 
| 
       40 
39 
     | 
    
         
             
                get_available_gpu_memory,
         
     | 
| 
      
 40 
     | 
    
         
            +
                is_npu,
         
     | 
| 
       41 
41 
     | 
    
         
             
                next_power_of_2,
         
     | 
| 
       42 
42 
     | 
    
         
             
            )
         
     | 
| 
       43 
43 
     | 
    
         | 
| 
      
 44 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
       44 
46 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       45 
47 
     | 
    
         | 
| 
       46 
48 
     | 
    
         | 
| 
       47 
49 
     | 
    
         
             
            def _get_plan_stream(
         
     | 
| 
       48 
50 
     | 
    
         
             
                device: str,
         
     | 
| 
       49 
     | 
    
         
            -
            ) -> Tuple[ 
     | 
| 
      
 51 
     | 
    
         
            +
            ) -> Tuple[any, contextlib.AbstractContextManager]:
         
     | 
| 
       50 
52 
     | 
    
         
             
                if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
         
     | 
| 
       51 
     | 
    
         
            -
                    plan_stream 
     | 
| 
       52 
     | 
    
         
            -
                    plan_stream_ctx = torch. 
     | 
| 
      
 53 
     | 
    
         
            +
                    plan_stream = torch.get_device_module(device).Stream()
         
     | 
| 
      
 54 
     | 
    
         
            +
                    plan_stream_ctx = torch.get_device_module(device).stream(plan_stream)
         
     | 
| 
       53 
55 
     | 
    
         
             
                    return plan_stream, plan_stream_ctx
         
     | 
| 
       54 
56 
     | 
    
         
             
                else:
         
     | 
| 
       55 
57 
     | 
    
         
             
                    return None, contextlib.nullcontext()
         
     | 
| 
         @@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker): 
     | 
|
| 
       206 
208 
     | 
    
         
             
                    self.cuda_graph_runner = None
         
     | 
| 
       207 
209 
     | 
    
         
             
                    self.cuda_graph_runner_for_draft_extend = None
         
     | 
| 
       208 
210 
     | 
    
         | 
| 
       209 
     | 
    
         
            -
                    if self.server_args.disable_cuda_graph:
         
     | 
| 
      
 211 
     | 
    
         
            +
                    if self.server_args.disable_cuda_graph or _is_npu:
         
     | 
| 
       210 
212 
     | 
    
         
             
                        return
         
     | 
| 
       211 
213 
     | 
    
         | 
| 
       212 
214 
     | 
    
         
             
                    # Capture draft
         
     | 
| 
         @@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker): 
     | 
|
| 
       456 
458 
     | 
    
         
             
                        )
         
     | 
| 
       457 
459 
     | 
    
         | 
| 
       458 
460 
     | 
    
         
             
                    if self.plan_stream:
         
     | 
| 
       459 
     | 
    
         
            -
                        torch. 
     | 
| 
      
 461 
     | 
    
         
            +
                        torch.get_device_module(self.device).current_stream().wait_stream(
         
     | 
| 
      
 462 
     | 
    
         
            +
                            self.plan_stream
         
     | 
| 
      
 463 
     | 
    
         
            +
                        )
         
     | 
| 
       460 
464 
     | 
    
         | 
| 
       461 
465 
     | 
    
         
             
                    # Run draft extend batch in the main compute stream
         
     | 
| 
       462 
466 
     | 
    
         
             
                    draft_logits_output = self.draft_runner.model.forward(
         
     | 
| 
         @@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker): 
     | 
|
| 
       577 
581 
     | 
    
         
             
                    # Since batch.seq_lens is allocated in another stream, we need
         
     | 
| 
       578 
582 
     | 
    
         
             
                    # record_stream() to prevent pytorch gc and reuse the gpu memory
         
     | 
| 
       579 
583 
     | 
    
         
             
                    # while forward_stream is still running.
         
     | 
| 
       580 
     | 
    
         
            -
                    batch.seq_lens.record_stream( 
     | 
| 
      
 584 
     | 
    
         
            +
                    batch.seq_lens.record_stream(
         
     | 
| 
      
 585 
     | 
    
         
            +
                        torch.get_device_module(self.device).current_stream()
         
     | 
| 
      
 586 
     | 
    
         
            +
                    )
         
     | 
| 
       581 
587 
     | 
    
         | 
| 
       582 
588 
     | 
    
         
             
                    # Parse args
         
     | 
| 
       583 
589 
     | 
    
         
             
                    verify_input: EagleVerifyInput = batch.spec_info
         
     | 
| 
         @@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker): 
     | 
|
| 
       596 
602 
     | 
    
         | 
| 
       597 
603 
     | 
    
         
             
                    # Correct some buffers due to the overlap plan
         
     | 
| 
       598 
604 
     | 
    
         
             
                    if self.plan_stream:
         
     | 
| 
       599 
     | 
    
         
            -
                        torch. 
     | 
| 
      
 605 
     | 
    
         
            +
                        torch.get_device_module().current_stream().wait_stream(self.plan_stream)
         
     | 
| 
       600 
606 
     | 
    
         | 
| 
       601 
607 
     | 
    
         
             
                        # Some values such as custom_mask and position depend on the output of draft,
         
     | 
| 
       602 
608 
     | 
    
         
             
                        # so the previous plan step used the wrong values. Here, we need to run the related
         
     | 
| 
         @@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker): 
     | 
|
| 
       628 
634 
     | 
    
         
             
                        accept_index,
         
     | 
| 
       629 
635 
     | 
    
         
             
                    ) = verify_input.sample(batch, logits_output)
         
     | 
| 
       630 
636 
     | 
    
         
             
                    new_seq_lens = batch.seq_lens + accept_length
         
     | 
| 
       631 
     | 
    
         
            -
                    verify_done = torch. 
     | 
| 
      
 637 
     | 
    
         
            +
                    verify_done = torch.get_device_module(self.device).Event()
         
     | 
| 
       632 
638 
     | 
    
         
             
                    verify_done.record()
         
     | 
| 
       633 
639 
     | 
    
         | 
| 
       634 
640 
     | 
    
         
             
                    all_verified_id = predict[accept_index]
         
     |