sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import ( 
     | 
|
| 
       24 
24 
     | 
    
         
             
                EagleDraftInputV2Mixin,
         
     | 
| 
       25 
25 
     | 
    
         
             
                EagleVerifyInputV2Mixin,
         
     | 
| 
       26 
26 
     | 
    
         
             
            )
         
     | 
| 
      
 27 
     | 
    
         
            +
            from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
         
     | 
| 
       27 
28 
     | 
    
         
             
            from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
         
     | 
| 
       28 
29 
     | 
    
         
             
            from sglang.srt.speculative.spec_utils import (
         
     | 
| 
       29 
30 
     | 
    
         
             
                SIMULATE_ACC_LEN,
         
     | 
| 
       30 
31 
     | 
    
         
             
                TREE_SPEC_KERNEL_AVAILABLE,
         
     | 
| 
       31 
32 
     | 
    
         
             
                align_evict_mask_to_page_size,
         
     | 
| 
       32 
     | 
    
         
            -
                 
     | 
| 
      
 33 
     | 
    
         
            +
                assign_req_to_token_pool_func,
         
     | 
| 
       33 
34 
     | 
    
         
             
                create_accept_length_filter,
         
     | 
| 
       34 
35 
     | 
    
         
             
                create_extend_after_decode_spec_info,
         
     | 
| 
       35 
36 
     | 
    
         
             
                filter_finished_cache_loc_kernel,
         
     | 
| 
         @@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import ( 
     | 
|
| 
       37 
38 
     | 
    
         
             
                get_src_tgt_cache_loc,
         
     | 
| 
       38 
39 
     | 
    
         
             
                get_target_cache_loc,
         
     | 
| 
       39 
40 
     | 
    
         
             
            )
         
     | 
| 
       40 
     | 
    
         
            -
            from sglang.srt.utils import is_cuda,  
     | 
| 
      
 41 
     | 
    
         
            +
            from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
       41 
44 
     | 
    
         | 
| 
       42 
45 
     | 
    
         
             
            if is_cuda():
         
     | 
| 
       43 
46 
     | 
    
         
             
                from sgl_kernel import (
         
     | 
| 
       44 
47 
     | 
    
         
             
                    top_k_renorm_prob,
         
     | 
| 
       45 
48 
     | 
    
         
             
                    top_p_renorm_prob,
         
     | 
| 
       46 
49 
     | 
    
         
             
                    tree_speculative_sampling_target_only,
         
     | 
| 
       47 
     | 
    
         
            -
                    verify_tree_greedy,
         
     | 
| 
       48 
50 
     | 
    
         
             
                )
         
     | 
| 
       49 
     | 
    
         
            -
            elif is_hip():
         
     | 
| 
       50 
     | 
    
         
            -
                from sgl_kernel import verify_tree_greedy
         
     | 
| 
       51 
51 
     | 
    
         | 
| 
       52 
52 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       53 
53 
     | 
    
         | 
| 
         @@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       77 
77 
     | 
    
         | 
| 
       78 
78 
     | 
    
         
             
                @classmethod
         
     | 
| 
       79 
79 
     | 
    
         
             
                def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
         
     | 
| 
      
 80 
     | 
    
         
            +
                    if not _is_npu:
         
     | 
| 
      
 81 
     | 
    
         
            +
                        device = "cuda"
         
     | 
| 
      
 82 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 83 
     | 
    
         
            +
                        device = "npu"
         
     | 
| 
       80 
84 
     | 
    
         
             
                    return cls(
         
     | 
| 
       81 
     | 
    
         
            -
                        draft_token=torch.empty((0,), dtype=torch.long, device= 
     | 
| 
       82 
     | 
    
         
            -
                        custom_mask=torch.full((0,), True, dtype=torch.bool, device= 
     | 
| 
       83 
     | 
    
         
            -
                        positions=torch.empty((0,), dtype=torch.int64, device= 
     | 
| 
      
 85 
     | 
    
         
            +
                        draft_token=torch.empty((0,), dtype=torch.long, device=device),
         
     | 
| 
      
 86 
     | 
    
         
            +
                        custom_mask=torch.full((0,), True, dtype=torch.bool, device=device),
         
     | 
| 
      
 87 
     | 
    
         
            +
                        positions=torch.empty((0,), dtype=torch.int64, device=device),
         
     | 
| 
       84 
88 
     | 
    
         
             
                        retrive_index=torch.full(
         
     | 
| 
       85 
     | 
    
         
            -
                            (0, num_verify_tokens), -1, dtype=torch.long, device= 
     | 
| 
      
 89 
     | 
    
         
            +
                            (0, num_verify_tokens), -1, dtype=torch.long, device=device
         
     | 
| 
       86 
90 
     | 
    
         
             
                        ),
         
     | 
| 
       87 
91 
     | 
    
         
             
                        retrive_next_token=torch.full(
         
     | 
| 
       88 
     | 
    
         
            -
                            (0, num_verify_tokens), -1, dtype=torch.long, device= 
     | 
| 
      
 92 
     | 
    
         
            +
                            (0, num_verify_tokens), -1, dtype=torch.long, device=device
         
     | 
| 
       89 
93 
     | 
    
         
             
                        ),
         
     | 
| 
       90 
94 
     | 
    
         
             
                        retrive_next_sibling=torch.full(
         
     | 
| 
       91 
     | 
    
         
            -
                            (0, num_verify_tokens), -1, dtype=torch.long, device= 
     | 
| 
      
 95 
     | 
    
         
            +
                            (0, num_verify_tokens), -1, dtype=torch.long, device=device
         
     | 
| 
       92 
96 
     | 
    
         
             
                        ),
         
     | 
| 
       93 
97 
     | 
    
         
             
                        retrive_cum_len=None,
         
     | 
| 
       94 
98 
     | 
    
         
             
                        topk=topk,
         
     | 
| 
         @@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       134 
138 
     | 
    
         
             
                        self.last_loc = last_loc
         
     | 
| 
       135 
139 
     | 
    
         | 
| 
       136 
140 
     | 
    
         
             
                    bs = batch.batch_size()
         
     | 
| 
       137 
     | 
    
         
            -
                     
     | 
| 
      
 141 
     | 
    
         
            +
                    assign_req_to_token_pool_func(
         
     | 
| 
       138 
142 
     | 
    
         
             
                        batch.req_pool_indices,
         
     | 
| 
       139 
143 
     | 
    
         
             
                        batch.req_to_token_pool.req_to_token,
         
     | 
| 
       140 
144 
     | 
    
         
             
                        batch.seq_lens,
         
     | 
| 
       141 
145 
     | 
    
         
             
                        end_offset,
         
     | 
| 
       142 
146 
     | 
    
         
             
                        batch.out_cache_loc,
         
     | 
| 
       143 
     | 
    
         
            -
                         
     | 
| 
       144 
     | 
    
         
            -
                        next_power_of_2(bs),
         
     | 
| 
      
 147 
     | 
    
         
            +
                        bs,
         
     | 
| 
       145 
148 
     | 
    
         
             
                    )
         
     | 
| 
       146 
149 
     | 
    
         | 
| 
       147 
150 
     | 
    
         
             
                def generate_attn_arg_prefill(
         
     | 
| 
         @@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       151 
154 
     | 
    
         
             
                    paged_kernel_lens_sum: int,
         
     | 
| 
       152 
155 
     | 
    
         
             
                    req_to_token: torch.Tensor,
         
     | 
| 
       153 
156 
     | 
    
         
             
                ):
         
     | 
| 
      
 157 
     | 
    
         
            +
                    device = req_pool_indices.device
         
     | 
| 
       154 
158 
     | 
    
         
             
                    batch_size = len(req_pool_indices)
         
     | 
| 
       155 
159 
     | 
    
         
             
                    qo_indptr = torch.arange(
         
     | 
| 
       156 
160 
     | 
    
         
             
                        0,
         
     | 
| 
       157 
161 
     | 
    
         
             
                        (1 + batch_size) * self.draft_token_num,
         
     | 
| 
       158 
162 
     | 
    
         
             
                        step=self.draft_token_num,
         
     | 
| 
       159 
163 
     | 
    
         
             
                        dtype=torch.int32,
         
     | 
| 
       160 
     | 
    
         
            -
                        device= 
     | 
| 
      
 164 
     | 
    
         
            +
                        device=device,
         
     | 
| 
       161 
165 
     | 
    
         
             
                    )
         
     | 
| 
       162 
166 
     | 
    
         
             
                    cum_kv_seq_len = torch.zeros(
         
     | 
| 
       163 
     | 
    
         
            -
                        (batch_size + 1,), dtype=torch.int32, device= 
     | 
| 
      
 167 
     | 
    
         
            +
                        (batch_size + 1,), dtype=torch.int32, device=device
         
     | 
| 
       164 
168 
     | 
    
         
             
                    )
         
     | 
| 
       165 
169 
     | 
    
         | 
| 
       166 
170 
     | 
    
         
             
                    paged_kernel_lens = paged_kernel_lens + self.draft_token_num
         
     | 
| 
         @@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       169 
173 
     | 
    
         
             
                    kv_indices = torch.empty(
         
     | 
| 
       170 
174 
     | 
    
         
             
                        paged_kernel_lens_sum + self.draft_token_num * batch_size,
         
     | 
| 
       171 
175 
     | 
    
         
             
                        dtype=torch.int32,
         
     | 
| 
       172 
     | 
    
         
            -
                        device= 
     | 
| 
      
 176 
     | 
    
         
            +
                        device=device,
         
     | 
| 
       173 
177 
     | 
    
         
             
                    )
         
     | 
| 
       174 
178 
     | 
    
         
             
                    create_flashinfer_kv_indices_triton[(batch_size,)](
         
     | 
| 
       175 
179 
     | 
    
         
             
                        req_to_token,
         
     | 
| 
         @@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       226 
230 
     | 
    
         | 
| 
       227 
231 
     | 
    
         
             
                    predict_shape = list(logits_output.next_token_logits.shape)[:-1]
         
     | 
| 
       228 
232 
     | 
    
         
             
                    predict_shape[-1] += 1
         
     | 
| 
       229 
     | 
    
         
            -
                    predict = torch.empty(predict_shape, dtype=torch.int32, device= 
     | 
| 
      
 233 
     | 
    
         
            +
                    predict = torch.empty(predict_shape, dtype=torch.int32, device=batch.device)
         
     | 
| 
       230 
234 
     | 
    
         
             
                    accept_index = torch.full(
         
     | 
| 
       231 
     | 
    
         
            -
                        (bs, self.spec_steps + 1), -1, dtype=torch.int32, device= 
     | 
| 
      
 235 
     | 
    
         
            +
                        (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device
         
     | 
| 
       232 
236 
     | 
    
         
             
                    )
         
     | 
| 
       233 
     | 
    
         
            -
                    accept_length = torch.empty((bs,), dtype=torch.int32, device= 
     | 
| 
      
 237 
     | 
    
         
            +
                    accept_length = torch.empty((bs,), dtype=torch.int32, device=batch.device)
         
     | 
| 
       234 
238 
     | 
    
         | 
| 
       235 
239 
     | 
    
         
             
                    if bs != len(sampling_info):
         
     | 
| 
       236 
240 
     | 
    
         
             
                        sampling_info = copy.deepcopy(sampling_info)
         
     | 
| 
         @@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       254 
258 
     | 
    
         
             
                        linear_penalty = torch.zeros(
         
     | 
| 
       255 
259 
     | 
    
         
             
                            (bs, logits_output.next_token_logits.shape[1]),
         
     | 
| 
       256 
260 
     | 
    
         
             
                            dtype=torch.float32,
         
     | 
| 
       257 
     | 
    
         
            -
                            device= 
     | 
| 
      
 261 
     | 
    
         
            +
                            device=batch.device,
         
     | 
| 
       258 
262 
     | 
    
         
             
                        )
         
     | 
| 
       259 
263 
     | 
    
         
             
                        sampling_info.apply_logits_bias(linear_penalty)
         
     | 
| 
       260 
264 
     | 
    
         
             
                        logits_output.next_token_logits.add_(
         
     | 
| 
         @@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       276 
280 
     | 
    
         
             
                            "Falling back to greedy verification."
         
     | 
| 
       277 
281 
     | 
    
         
             
                        )
         
     | 
| 
       278 
282 
     | 
    
         | 
| 
       279 
     | 
    
         
            -
                    if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
         
     | 
| 
      
 283 
     | 
    
         
            +
                    if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE or _is_npu:
         
     | 
| 
       280 
284 
     | 
    
         
             
                        target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
         
     | 
| 
       281 
285 
     | 
    
         
             
                        target_predict = target_predict.reshape(bs, self.draft_token_num)
         
     | 
| 
       282 
     | 
    
         
            -
             
     | 
| 
       283 
     | 
    
         
            -
                        verify_tree_greedy(
         
     | 
| 
      
 286 
     | 
    
         
            +
                        predict, accept_index, accept_length = verify_tree_greedy_func(
         
     | 
| 
       284 
287 
     | 
    
         
             
                            predicts=predict,  # mutable
         
     | 
| 
       285 
288 
     | 
    
         
             
                            accept_index=accept_index,  # mutable
         
     | 
| 
       286 
289 
     | 
    
         
             
                            accept_token_num=accept_length,  # mutable
         
     | 
| 
         @@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       289 
292 
     | 
    
         
             
                            retrive_next_token=self.retrive_next_token,
         
     | 
| 
       290 
293 
     | 
    
         
             
                            retrive_next_sibling=self.retrive_next_sibling,
         
     | 
| 
       291 
294 
     | 
    
         
             
                            target_predict=target_predict,
         
     | 
| 
      
 295 
     | 
    
         
            +
                            topk=self.topk,
         
     | 
| 
       292 
296 
     | 
    
         
             
                        )
         
     | 
| 
      
 297 
     | 
    
         
            +
             
     | 
| 
       293 
298 
     | 
    
         
             
                    else:
         
     | 
| 
       294 
299 
     | 
    
         
             
                        # apply temperature and get target probs
         
     | 
| 
       295 
300 
     | 
    
         
             
                        expanded_temperature = torch.repeat_interleave(
         
     | 
| 
         @@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       315 
320 
     | 
    
         
             
                        target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
         
     | 
| 
       316 
321 
     | 
    
         | 
| 
       317 
322 
     | 
    
         
             
                        draft_probs = torch.zeros(
         
     | 
| 
       318 
     | 
    
         
            -
                            target_probs.shape, dtype=torch.float32, device= 
     | 
| 
      
 323 
     | 
    
         
            +
                            target_probs.shape, dtype=torch.float32, device=batch.device
         
     | 
| 
       319 
324 
     | 
    
         
             
                        )
         
     | 
| 
       320 
325 
     | 
    
         | 
| 
       321 
326 
     | 
    
         
             
                        # coins for rejection sampling
         
     | 
| 
       322 
     | 
    
         
            -
                        coins = torch.rand_like( 
     | 
| 
      
 327 
     | 
    
         
            +
                        coins = torch.rand_like(
         
     | 
| 
      
 328 
     | 
    
         
            +
                            candidates, dtype=torch.float32, device=batch.device
         
     | 
| 
      
 329 
     | 
    
         
            +
                        )
         
     | 
| 
       323 
330 
     | 
    
         
             
                        # coins for final sampling
         
     | 
| 
       324 
331 
     | 
    
         
             
                        coins_for_final_sampling = torch.rand(
         
     | 
| 
       325 
     | 
    
         
            -
                            (bs,), dtype=torch.float32, device= 
     | 
| 
      
 332 
     | 
    
         
            +
                            (bs,), dtype=torch.float32, device=batch.device
         
     | 
| 
       326 
333 
     | 
    
         
             
                        )
         
     | 
| 
       327 
334 
     | 
    
         
             
                        tree_speculative_sampling_target_only(
         
     | 
| 
       328 
335 
     | 
    
         
             
                            predicts=predict,  # mutable
         
     | 
| 
         @@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       468 
475 
     | 
    
         
             
                    if not has_finished:
         
     | 
| 
       469 
476 
     | 
    
         
             
                        if page_size == 1 or self.topk == 1:
         
     | 
| 
       470 
477 
     | 
    
         
             
                            batch.out_cache_loc = batch.out_cache_loc[accept_index]
         
     | 
| 
       471 
     | 
    
         
            -
                             
     | 
| 
      
 478 
     | 
    
         
            +
                            assign_req_to_token_pool_func(
         
     | 
| 
       472 
479 
     | 
    
         
             
                                batch.req_pool_indices,
         
     | 
| 
       473 
480 
     | 
    
         
             
                                batch.req_to_token_pool.req_to_token,
         
     | 
| 
       474 
481 
     | 
    
         
             
                                batch.seq_lens,
         
     | 
| 
       475 
482 
     | 
    
         
             
                                batch.seq_lens + accept_length + 1,
         
     | 
| 
       476 
483 
     | 
    
         
             
                                batch.out_cache_loc,
         
     | 
| 
       477 
     | 
    
         
            -
                                 
     | 
| 
       478 
     | 
    
         
            -
                                next_power_of_2(bs),
         
     | 
| 
      
 484 
     | 
    
         
            +
                                bs,
         
     | 
| 
       479 
485 
     | 
    
         
             
                            )
         
     | 
| 
       480 
486 
     | 
    
         
             
                        else:
         
     | 
| 
       481 
487 
     | 
    
         
             
                            batch.out_cache_loc = tgt_cache_loc
         
     | 
| 
         @@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): 
     | 
|
| 
       501 
507 
     | 
    
         
             
                        )
         
     | 
| 
       502 
508 
     | 
    
         
             
                    else:
         
     | 
| 
       503 
509 
     | 
    
         
             
                        if page_size == 1 or self.topk == 1:
         
     | 
| 
       504 
     | 
    
         
            -
                             
     | 
| 
      
 510 
     | 
    
         
            +
                            assign_req_to_token_pool_func(
         
     | 
| 
       505 
511 
     | 
    
         
             
                                batch.req_pool_indices,
         
     | 
| 
       506 
512 
     | 
    
         
             
                                batch.req_to_token_pool.req_to_token,
         
     | 
| 
       507 
513 
     | 
    
         
             
                                batch.seq_lens,
         
     | 
| 
       508 
514 
     | 
    
         
             
                                batch.seq_lens + accept_length + 1,
         
     | 
| 
       509 
515 
     | 
    
         
             
                                batch.out_cache_loc[accept_index],
         
     | 
| 
       510 
     | 
    
         
            -
                                 
     | 
| 
       511 
     | 
    
         
            -
                                next_power_of_2(bs),
         
     | 
| 
      
 516 
     | 
    
         
            +
                                bs,
         
     | 
| 
       512 
517 
     | 
    
         
             
                            )
         
     | 
| 
       513 
518 
     | 
    
         
             
                            batch.seq_lens.add_(accept_length + 1)
         
     | 
| 
       514 
519 
     | 
    
         
             
                            batch.seq_lens_cpu.add_(accept_length_cpu + 1)
         
     | 
| 
         @@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): 
     | 
|
| 
       695 
700 
     | 
    
         
             
                    paged_kernel_lens_sum: int,
         
     | 
| 
       696 
701 
     | 
    
         
             
                    req_to_token: torch.Tensor,
         
     | 
| 
       697 
702 
     | 
    
         
             
                ):
         
     | 
| 
      
 703 
     | 
    
         
            +
                    device = req_pool_indices.device
         
     | 
| 
       698 
704 
     | 
    
         
             
                    bs = self.accept_length.numel()
         
     | 
| 
       699 
     | 
    
         
            -
                    qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device= 
     | 
| 
      
 705 
     | 
    
         
            +
                    qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
         
     | 
| 
       700 
706 
     | 
    
         
             
                    qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
         
     | 
| 
       701 
     | 
    
         
            -
                    cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device= 
     | 
| 
      
 707 
     | 
    
         
            +
                    cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
         
     | 
| 
       702 
708 
     | 
    
         
             
                    cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
       703 
709 
     | 
    
         | 
| 
       704 
710 
     | 
    
         
             
                    if paged_kernel_lens_sum is None:
         
     | 
| 
       705 
711 
     | 
    
         
             
                        paged_kernel_lens_sum = cum_kv_seq_len[-1]
         
     | 
| 
       706 
712 
     | 
    
         | 
| 
       707 
713 
     | 
    
         
             
                    kv_indices = torch.empty(
         
     | 
| 
       708 
     | 
    
         
            -
                        paged_kernel_lens_sum, dtype=torch.int32, device= 
     | 
| 
      
 714 
     | 
    
         
            +
                        paged_kernel_lens_sum, dtype=torch.int32, device=device
         
     | 
| 
       709 
715 
     | 
    
         
             
                    )
         
     | 
| 
       710 
716 
     | 
    
         | 
| 
       711 
717 
     | 
    
         
             
                    create_flashinfer_kv_indices_triton[(bs,)](
         
     | 
| 
         @@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import ( 
     | 
|
| 
       23 
23 
     | 
    
         
             
            )
         
     | 
| 
       24 
24 
     | 
    
         
             
            from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       25 
25 
     | 
    
         
             
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
      
 26 
     | 
    
         
            +
            from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
         
     | 
| 
       26 
27 
     | 
    
         
             
            from sglang.srt.speculative.spec_utils import (
         
     | 
| 
       27 
28 
     | 
    
         
             
                SIMULATE_ACC_LEN,
         
     | 
| 
       28 
29 
     | 
    
         
             
                generate_simulated_accept_index,
         
     | 
| 
       29 
30 
     | 
    
         
             
            )
         
     | 
| 
       30 
     | 
    
         
            -
            from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
         
     | 
| 
      
 31 
     | 
    
         
            +
            from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, is_npu, next_power_of_2
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
            _is_cuda = is_cuda()
         
     | 
| 
      
 34 
     | 
    
         
            +
            _is_hip = is_hip()
         
     | 
| 
      
 35 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
       31 
36 
     | 
    
         | 
| 
       32 
37 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       33 
38 
     | 
    
         
             
                from sglang.srt.managers.tp_worker import TpModelWorker
         
     | 
| 
         @@ -41,11 +46,8 @@ if is_cuda(): 
     | 
|
| 
       41 
46 
     | 
    
         
             
                    top_k_renorm_prob,
         
     | 
| 
       42 
47 
     | 
    
         
             
                    top_p_renorm_prob,
         
     | 
| 
       43 
48 
     | 
    
         
             
                    tree_speculative_sampling_target_only,
         
     | 
| 
       44 
     | 
    
         
            -
                    verify_tree_greedy,
         
     | 
| 
       45 
49 
     | 
    
         
             
                )
         
     | 
| 
       46 
50 
     | 
    
         
             
                from sgl_kernel.top_k import fast_topk
         
     | 
| 
       47 
     | 
    
         
            -
            elif is_hip():
         
     | 
| 
       48 
     | 
    
         
            -
                from sgl_kernel import verify_tree_greedy
         
     | 
| 
       49 
51 
     | 
    
         | 
| 
       50 
52 
     | 
    
         | 
| 
       51 
53 
     | 
    
         
             
            @triton.jit
         
     | 
| 
         @@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1( 
     | 
|
| 
       78 
80 
     | 
    
         
             
            @dataclass
         
     | 
| 
       79 
81 
     | 
    
         
             
            class EagleDraftInputV2Mixin:
         
     | 
| 
       80 
82 
     | 
    
         
             
                def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
         
     | 
| 
       81 
     | 
    
         
            -
                    from sglang.srt.speculative.spec_utils import  
     | 
| 
      
 83 
     | 
    
         
            +
                    from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
         
     | 
| 
       82 
84 
     | 
    
         | 
| 
       83 
85 
     | 
    
         
             
                    bs = batch.batch_size()
         
     | 
| 
       84 
86 
     | 
    
         | 
| 
         @@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin: 
     | 
|
| 
       112 
114 
     | 
    
         
             
                            extend_num_tokens,
         
     | 
| 
       113 
115 
     | 
    
         
             
                        )
         
     | 
| 
       114 
116 
     | 
    
         | 
| 
       115 
     | 
    
         
            -
                     
     | 
| 
      
 117 
     | 
    
         
            +
                    assign_req_to_token_pool_func(
         
     | 
| 
       116 
118 
     | 
    
         
             
                        batch.req_pool_indices,
         
     | 
| 
       117 
119 
     | 
    
         
             
                        batch.req_to_token_pool.req_to_token,
         
     | 
| 
       118 
120 
     | 
    
         
             
                        self.allocate_lens,
         
     | 
| 
       119 
121 
     | 
    
         
             
                        new_allocate_lens,
         
     | 
| 
       120 
122 
     | 
    
         
             
                        out_cache_loc,
         
     | 
| 
       121 
     | 
    
         
            -
                         
     | 
| 
       122 
     | 
    
         
            -
                        next_power_of_2(bs),
         
     | 
| 
      
 123 
     | 
    
         
            +
                        bs,
         
     | 
| 
       123 
124 
     | 
    
         
             
                    )
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
       124 
126 
     | 
    
         
             
                    self.allocate_lens = new_allocate_lens
         
     | 
| 
       125 
127 
     | 
    
         | 
| 
       126 
128 
     | 
    
         
             
                    # FIXME(lsyin): make this sync optional
         
     | 
| 
         @@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin: 
     | 
|
| 
       199 
201 
     | 
    
         
             
                    bs = len(batch.req_pool_indices)
         
     | 
| 
       200 
202 
     | 
    
         
             
                    batch.input_ids = self.draft_token
         
     | 
| 
       201 
203 
     | 
    
         
             
                    device = batch.input_ids.device
         
     | 
| 
       202 
     | 
    
         
            -
                    batch.out_cache_loc =  
     | 
| 
       203 
     | 
    
         
            -
                         
     | 
| 
       204 
     | 
    
         
            -
                         
     | 
| 
      
 204 
     | 
    
         
            +
                    batch.out_cache_loc = assign_extend_cache_locs_func(
         
     | 
| 
      
 205 
     | 
    
         
            +
                        req_pool_indices=batch.req_pool_indices,
         
     | 
| 
      
 206 
     | 
    
         
            +
                        req_to_token=req_to_token_pool.req_to_token,
         
     | 
| 
      
 207 
     | 
    
         
            +
                        start_offset=batch.seq_lens,
         
     | 
| 
      
 208 
     | 
    
         
            +
                        end_offset=batch.seq_lens + self.draft_token_num,
         
     | 
| 
      
 209 
     | 
    
         
            +
                        batch_size=bs,
         
     | 
| 
      
 210 
     | 
    
         
            +
                        draft_token_num=self.draft_token_num,
         
     | 
| 
       205 
211 
     | 
    
         
             
                        device=device,
         
     | 
| 
       206 
212 
     | 
    
         
             
                    )
         
     | 
| 
       207 
213 
     | 
    
         | 
| 
       208 
     | 
    
         
            -
                    assign_extend_cache_locs[(bs,)](
         
     | 
| 
       209 
     | 
    
         
            -
                        batch.req_pool_indices,
         
     | 
| 
       210 
     | 
    
         
            -
                        req_to_token_pool.req_to_token,
         
     | 
| 
       211 
     | 
    
         
            -
                        batch.seq_lens,
         
     | 
| 
       212 
     | 
    
         
            -
                        batch.seq_lens + self.draft_token_num,
         
     | 
| 
       213 
     | 
    
         
            -
                        batch.out_cache_loc,
         
     | 
| 
       214 
     | 
    
         
            -
                        req_to_token_pool.req_to_token.shape[1],
         
     | 
| 
       215 
     | 
    
         
            -
                        next_power_of_2(bs),
         
     | 
| 
       216 
     | 
    
         
            -
                    )
         
     | 
| 
       217 
     | 
    
         
            -
             
     | 
| 
       218 
214 
     | 
    
         
             
                    # Get a forward batch
         
     | 
| 
       219 
215 
     | 
    
         
             
                    batch.forward_mode = ForwardMode.TARGET_VERIFY
         
     | 
| 
       220 
216 
     | 
    
         
             
                    batch.capture_hidden_mode = CaptureHiddenMode.FULL
         
     | 
| 
         @@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin: 
     | 
|
| 
       258 
254 
     | 
    
         
             
                    accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
         
     | 
| 
       259 
255 
     | 
    
         | 
| 
       260 
256 
     | 
    
         
             
                    # Sample tokens
         
     | 
| 
       261 
     | 
    
         
            -
                    if sampling_info.is_all_greedy:
         
     | 
| 
      
 257 
     | 
    
         
            +
                    if sampling_info.is_all_greedy or _is_npu:
         
     | 
| 
       262 
258 
     | 
    
         
             
                        target_predict = torch.argmax(next_token_logits, dim=-1)
         
     | 
| 
       263 
259 
     | 
    
         
             
                        target_predict = target_predict.reshape(bs, self.draft_token_num)
         
     | 
| 
       264 
     | 
    
         
            -
             
     | 
| 
       265 
     | 
    
         
            -
                        verify_tree_greedy(
         
     | 
| 
      
 260 
     | 
    
         
            +
                        predict, accept_index, accept_length = verify_tree_greedy_func(
         
     | 
| 
       266 
261 
     | 
    
         
             
                            predicts=predict,  # mutable
         
     | 
| 
       267 
262 
     | 
    
         
             
                            accept_index=accept_index,  # mutable
         
     | 
| 
       268 
263 
     | 
    
         
             
                            accept_token_num=accept_length,  # mutable
         
     | 
| 
         @@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin: 
     | 
|
| 
       271 
266 
     | 
    
         
             
                            retrive_next_token=self.retrive_next_token,
         
     | 
| 
       272 
267 
     | 
    
         
             
                            retrive_next_sibling=self.retrive_next_sibling,
         
     | 
| 
       273 
268 
     | 
    
         
             
                            target_predict=target_predict,
         
     | 
| 
      
 269 
     | 
    
         
            +
                            topk=self.topk,
         
     | 
| 
       274 
270 
     | 
    
         
             
                        )
         
     | 
| 
       275 
271 
     | 
    
         
             
                    else:
         
     | 
| 
       276 
272 
     | 
    
         
             
                        # Apply temperature and get target probs
         
     | 
| 
         @@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin: 
     | 
|
| 
       338 
334 
     | 
    
         
             
                    return predict, accept_length, accept_index
         
     | 
| 
       339 
335 
     | 
    
         | 
| 
       340 
336 
     | 
    
         | 
| 
       341 
     | 
    
         
            -
            @torch.compile(dynamic=True)
         
     | 
| 
      
 337 
     | 
    
         
            +
            @torch.compile(dynamic=True, disable=_is_npu)
         
     | 
| 
       342 
338 
     | 
    
         
             
            def select_top_k_tokens_tmp(
         
     | 
| 
       343 
339 
     | 
    
         
             
                i: int,
         
     | 
| 
       344 
340 
     | 
    
         
             
                topk_p: torch.Tensor,
         
     | 
| 
         @@ -456,3 +452,50 @@ def assign_extend_cache_locs( 
     | 
|
| 
       456 
452 
     | 
    
         
             
                    tl.store(out_cache_ptr + save_offset, data, mask=mask)
         
     | 
| 
       457 
453 
     | 
    
         
             
                    load_offset += BLOCK_SIZE
         
     | 
| 
       458 
454 
     | 
    
         
             
                    save_offset += BLOCK_SIZE
         
     | 
| 
      
 455 
     | 
    
         
            +
             
     | 
| 
      
 456 
     | 
    
         
            +
             
     | 
| 
      
 457 
     | 
    
         
            +
            def assign_extend_cache_locs_func(
         
     | 
| 
      
 458 
     | 
    
         
            +
                req_pool_indices: torch.Tensor,
         
     | 
| 
      
 459 
     | 
    
         
            +
                req_to_token: torch.Tensor,
         
     | 
| 
      
 460 
     | 
    
         
            +
                start_offset: torch.Tensor,
         
     | 
| 
      
 461 
     | 
    
         
            +
                end_offset: torch.Tensor,
         
     | 
| 
      
 462 
     | 
    
         
            +
                batch_size: int,
         
     | 
| 
      
 463 
     | 
    
         
            +
                draft_token_num: int,
         
     | 
| 
      
 464 
     | 
    
         
            +
                device,
         
     | 
| 
      
 465 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 466 
     | 
    
         
            +
                if _is_cuda or _is_hip:
         
     | 
| 
      
 467 
     | 
    
         
            +
                    out_cache_loc = torch.empty(
         
     | 
| 
      
 468 
     | 
    
         
            +
                        (batch_size * draft_token_num,),
         
     | 
| 
      
 469 
     | 
    
         
            +
                        dtype=torch.int64,
         
     | 
| 
      
 470 
     | 
    
         
            +
                        device=device,
         
     | 
| 
      
 471 
     | 
    
         
            +
                    )
         
     | 
| 
      
 472 
     | 
    
         
            +
                    assign_extend_cache_locs[(batch_size,)](
         
     | 
| 
      
 473 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 474 
     | 
    
         
            +
                        req_to_token,
         
     | 
| 
      
 475 
     | 
    
         
            +
                        start_offset,
         
     | 
| 
      
 476 
     | 
    
         
            +
                        end_offset,
         
     | 
| 
      
 477 
     | 
    
         
            +
                        out_cache_loc,
         
     | 
| 
      
 478 
     | 
    
         
            +
                        req_to_token.shape[1],
         
     | 
| 
      
 479 
     | 
    
         
            +
                        next_power_of_2(batch_size),
         
     | 
| 
      
 480 
     | 
    
         
            +
                    )
         
     | 
| 
      
 481 
     | 
    
         
            +
             
     | 
| 
      
 482 
     | 
    
         
            +
                    return out_cache_loc
         
     | 
| 
      
 483 
     | 
    
         
            +
             
     | 
| 
      
 484 
     | 
    
         
            +
                elif _is_npu:
         
     | 
| 
      
 485 
     | 
    
         
            +
                    import sgl_kernel_npu  # noqa: F401
         
     | 
| 
      
 486 
     | 
    
         
            +
             
     | 
| 
      
 487 
     | 
    
         
            +
                    out_cache_loc = torch.empty(
         
     | 
| 
      
 488 
     | 
    
         
            +
                        (batch_size * draft_token_num,),
         
     | 
| 
      
 489 
     | 
    
         
            +
                        dtype=torch.int32,
         
     | 
| 
      
 490 
     | 
    
         
            +
                        device=device,
         
     | 
| 
      
 491 
     | 
    
         
            +
                    )
         
     | 
| 
      
 492 
     | 
    
         
            +
                    torch.ops.npu.cache_loc_update(
         
     | 
| 
      
 493 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 494 
     | 
    
         
            +
                        req_to_token,
         
     | 
| 
      
 495 
     | 
    
         
            +
                        start_offset,
         
     | 
| 
      
 496 
     | 
    
         
            +
                        end_offset,
         
     | 
| 
      
 497 
     | 
    
         
            +
                        out_cache_loc,
         
     | 
| 
      
 498 
     | 
    
         
            +
                    )
         
     | 
| 
      
 499 
     | 
    
         
            +
                    out_cache_loc = out_cache_loc.to(dtype=torch.int64)
         
     | 
| 
      
 500 
     | 
    
         
            +
             
     | 
| 
      
 501 
     | 
    
         
            +
                    return out_cache_loc
         
     | 
| 
         @@ -4,14 +4,128 @@ from typing import List, Optional 
     | 
|
| 
       4 
4 
     | 
    
         | 
| 
       5 
5 
     | 
    
         
             
            import torch
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
     | 
    
         
            -
            from sglang.srt.utils import is_cuda, is_hip
         
     | 
| 
      
 7 
     | 
    
         
            +
            from sglang.srt.utils import is_cuda, is_hip, is_npu
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
      
 9 
     | 
    
         
            +
            _is_cuda = is_cuda()
         
     | 
| 
      
 10 
     | 
    
         
            +
            _is_hip = is_hip()
         
     | 
| 
      
 11 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            if _is_cuda or _is_hip:
         
     | 
| 
       10 
14 
     | 
    
         
             
                from sgl_kernel import (
         
     | 
| 
       11 
15 
     | 
    
         
             
                    build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
         
     | 
| 
       12 
16 
     | 
    
         
             
                )
         
     | 
| 
       13 
17 
     | 
    
         | 
| 
       14 
18 
     | 
    
         | 
| 
      
 19 
     | 
    
         
            +
            def build_tree_efficient_native(
         
     | 
| 
      
 20 
     | 
    
         
            +
                parent_list: torch.Tensor,
         
     | 
| 
      
 21 
     | 
    
         
            +
                selected_index: torch.Tensor,
         
     | 
| 
      
 22 
     | 
    
         
            +
                verified_seq_len: torch.Tensor,
         
     | 
| 
      
 23 
     | 
    
         
            +
                tree_mask: torch.Tensor,
         
     | 
| 
      
 24 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 25 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 26 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 27 
     | 
    
         
            +
                topk: int,
         
     | 
| 
      
 28 
     | 
    
         
            +
                draft_token_num: int,
         
     | 
| 
      
 29 
     | 
    
         
            +
                tree_mask_mode: int,
         
     | 
| 
      
 30 
     | 
    
         
            +
                bs: int,
         
     | 
| 
      
 31 
     | 
    
         
            +
            ):
         
     | 
| 
      
 32 
     | 
    
         
            +
                # Generate batch and token index ranges
         
     | 
| 
      
 33 
     | 
    
         
            +
                bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
         
     | 
| 
      
 34 
     | 
    
         
            +
                draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                # Optimized common case for performance.
         
     | 
| 
      
 37 
     | 
    
         
            +
                if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    positions = verified_seq_len.repeat_interleave(draft_token_num)
         
     | 
| 
      
 39 
     | 
    
         
            +
                    positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                    retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
         
     | 
| 
      
 42 
     | 
    
         
            +
                    retrive_next_token[:, 0] = 1
         
     | 
| 
      
 43 
     | 
    
         
            +
                    retrive_next_token[:, 1] = -1
         
     | 
| 
      
 44 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 45 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 46 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 47 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 48 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 49 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 50 
     | 
    
         
            +
                    )
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                # Precompute sequence tree indices
         
     | 
| 
      
 53 
     | 
    
         
            +
                draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
         
     | 
| 
      
 54 
     | 
    
         
            +
                cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
         
     | 
| 
      
 55 
     | 
    
         
            +
                cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
         
     | 
| 
      
 56 
     | 
    
         
            +
                cum_seq_len = cum_seq_len[:-1]
         
     | 
| 
      
 57 
     | 
    
         
            +
                seq_tree_idx = (
         
     | 
| 
      
 58 
     | 
    
         
            +
                    draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
         
     | 
| 
      
 59 
     | 
    
         
            +
                    + cum_seq_len
         
     | 
| 
      
 60 
     | 
    
         
            +
                )
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                # Batch processing for tree mask
         
     | 
| 
      
 63 
     | 
    
         
            +
                if tree_mask_mode == TreeMaskMode.FULL_MASK:
         
     | 
| 
      
 64 
     | 
    
         
            +
                    token_tree_base = (
         
     | 
| 
      
 65 
     | 
    
         
            +
                        seq_tree_idx.view(-1, 1)
         
     | 
| 
      
 66 
     | 
    
         
            +
                        + (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
         
     | 
| 
      
 67 
     | 
    
         
            +
                    )
         
     | 
| 
      
 68 
     | 
    
         
            +
                    token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
         
     | 
| 
      
 69 
     | 
    
         
            +
                else:
         
     | 
| 
      
 70 
     | 
    
         
            +
                    token_tree_indices = (
         
     | 
| 
      
 71 
     | 
    
         
            +
                        bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
         
     | 
| 
      
 72 
     | 
    
         
            +
                    )
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                tree_mask[token_tree_indices.flatten() - 1] = True
         
     | 
| 
      
 75 
     | 
    
         
            +
                indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
         
     | 
| 
      
 76 
     | 
    
         
            +
                tree_mask[indices.view(-1)] = False
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                positions = verified_seq_len.repeat_interleave(draft_token_num)
         
     | 
| 
      
 79 
     | 
    
         
            +
                parent_tb_indices = selected_index // topk
         
     | 
| 
      
 80 
     | 
    
         
            +
                retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
         
     | 
| 
      
 81 
     | 
    
         
            +
                tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                for bid in range(bs):
         
     | 
| 
      
 84 
     | 
    
         
            +
                    for tid in range(draft_token_num):
         
     | 
| 
      
 85 
     | 
    
         
            +
                        position = 0
         
     | 
| 
      
 86 
     | 
    
         
            +
                        if tid == 0:
         
     | 
| 
      
 87 
     | 
    
         
            +
                            # Process root node
         
     | 
| 
      
 88 
     | 
    
         
            +
                            for i in range(draft_token_num - 1, 0, -1):
         
     | 
| 
      
 89 
     | 
    
         
            +
                                parent_position = 0
         
     | 
| 
      
 90 
     | 
    
         
            +
                                parent_tb_idx = parent_tb_indices[bid][i - 1]
         
     | 
| 
      
 91 
     | 
    
         
            +
                                if parent_tb_idx > 0:
         
     | 
| 
      
 92 
     | 
    
         
            +
                                    parent_token_idx = parent_list[bid][parent_tb_idx]
         
     | 
| 
      
 93 
     | 
    
         
            +
                                    loop_num = draft_token_num - parent_position
         
     | 
| 
      
 94 
     | 
    
         
            +
                                    for _ in range(loop_num):
         
     | 
| 
      
 95 
     | 
    
         
            +
                                        if selected_index[bid][parent_position] == parent_token_idx:
         
     | 
| 
      
 96 
     | 
    
         
            +
                                            parent_position += 1
         
     | 
| 
      
 97 
     | 
    
         
            +
                                            break
         
     | 
| 
      
 98 
     | 
    
         
            +
                                        parent_position += 1
         
     | 
| 
      
 99 
     | 
    
         
            +
                                if parent_position == draft_token_num:
         
     | 
| 
      
 100 
     | 
    
         
            +
                                    continue
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                                if retrive_next_token[bid][parent_position] != -1:
         
     | 
| 
      
 103 
     | 
    
         
            +
                                    retrive_next_sibling[bid][i] = retrive_next_token[bid][
         
     | 
| 
      
 104 
     | 
    
         
            +
                                        parent_position
         
     | 
| 
      
 105 
     | 
    
         
            +
                                    ]
         
     | 
| 
      
 106 
     | 
    
         
            +
                                retrive_next_token[bid][parent_position] = i
         
     | 
| 
      
 107 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 108 
     | 
    
         
            +
                            # Process no-root nodes
         
     | 
| 
      
 109 
     | 
    
         
            +
                            cur_position = tid - 1
         
     | 
| 
      
 110 
     | 
    
         
            +
                            while True:
         
     | 
| 
      
 111 
     | 
    
         
            +
                                position += 1
         
     | 
| 
      
 112 
     | 
    
         
            +
                                if cur_position >= draft_token_num:
         
     | 
| 
      
 113 
     | 
    
         
            +
                                    tree_mask[token_tree_indices + cur_position] = True
         
     | 
| 
      
 114 
     | 
    
         
            +
                                    parent_tb_idx = selected_index[bid][cur_position] // topk
         
     | 
| 
      
 115 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 116 
     | 
    
         
            +
                                    parent_tb_idx = parent_tb_indices[bid][cur_position]
         
     | 
| 
      
 117 
     | 
    
         
            +
                                if parent_tb_idx == 0:
         
     | 
| 
      
 118 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 119 
     | 
    
         
            +
                                token_idx = parent_list[bid][parent_tb_idx]
         
     | 
| 
      
 120 
     | 
    
         
            +
                                cur_position = 0
         
     | 
| 
      
 121 
     | 
    
         
            +
                                for _ in range(draft_token_num):
         
     | 
| 
      
 122 
     | 
    
         
            +
                                    if selected_index[bid][cur_position] == token_idx:
         
     | 
| 
      
 123 
     | 
    
         
            +
                                        break
         
     | 
| 
      
 124 
     | 
    
         
            +
                                    cur_position += 1
         
     | 
| 
      
 125 
     | 
    
         
            +
                            positions[bid * draft_token_num + tid] += position
         
     | 
| 
      
 126 
     | 
    
         
            +
                return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
       15 
129 
     | 
    
         
             
            def organize_draft_results(
         
     | 
| 
       16 
130 
     | 
    
         
             
                score_list: List[torch.Tensor],
         
     | 
| 
       17 
131 
     | 
    
         
             
                token_list: List[torch.Tensor],
         
     | 
| 
         @@ -114,20 +228,41 @@ def build_tree_kernel_efficient( 
     | 
|
| 
       114 
228 
     | 
    
         
             
                        (bs * num_verify_tokens,), device=device, dtype=torch.long
         
     | 
| 
       115 
229 
     | 
    
         
             
                    )
         
     | 
| 
       116 
230 
     | 
    
         | 
| 
       117 
     | 
    
         
            -
                 
     | 
| 
       118 
     | 
    
         
            -
                     
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
             
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
                     
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
      
 231 
     | 
    
         
            +
                if _is_npu:
         
     | 
| 
      
 232 
     | 
    
         
            +
                    (
         
     | 
| 
      
 233 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 234 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 235 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 236 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 237 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 238 
     | 
    
         
            +
                    ) = build_tree_efficient_native(
         
     | 
| 
      
 239 
     | 
    
         
            +
                        parent_list,
         
     | 
| 
      
 240 
     | 
    
         
            +
                        top_scores_index,
         
     | 
| 
      
 241 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 242 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 243 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 244 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 245 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 246 
     | 
    
         
            +
                        topk,
         
     | 
| 
      
 247 
     | 
    
         
            +
                        num_verify_tokens,
         
     | 
| 
      
 248 
     | 
    
         
            +
                        tree_mask_mode,
         
     | 
| 
      
 249 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 250 
     | 
    
         
            +
                    )
         
     | 
| 
      
 251 
     | 
    
         
            +
                else:
         
     | 
| 
      
 252 
     | 
    
         
            +
                    sgl_build_tree_kernel_efficient(
         
     | 
| 
      
 253 
     | 
    
         
            +
                        parent_list,
         
     | 
| 
      
 254 
     | 
    
         
            +
                        top_scores_index,
         
     | 
| 
      
 255 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 256 
     | 
    
         
            +
                        tree_mask,
         
     | 
| 
      
 257 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 258 
     | 
    
         
            +
                        retrive_index,
         
     | 
| 
      
 259 
     | 
    
         
            +
                        retrive_next_token,
         
     | 
| 
      
 260 
     | 
    
         
            +
                        retrive_next_sibling,
         
     | 
| 
      
 261 
     | 
    
         
            +
                        topk,
         
     | 
| 
      
 262 
     | 
    
         
            +
                        spec_steps,
         
     | 
| 
      
 263 
     | 
    
         
            +
                        num_verify_tokens,
         
     | 
| 
      
 264 
     | 
    
         
            +
                        tree_mask_mode,
         
     | 
| 
      
 265 
     | 
    
         
            +
                    )
         
     | 
| 
       131 
266 
     | 
    
         
             
                return (
         
     | 
| 
       132 
267 
     | 
    
         
             
                    tree_mask,
         
     | 
| 
       133 
268 
     | 
    
         
             
                    positions,
         
     | 
| 
         @@ -136,3 +271,113 @@ def build_tree_kernel_efficient( 
     | 
|
| 
       136 
271 
     | 
    
         
             
                    retrive_next_sibling,
         
     | 
| 
       137 
272 
     | 
    
         
             
                    draft_tokens,
         
     | 
| 
       138 
273 
     | 
    
         
             
                )
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
            def verify_tree_greedy_native(
         
     | 
| 
      
 277 
     | 
    
         
            +
                predicts: torch.Tensor,
         
     | 
| 
      
 278 
     | 
    
         
            +
                accept_index: torch.Tensor,
         
     | 
| 
      
 279 
     | 
    
         
            +
                accept_token_num: torch.Tensor,
         
     | 
| 
      
 280 
     | 
    
         
            +
                candidates: torch.Tensor,
         
     | 
| 
      
 281 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 282 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 283 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 284 
     | 
    
         
            +
                target_predict: torch.Tensor,
         
     | 
| 
      
 285 
     | 
    
         
            +
                topk: int = -1,
         
     | 
| 
      
 286 
     | 
    
         
            +
            ):
         
     | 
| 
      
 287 
     | 
    
         
            +
                batch_size, num_draft_tokens = candidates.shape
         
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
                # Optimized common case for performance.
         
     | 
| 
      
 290 
     | 
    
         
            +
                if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
         
     | 
| 
      
 291 
     | 
    
         
            +
                    comparison_result = candidates[:, 1] == target_predict[:, 0]
         
     | 
| 
      
 292 
     | 
    
         
            +
             
     | 
| 
      
 293 
     | 
    
         
            +
                    predicts = target_predict.flatten()
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                    accept_index = torch.arange(
         
     | 
| 
      
 296 
     | 
    
         
            +
                        0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
         
     | 
| 
      
 297 
     | 
    
         
            +
                    ).reshape(batch_size, num_draft_tokens)
         
     | 
| 
      
 298 
     | 
    
         
            +
                    comparison_result = comparison_result.to(torch.int64)
         
     | 
| 
      
 299 
     | 
    
         
            +
                    accept_index_mask = accept_index[:, 1] * comparison_result
         
     | 
| 
      
 300 
     | 
    
         
            +
                    accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
         
     | 
| 
      
 301 
     | 
    
         
            +
             
     | 
| 
      
 302 
     | 
    
         
            +
                    accept_token_num = comparison_result.int()
         
     | 
| 
      
 303 
     | 
    
         
            +
                    return predicts, accept_index, accept_token_num
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
                # BFS
         
     | 
| 
      
 306 
     | 
    
         
            +
                for bx in range(batch_size):
         
     | 
| 
      
 307 
     | 
    
         
            +
                    cur_candidates = candidates[bx]
         
     | 
| 
      
 308 
     | 
    
         
            +
                    cur_retrive_index = retrive_index[bx]
         
     | 
| 
      
 309 
     | 
    
         
            +
                    cur_next_token = retrive_next_token[bx]
         
     | 
| 
      
 310 
     | 
    
         
            +
                    cur_next_sibling = retrive_next_sibling[bx]
         
     | 
| 
      
 311 
     | 
    
         
            +
                    cur_target = target_predict[bx]
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                    last_accepted_idx = cur_retrive_index[0]
         
     | 
| 
      
 314 
     | 
    
         
            +
                    accept_index[bx, 0] = last_accepted_idx
         
     | 
| 
      
 315 
     | 
    
         
            +
                    num_accepted = 0
         
     | 
| 
      
 316 
     | 
    
         
            +
                    cur_node = 0
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
                    for _ in range(1, num_draft_tokens):
         
     | 
| 
      
 319 
     | 
    
         
            +
                        cur_node = cur_next_token[cur_node]
         
     | 
| 
      
 320 
     | 
    
         
            +
                        found = False
         
     | 
| 
      
 321 
     | 
    
         
            +
                        while cur_node != -1:
         
     | 
| 
      
 322 
     | 
    
         
            +
                            draft_idx = cur_retrive_index[cur_node]
         
     | 
| 
      
 323 
     | 
    
         
            +
                            draft_token = cur_candidates[cur_node]
         
     | 
| 
      
 324 
     | 
    
         
            +
                            target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                            if draft_token == target_token:
         
     | 
| 
      
 327 
     | 
    
         
            +
                                predicts[last_accepted_idx] = target_token
         
     | 
| 
      
 328 
     | 
    
         
            +
                                num_accepted += 1
         
     | 
| 
      
 329 
     | 
    
         
            +
                                accept_index[bx, num_accepted] = draft_idx
         
     | 
| 
      
 330 
     | 
    
         
            +
                                last_accepted_idx = draft_idx
         
     | 
| 
      
 331 
     | 
    
         
            +
                                found = True
         
     | 
| 
      
 332 
     | 
    
         
            +
                                break
         
     | 
| 
      
 333 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 334 
     | 
    
         
            +
                                cur_node = cur_next_sibling[cur_node]
         
     | 
| 
      
 335 
     | 
    
         
            +
                        if not found:
         
     | 
| 
      
 336 
     | 
    
         
            +
                            break
         
     | 
| 
      
 337 
     | 
    
         
            +
             
     | 
| 
      
 338 
     | 
    
         
            +
                    accept_token_num[bx] = num_accepted
         
     | 
| 
      
 339 
     | 
    
         
            +
                    predicts[last_accepted_idx] = cur_target[
         
     | 
| 
      
 340 
     | 
    
         
            +
                        last_accepted_idx - num_draft_tokens * bx
         
     | 
| 
      
 341 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 342 
     | 
    
         
            +
                return predicts, accept_index, accept_token_num
         
     | 
| 
      
 343 
     | 
    
         
            +
             
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
            def verify_tree_greedy_func(
         
     | 
| 
      
 346 
     | 
    
         
            +
                predicts: torch.Tensor,
         
     | 
| 
      
 347 
     | 
    
         
            +
                accept_index: torch.Tensor,
         
     | 
| 
      
 348 
     | 
    
         
            +
                accept_token_num: torch.Tensor,
         
     | 
| 
      
 349 
     | 
    
         
            +
                candidates: torch.Tensor,
         
     | 
| 
      
 350 
     | 
    
         
            +
                retrive_index: torch.Tensor,
         
     | 
| 
      
 351 
     | 
    
         
            +
                retrive_next_token: torch.Tensor,
         
     | 
| 
      
 352 
     | 
    
         
            +
                retrive_next_sibling: torch.Tensor,
         
     | 
| 
      
 353 
     | 
    
         
            +
                target_predict: torch.Tensor,
         
     | 
| 
      
 354 
     | 
    
         
            +
                topk: int = -1,
         
     | 
| 
      
 355 
     | 
    
         
            +
            ):
         
     | 
| 
      
 356 
     | 
    
         
            +
                if _is_cuda or _is_hip:
         
     | 
| 
      
 357 
     | 
    
         
            +
                    from sgl_kernel import verify_tree_greedy
         
     | 
| 
      
 358 
     | 
    
         
            +
             
     | 
| 
      
 359 
     | 
    
         
            +
                    verify_tree_greedy(
         
     | 
| 
      
 360 
     | 
    
         
            +
                        predicts=predicts,  # mutable
         
     | 
| 
      
 361 
     | 
    
         
            +
                        accept_index=accept_index,  # mutable
         
     | 
| 
      
 362 
     | 
    
         
            +
                        accept_token_num=accept_token_num,  # mutable
         
     | 
| 
      
 363 
     | 
    
         
            +
                        candidates=candidates,
         
     | 
| 
      
 364 
     | 
    
         
            +
                        retrive_index=retrive_index,
         
     | 
| 
      
 365 
     | 
    
         
            +
                        retrive_next_token=retrive_next_token,
         
     | 
| 
      
 366 
     | 
    
         
            +
                        retrive_next_sibling=retrive_next_sibling,
         
     | 
| 
      
 367 
     | 
    
         
            +
                        target_predict=target_predict,
         
     | 
| 
      
 368 
     | 
    
         
            +
                    )
         
     | 
| 
      
 369 
     | 
    
         
            +
             
     | 
| 
      
 370 
     | 
    
         
            +
                elif _is_npu:
         
     | 
| 
      
 371 
     | 
    
         
            +
                    predicts, accept_index, accept_token_num = verify_tree_greedy_native(
         
     | 
| 
      
 372 
     | 
    
         
            +
                        predicts=predicts,  # mutable
         
     | 
| 
      
 373 
     | 
    
         
            +
                        accept_index=accept_index,  # mutable
         
     | 
| 
      
 374 
     | 
    
         
            +
                        accept_token_num=accept_token_num,  # mutable
         
     | 
| 
      
 375 
     | 
    
         
            +
                        candidates=candidates,
         
     | 
| 
      
 376 
     | 
    
         
            +
                        retrive_index=retrive_index,
         
     | 
| 
      
 377 
     | 
    
         
            +
                        retrive_next_token=retrive_next_token,
         
     | 
| 
      
 378 
     | 
    
         
            +
                        retrive_next_sibling=retrive_next_sibling,
         
     | 
| 
      
 379 
     | 
    
         
            +
                        target_predict=target_predict,
         
     | 
| 
      
 380 
     | 
    
         
            +
                        topk=topk,
         
     | 
| 
      
 381 
     | 
    
         
            +
                    )
         
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
      
 383 
     | 
    
         
            +
                return predicts, accept_index, accept_token_num
         
     |