sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
 - sglang/bench_server_latency.py +21 -10
 - sglang/bench_serving.py +101 -7
 - sglang/global_config.py +0 -1
 - sglang/srt/layers/attention/__init__.py +27 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
 - sglang/srt/layers/attention/flashinfer_backend.py +352 -83
 - sglang/srt/layers/attention/triton_backend.py +6 -4
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
 - sglang/srt/layers/sampler.py +6 -2
 - sglang/srt/managers/detokenizer_manager.py +31 -10
 - sglang/srt/managers/io_struct.py +4 -0
 - sglang/srt/managers/schedule_batch.py +120 -43
 - sglang/srt/managers/schedule_policy.py +2 -1
 - sglang/srt/managers/scheduler.py +202 -140
 - sglang/srt/managers/tokenizer_manager.py +5 -1
 - sglang/srt/managers/tp_worker.py +111 -1
 - sglang/srt/mem_cache/chunk_cache.py +8 -4
 - sglang/srt/mem_cache/memory_pool.py +77 -4
 - sglang/srt/mem_cache/radix_cache.py +15 -7
 - sglang/srt/model_executor/cuda_graph_runner.py +4 -4
 - sglang/srt/model_executor/forward_batch_info.py +16 -21
 - sglang/srt/model_executor/model_runner.py +60 -1
 - sglang/srt/models/baichuan.py +2 -3
 - sglang/srt/models/chatglm.py +5 -6
 - sglang/srt/models/commandr.py +1 -2
 - sglang/srt/models/dbrx.py +1 -2
 - sglang/srt/models/deepseek.py +4 -5
 - sglang/srt/models/deepseek_v2.py +5 -6
 - sglang/srt/models/exaone.py +1 -2
 - sglang/srt/models/gemma.py +2 -2
 - sglang/srt/models/gemma2.py +5 -5
 - sglang/srt/models/gpt_bigcode.py +5 -5
 - sglang/srt/models/grok.py +1 -2
 - sglang/srt/models/internlm2.py +1 -2
 - sglang/srt/models/llama.py +1 -2
 - sglang/srt/models/llama_classification.py +1 -2
 - sglang/srt/models/llama_reward.py +2 -3
 - sglang/srt/models/llava.py +4 -8
 - sglang/srt/models/llavavid.py +1 -2
 - sglang/srt/models/minicpm.py +1 -2
 - sglang/srt/models/minicpm3.py +5 -6
 - sglang/srt/models/mixtral.py +1 -2
 - sglang/srt/models/mixtral_quant.py +1 -2
 - sglang/srt/models/olmo.py +352 -0
 - sglang/srt/models/olmoe.py +1 -2
 - sglang/srt/models/qwen.py +1 -2
 - sglang/srt/models/qwen2.py +1 -2
 - sglang/srt/models/qwen2_moe.py +4 -5
 - sglang/srt/models/stablelm.py +1 -2
 - sglang/srt/models/torch_native_llama.py +1 -2
 - sglang/srt/models/xverse.py +1 -2
 - sglang/srt/models/xverse_moe.py +4 -5
 - sglang/srt/models/yivl.py +1 -2
 - sglang/srt/openai_api/adapter.py +92 -49
 - sglang/srt/openai_api/protocol.py +10 -2
 - sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
 - sglang/srt/sampling/sampling_batch_info.py +92 -58
 - sglang/srt/sampling/sampling_params.py +2 -0
 - sglang/srt/server.py +116 -17
 - sglang/srt/server_args.py +121 -45
 - sglang/srt/utils.py +11 -3
 - sglang/test/few_shot_gsm8k.py +4 -1
 - sglang/test/few_shot_gsm8k_engine.py +144 -0
 - sglang/test/srt/sampling/penaltylib/utils.py +16 -12
 - sglang/version.py +1 -1
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
 - sglang/srt/layers/attention/flashinfer_utils.py +0 -237
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,237 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from enum import Enum, auto
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            import torch
         
     | 
| 
       4 
     | 
    
         
            -
            import triton
         
     | 
| 
       5 
     | 
    
         
            -
            import triton.language as tl
         
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
            class WrapperDispatch(Enum):
         
     | 
| 
       9 
     | 
    
         
            -
                SLIDING_WINDOW = auto()
         
     | 
| 
       10 
     | 
    
         
            -
                CROSS_ATTENTION = auto()
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
            @triton.jit
         
     | 
| 
       14 
     | 
    
         
            -
            def create_flashinfer_kv_indices_triton(
         
     | 
| 
       15 
     | 
    
         
            -
                req_to_token_ptr,  # [max_batch, max_context_len]
         
     | 
| 
       16 
     | 
    
         
            -
                req_pool_indices_ptr,
         
     | 
| 
       17 
     | 
    
         
            -
                page_kernel_lens_ptr,
         
     | 
| 
       18 
     | 
    
         
            -
                kv_indptr,
         
     | 
| 
       19 
     | 
    
         
            -
                kv_start_idx,
         
     | 
| 
       20 
     | 
    
         
            -
                kv_indices_ptr,
         
     | 
| 
       21 
     | 
    
         
            -
                max_context_len: tl.constexpr,
         
     | 
| 
       22 
     | 
    
         
            -
            ):
         
     | 
| 
       23 
     | 
    
         
            -
                BLOCK_SIZE: tl.constexpr = 512
         
     | 
| 
       24 
     | 
    
         
            -
                pid = tl.program_id(axis=0)
         
     | 
| 
       25 
     | 
    
         
            -
                req_pool_index = tl.load(req_pool_indices_ptr + pid)
         
     | 
| 
       26 
     | 
    
         
            -
                kv_indices_offset = tl.load(kv_indptr + pid)
         
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
                kv_start = 0
         
     | 
| 
       29 
     | 
    
         
            -
                kv_end = 0
         
     | 
| 
       30 
     | 
    
         
            -
                if kv_start_idx:
         
     | 
| 
       31 
     | 
    
         
            -
                    kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
         
     | 
| 
       32 
     | 
    
         
            -
                    kv_end = kv_start
         
     | 
| 
       33 
     | 
    
         
            -
                kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
         
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
                req_to_token_ptr += req_pool_index * max_context_len
         
     | 
| 
       36 
     | 
    
         
            -
                kv_indices_ptr += kv_indices_offset
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
                ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
         
     | 
| 
       39 
     | 
    
         
            -
                st_offset = tl.arange(0, BLOCK_SIZE)
         
     | 
| 
       40 
     | 
    
         
            -
                num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
         
     | 
| 
       41 
     | 
    
         
            -
                for _ in range(num_loop):
         
     | 
| 
       42 
     | 
    
         
            -
                    mask = ld_offset < kv_end
         
     | 
| 
       43 
     | 
    
         
            -
                    data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
         
     | 
| 
       44 
     | 
    
         
            -
                    tl.store(kv_indices_ptr + st_offset, data, mask=mask)
         
     | 
| 
       45 
     | 
    
         
            -
                    ld_offset += BLOCK_SIZE
         
     | 
| 
       46 
     | 
    
         
            -
                    st_offset += BLOCK_SIZE
         
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
             
     | 
| 
       49 
     | 
    
         
            -
            class FlashinferUpdater:
         
     | 
| 
       50 
     | 
    
         
            -
                def __init__(
         
     | 
| 
       51 
     | 
    
         
            -
                    self,
         
     | 
| 
       52 
     | 
    
         
            -
                    forward_mode,
         
     | 
| 
       53 
     | 
    
         
            -
                    model_runner,
         
     | 
| 
       54 
     | 
    
         
            -
                    req_pool_indices,
         
     | 
| 
       55 
     | 
    
         
            -
                    seq_lens,
         
     | 
| 
       56 
     | 
    
         
            -
                    prefix_lens,
         
     | 
| 
       57 
     | 
    
         
            -
                    decode_wrappers=None,
         
     | 
| 
       58 
     | 
    
         
            -
                    use_ragged=False,
         
     | 
| 
       59 
     | 
    
         
            -
                ):
         
     | 
| 
       60 
     | 
    
         
            -
                    self.forward_mode = forward_mode
         
     | 
| 
       61 
     | 
    
         
            -
                    self.model_runner = model_runner
         
     | 
| 
       62 
     | 
    
         
            -
                    self.req_pool_indices = req_pool_indices
         
     | 
| 
       63 
     | 
    
         
            -
                    self.seq_lens = seq_lens
         
     | 
| 
       64 
     | 
    
         
            -
                    self.prefix_lens = prefix_lens
         
     | 
| 
       65 
     | 
    
         
            -
                    self.use_ragged = use_ragged
         
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
                    self.num_qo_heads = (
         
     | 
| 
       68 
     | 
    
         
            -
                        model_runner.model_config.num_attention_heads // model_runner.tp_size
         
     | 
| 
       69 
     | 
    
         
            -
                    )
         
     | 
| 
       70 
     | 
    
         
            -
                    self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
         
     | 
| 
       71 
     | 
    
         
            -
                        model_runner.tp_size
         
     | 
| 
       72 
     | 
    
         
            -
                    )
         
     | 
| 
       73 
     | 
    
         
            -
                    self.head_dim = model_runner.model_config.head_dim
         
     | 
| 
       74 
     | 
    
         
            -
                    self.batch_size = len(req_pool_indices)
         
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                    self.decode_wrappers = (
         
     | 
| 
       77 
     | 
    
         
            -
                        decode_wrappers or self.model_runner.attn_backend.decode_wrappers
         
     | 
| 
       78 
     | 
    
         
            -
                    )
         
     | 
| 
       79 
     | 
    
         
            -
                    self.prefill_wrapper_ragged = (
         
     | 
| 
       80 
     | 
    
         
            -
                        self.model_runner.attn_backend.prefill_wrapper_ragged
         
     | 
| 
       81 
     | 
    
         
            -
                    )
         
     | 
| 
       82 
     | 
    
         
            -
                    self.prefill_wrappers_paged = (
         
     | 
| 
       83 
     | 
    
         
            -
                        self.model_runner.attn_backend.prefill_wrappers_paged
         
     | 
| 
       84 
     | 
    
         
            -
                    )
         
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
                    self.kv_last_page_len = torch.ones(
         
     | 
| 
       87 
     | 
    
         
            -
                        (self.batch_size,), dtype=torch.int32, device="cuda"
         
     | 
| 
       88 
     | 
    
         
            -
                    )
         
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
       90 
     | 
    
         
            -
                def _update_decode_indices(self, decode_wrapper):
         
     | 
| 
       91 
     | 
    
         
            -
                    assert not isinstance(decode_wrapper, list)
         
     | 
| 
       92 
     | 
    
         
            -
                    decode_wrapper.end_forward()
         
     | 
| 
       93 
     | 
    
         
            -
                    decode_wrapper.begin_forward(
         
     | 
| 
       94 
     | 
    
         
            -
                        self.kv_indptr,
         
     | 
| 
       95 
     | 
    
         
            -
                        self.kv_indices,
         
     | 
| 
       96 
     | 
    
         
            -
                        self.kv_last_page_len,
         
     | 
| 
       97 
     | 
    
         
            -
                        self.num_qo_heads,
         
     | 
| 
       98 
     | 
    
         
            -
                        self.num_kv_heads,
         
     | 
| 
       99 
     | 
    
         
            -
                        self.head_dim,
         
     | 
| 
       100 
     | 
    
         
            -
                        1,
         
     | 
| 
       101 
     | 
    
         
            -
                        data_type=self.model_runner.kv_cache_dtype,
         
     | 
| 
       102 
     | 
    
         
            -
                        q_data_type=self.model_runner.dtype,
         
     | 
| 
       103 
     | 
    
         
            -
                    )
         
     | 
| 
       104 
     | 
    
         
            -
             
     | 
| 
       105 
     | 
    
         
            -
                def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
         
     | 
| 
       106 
     | 
    
         
            -
                    assert not isinstance(paged_wrapper, list)
         
     | 
| 
       107 
     | 
    
         
            -
                    assert not isinstance(ragged_wrapper, list)
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
     | 
    
         
            -
                    # extend part
         
     | 
| 
       110 
     | 
    
         
            -
                    qo_indptr = torch.zeros(
         
     | 
| 
       111 
     | 
    
         
            -
                        (self.batch_size + 1,), dtype=torch.int32, device="cuda"
         
     | 
| 
       112 
     | 
    
         
            -
                    )
         
     | 
| 
       113 
     | 
    
         
            -
                    qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
         
     | 
| 
       114 
     | 
    
         
            -
             
     | 
| 
       115 
     | 
    
         
            -
                    if self.use_ragged:
         
     | 
| 
       116 
     | 
    
         
            -
                        ragged_wrapper.end_forward()
         
     | 
| 
       117 
     | 
    
         
            -
                        ragged_wrapper.begin_forward(
         
     | 
| 
       118 
     | 
    
         
            -
                            qo_indptr,
         
     | 
| 
       119 
     | 
    
         
            -
                            qo_indptr,
         
     | 
| 
       120 
     | 
    
         
            -
                            self.num_qo_heads,
         
     | 
| 
       121 
     | 
    
         
            -
                            self.num_kv_heads,
         
     | 
| 
       122 
     | 
    
         
            -
                            self.head_dim,
         
     | 
| 
       123 
     | 
    
         
            -
                        )
         
     | 
| 
       124 
     | 
    
         
            -
             
     | 
| 
       125 
     | 
    
         
            -
                    # cached part
         
     | 
| 
       126 
     | 
    
         
            -
                    paged_wrapper.end_forward()
         
     | 
| 
       127 
     | 
    
         
            -
                    paged_wrapper.begin_forward(
         
     | 
| 
       128 
     | 
    
         
            -
                        qo_indptr,
         
     | 
| 
       129 
     | 
    
         
            -
                        self.kv_indptr,
         
     | 
| 
       130 
     | 
    
         
            -
                        self.kv_indices,
         
     | 
| 
       131 
     | 
    
         
            -
                        self.kv_last_page_len,
         
     | 
| 
       132 
     | 
    
         
            -
                        self.num_qo_heads,
         
     | 
| 
       133 
     | 
    
         
            -
                        self.num_kv_heads,
         
     | 
| 
       134 
     | 
    
         
            -
                        self.head_dim,
         
     | 
| 
       135 
     | 
    
         
            -
                        1,
         
     | 
| 
       136 
     | 
    
         
            -
                    )
         
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
                def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
         
     | 
| 
       139 
     | 
    
         
            -
                    if dispatch_reason is None:
         
     | 
| 
       140 
     | 
    
         
            -
                        if self.use_ragged:
         
     | 
| 
       141 
     | 
    
         
            -
                            paged_kernel_lens = self.prefix_lens
         
     | 
| 
       142 
     | 
    
         
            -
                        else:
         
     | 
| 
       143 
     | 
    
         
            -
                            paged_kernel_lens = self.seq_lens
         
     | 
| 
       144 
     | 
    
         
            -
                        self.kv_start_idx = None
         
     | 
| 
       145 
     | 
    
         
            -
                    elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
       146 
     | 
    
         
            -
                        if wrapper_id == 0:
         
     | 
| 
       147 
     | 
    
         
            -
                            # window attention use paged only
         
     | 
| 
       148 
     | 
    
         
            -
                            if self.forward_mode.is_decode():
         
     | 
| 
       149 
     | 
    
         
            -
                                paged_kernel_lens = torch.minimum(
         
     | 
| 
       150 
     | 
    
         
            -
                                    self.seq_lens,
         
     | 
| 
       151 
     | 
    
         
            -
                                    torch.tensor(self.model_runner.sliding_window_size + 1),
         
     | 
| 
       152 
     | 
    
         
            -
                                )
         
     | 
| 
       153 
     | 
    
         
            -
                            else:
         
     | 
| 
       154 
     | 
    
         
            -
                                paged_kernel_lens = torch.minimum(
         
     | 
| 
       155 
     | 
    
         
            -
                                    self.seq_lens,
         
     | 
| 
       156 
     | 
    
         
            -
                                    torch.tensor(self.model_runner.sliding_window_size)
         
     | 
| 
       157 
     | 
    
         
            -
                                    + self.seq_lens
         
     | 
| 
       158 
     | 
    
         
            -
                                    - self.prefix_lens,
         
     | 
| 
       159 
     | 
    
         
            -
                                )
         
     | 
| 
       160 
     | 
    
         
            -
                        else:
         
     | 
| 
       161 
     | 
    
         
            -
                            # full attention
         
     | 
| 
       162 
     | 
    
         
            -
                            paged_kernel_lens = self.seq_lens
         
     | 
| 
       163 
     | 
    
         
            -
                        self.kv_start_idx = self.seq_lens - paged_kernel_lens
         
     | 
| 
       164 
     | 
    
         
            -
             
     | 
| 
       165 
     | 
    
         
            -
                    self.kv_indptr = torch.zeros(
         
     | 
| 
       166 
     | 
    
         
            -
                        (self.batch_size + 1,), dtype=torch.int32, device="cuda"
         
     | 
| 
       167 
     | 
    
         
            -
                    )
         
     | 
| 
       168 
     | 
    
         
            -
                    self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
       169 
     | 
    
         
            -
                    self.kv_indices = torch.empty(
         
     | 
| 
       170 
     | 
    
         
            -
                        self.kv_indptr[-1], dtype=torch.int32, device="cuda"
         
     | 
| 
       171 
     | 
    
         
            -
                    )
         
     | 
| 
       172 
     | 
    
         
            -
             
     | 
| 
       173 
     | 
    
         
            -
                    create_flashinfer_kv_indices_triton[(self.batch_size,)](
         
     | 
| 
       174 
     | 
    
         
            -
                        self.model_runner.req_to_token_pool.req_to_token,
         
     | 
| 
       175 
     | 
    
         
            -
                        self.req_pool_indices,
         
     | 
| 
       176 
     | 
    
         
            -
                        paged_kernel_lens,
         
     | 
| 
       177 
     | 
    
         
            -
                        self.kv_indptr,
         
     | 
| 
       178 
     | 
    
         
            -
                        self.kv_start_idx,
         
     | 
| 
       179 
     | 
    
         
            -
                        self.kv_indices,
         
     | 
| 
       180 
     | 
    
         
            -
                        self.model_runner.req_to_token_pool.req_to_token.size(1),
         
     | 
| 
       181 
     | 
    
         
            -
                    )
         
     | 
| 
       182 
     | 
    
         
            -
             
     | 
| 
       183 
     | 
    
         
            -
                def _update_indicess_single_wrapper(self):
         
     | 
| 
       184 
     | 
    
         
            -
                    self._get_indices()
         
     | 
| 
       185 
     | 
    
         
            -
             
     | 
| 
       186 
     | 
    
         
            -
                    if self.forward_mode.is_decode():
         
     | 
| 
       187 
     | 
    
         
            -
                        self._update_decode_indices(self.decode_wrappers[0])
         
     | 
| 
       188 
     | 
    
         
            -
                    else:
         
     | 
| 
       189 
     | 
    
         
            -
                        self._update_extend_indices(
         
     | 
| 
       190 
     | 
    
         
            -
                            self.prefill_wrapper_ragged,
         
     | 
| 
       191 
     | 
    
         
            -
                            self.prefill_wrappers_paged[0],
         
     | 
| 
       192 
     | 
    
         
            -
                        )
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
                def _update_indices_cross_attention(self):
         
     | 
| 
       195 
     | 
    
         
            -
                    pass
         
     | 
| 
       196 
     | 
    
         
            -
             
     | 
| 
       197 
     | 
    
         
            -
                def _update_indices_sliding_window(self):
         
     | 
| 
       198 
     | 
    
         
            -
                    assert self.use_ragged is False
         
     | 
| 
       199 
     | 
    
         
            -
                    for wrapper_id in range(2):
         
     | 
| 
       200 
     | 
    
         
            -
                        self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
         
     | 
| 
       201 
     | 
    
         
            -
                        if self.forward_mode.is_decode():
         
     | 
| 
       202 
     | 
    
         
            -
                            self._update_decode_indices(self.decode_wrappers[wrapper_id])
         
     | 
| 
       203 
     | 
    
         
            -
                        else:
         
     | 
| 
       204 
     | 
    
         
            -
                            self._update_extend_indices(
         
     | 
| 
       205 
     | 
    
         
            -
                                None,
         
     | 
| 
       206 
     | 
    
         
            -
                                self.prefill_wrappers_paged[wrapper_id],
         
     | 
| 
       207 
     | 
    
         
            -
                            )
         
     | 
| 
       208 
     | 
    
         
            -
             
     | 
| 
       209 
     | 
    
         
            -
             
     | 
| 
       210 
     | 
    
         
            -
            def update_flashinfer_indices(
         
     | 
| 
       211 
     | 
    
         
            -
                forward_mode,
         
     | 
| 
       212 
     | 
    
         
            -
                model_runner,
         
     | 
| 
       213 
     | 
    
         
            -
                req_pool_indices,
         
     | 
| 
       214 
     | 
    
         
            -
                seq_lens,
         
     | 
| 
       215 
     | 
    
         
            -
                prefix_lens,
         
     | 
| 
       216 
     | 
    
         
            -
                decode_wrappers=None,
         
     | 
| 
       217 
     | 
    
         
            -
                use_ragged=False,
         
     | 
| 
       218 
     | 
    
         
            -
            ):
         
     | 
| 
       219 
     | 
    
         
            -
                updater = FlashinferUpdater(
         
     | 
| 
       220 
     | 
    
         
            -
                    forward_mode,
         
     | 
| 
       221 
     | 
    
         
            -
                    model_runner,
         
     | 
| 
       222 
     | 
    
         
            -
                    req_pool_indices,
         
     | 
| 
       223 
     | 
    
         
            -
                    seq_lens,
         
     | 
| 
       224 
     | 
    
         
            -
                    prefix_lens,
         
     | 
| 
       225 
     | 
    
         
            -
                    decode_wrappers,
         
     | 
| 
       226 
     | 
    
         
            -
                    use_ragged,
         
     | 
| 
       227 
     | 
    
         
            -
                )
         
     | 
| 
       228 
     | 
    
         
            -
             
     | 
| 
       229 
     | 
    
         
            -
                dispatch_reason = model_runner.attn_backend.dispatch_reason
         
     | 
| 
       230 
     | 
    
         
            -
             
     | 
| 
       231 
     | 
    
         
            -
                if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
       232 
     | 
    
         
            -
                    updater._update_indices_sliding_window()
         
     | 
| 
       233 
     | 
    
         
            -
                elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
       234 
     | 
    
         
            -
                    updater._update_indices_cross_attention()
         
     | 
| 
       235 
     | 
    
         
            -
                else:
         
     | 
| 
       236 
     | 
    
         
            -
                    assert model_runner.attn_backend.num_wrappers == 1
         
     | 
| 
       237 
     | 
    
         
            -
                    updater._update_indicess_single_wrapper()
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |