sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
 - sglang/bench_server_latency.py +21 -10
 - sglang/bench_serving.py +101 -7
 - sglang/global_config.py +0 -1
 - sglang/srt/layers/attention/__init__.py +27 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
 - sglang/srt/layers/attention/flashinfer_backend.py +352 -83
 - sglang/srt/layers/attention/triton_backend.py +6 -4
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
 - sglang/srt/layers/sampler.py +6 -2
 - sglang/srt/managers/detokenizer_manager.py +31 -10
 - sglang/srt/managers/io_struct.py +4 -0
 - sglang/srt/managers/schedule_batch.py +120 -43
 - sglang/srt/managers/schedule_policy.py +2 -1
 - sglang/srt/managers/scheduler.py +202 -140
 - sglang/srt/managers/tokenizer_manager.py +5 -1
 - sglang/srt/managers/tp_worker.py +111 -1
 - sglang/srt/mem_cache/chunk_cache.py +8 -4
 - sglang/srt/mem_cache/memory_pool.py +77 -4
 - sglang/srt/mem_cache/radix_cache.py +15 -7
 - sglang/srt/model_executor/cuda_graph_runner.py +4 -4
 - sglang/srt/model_executor/forward_batch_info.py +16 -21
 - sglang/srt/model_executor/model_runner.py +60 -1
 - sglang/srt/models/baichuan.py +2 -3
 - sglang/srt/models/chatglm.py +5 -6
 - sglang/srt/models/commandr.py +1 -2
 - sglang/srt/models/dbrx.py +1 -2
 - sglang/srt/models/deepseek.py +4 -5
 - sglang/srt/models/deepseek_v2.py +5 -6
 - sglang/srt/models/exaone.py +1 -2
 - sglang/srt/models/gemma.py +2 -2
 - sglang/srt/models/gemma2.py +5 -5
 - sglang/srt/models/gpt_bigcode.py +5 -5
 - sglang/srt/models/grok.py +1 -2
 - sglang/srt/models/internlm2.py +1 -2
 - sglang/srt/models/llama.py +1 -2
 - sglang/srt/models/llama_classification.py +1 -2
 - sglang/srt/models/llama_reward.py +2 -3
 - sglang/srt/models/llava.py +4 -8
 - sglang/srt/models/llavavid.py +1 -2
 - sglang/srt/models/minicpm.py +1 -2
 - sglang/srt/models/minicpm3.py +5 -6
 - sglang/srt/models/mixtral.py +1 -2
 - sglang/srt/models/mixtral_quant.py +1 -2
 - sglang/srt/models/olmo.py +352 -0
 - sglang/srt/models/olmoe.py +1 -2
 - sglang/srt/models/qwen.py +1 -2
 - sglang/srt/models/qwen2.py +1 -2
 - sglang/srt/models/qwen2_moe.py +4 -5
 - sglang/srt/models/stablelm.py +1 -2
 - sglang/srt/models/torch_native_llama.py +1 -2
 - sglang/srt/models/xverse.py +1 -2
 - sglang/srt/models/xverse_moe.py +4 -5
 - sglang/srt/models/yivl.py +1 -2
 - sglang/srt/openai_api/adapter.py +92 -49
 - sglang/srt/openai_api/protocol.py +10 -2
 - sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
 - sglang/srt/sampling/sampling_batch_info.py +92 -58
 - sglang/srt/sampling/sampling_params.py +2 -0
 - sglang/srt/server.py +116 -17
 - sglang/srt/server_args.py +121 -45
 - sglang/srt/utils.py +11 -3
 - sglang/test/few_shot_gsm8k.py +4 -1
 - sglang/test/few_shot_gsm8k_engine.py +144 -0
 - sglang/test/srt/sampling/penaltylib/utils.py +16 -12
 - sglang/version.py +1 -1
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
 - sglang/srt/layers/attention/flashinfer_utils.py +0 -237
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
 
| 
         @@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize. 
     | 
|
| 
       7 
7 
     | 
    
         
             
            Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
         
     | 
| 
       8 
8 
     | 
    
         
             
            """
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
      
 10 
     | 
    
         
            +
            from enum import Enum, auto
         
     | 
| 
       10 
11 
     | 
    
         
             
            from typing import TYPE_CHECKING
         
     | 
| 
       11 
12 
     | 
    
         | 
| 
       12 
13 
     | 
    
         
             
            import torch
         
     | 
| 
       13 
14 
     | 
    
         
             
            import torch.nn as nn
         
     | 
| 
      
 15 
     | 
    
         
            +
            import triton
         
     | 
| 
      
 16 
     | 
    
         
            +
            import triton.language as tl
         
     | 
| 
       14 
17 
     | 
    
         | 
| 
       15 
18 
     | 
    
         
             
            from sglang.global_config import global_config
         
     | 
| 
       16 
19 
     | 
    
         
             
            from sglang.srt.layers.attention import AttentionBackend
         
     | 
| 
       17 
     | 
    
         
            -
            from sglang.srt. 
     | 
| 
       18 
     | 
    
         
            -
                WrapperDispatch,
         
     | 
| 
       19 
     | 
    
         
            -
                update_flashinfer_indices,
         
     | 
| 
       20 
     | 
    
         
            -
            )
         
     | 
| 
       21 
     | 
    
         
            -
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
         
     | 
| 
      
 20 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       22 
21 
     | 
    
         
             
            from sglang.srt.utils import is_flashinfer_available
         
     | 
| 
       23 
22 
     | 
    
         | 
| 
       24 
23 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
         @@ -34,13 +33,18 @@ if is_flashinfer_available(): 
     | 
|
| 
       34 
33 
     | 
    
         
             
                from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
         
     | 
| 
       35 
34 
     | 
    
         | 
| 
       36 
35 
     | 
    
         | 
| 
      
 36 
     | 
    
         
            +
            class WrapperDispatch(Enum):
         
     | 
| 
      
 37 
     | 
    
         
            +
                SLIDING_WINDOW = auto()
         
     | 
| 
      
 38 
     | 
    
         
            +
                CROSS_ATTENTION = auto()
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
       37 
41 
     | 
    
         
             
            class FlashInferAttnBackend(AttentionBackend):
         
     | 
| 
       38 
42 
     | 
    
         
             
                """Flashinfer attention kernels."""
         
     | 
| 
       39 
43 
     | 
    
         | 
| 
       40 
44 
     | 
    
         
             
                def __init__(self, model_runner: ModelRunner):
         
     | 
| 
       41 
45 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       42 
     | 
    
         
            -
                    self.model_runner = model_runner
         
     | 
| 
       43 
46 
     | 
    
         | 
| 
      
 47 
     | 
    
         
            +
                    # Parse constants
         
     | 
| 
       44 
48 
     | 
    
         
             
                    if not _grouped_size_compiled_for_decode_kernels(
         
     | 
| 
       45 
49 
     | 
    
         
             
                        model_runner.model_config.num_attention_heads // model_runner.tp_size,
         
     | 
| 
       46 
50 
     | 
    
         
             
                        model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
         
     | 
| 
         @@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       48 
52 
     | 
    
         
             
                        self.decode_use_tensor_cores = True
         
     | 
| 
       49 
53 
     | 
    
         
             
                    else:
         
     | 
| 
       50 
54 
     | 
    
         
             
                        self.decode_use_tensor_cores = False
         
     | 
| 
       51 
     | 
    
         
            -
             
     | 
| 
       52 
     | 
    
         
            -
                    self.workspace_buffer = torch.empty(
         
     | 
| 
       53 
     | 
    
         
            -
                        global_config.flashinfer_workspace_size,
         
     | 
| 
       54 
     | 
    
         
            -
                        dtype=torch.uint8,
         
     | 
| 
       55 
     | 
    
         
            -
                        device="cuda",
         
     | 
| 
       56 
     | 
    
         
            -
                    )
         
     | 
| 
      
 55 
     | 
    
         
            +
                    self.max_context_len = model_runner.model_config.context_len
         
     | 
| 
       57 
56 
     | 
    
         | 
| 
       58 
57 
     | 
    
         
             
                    assert not (
         
     | 
| 
       59 
58 
     | 
    
         
             
                        model_runner.sliding_window_size is not None
         
     | 
| 
       60 
59 
     | 
    
         
             
                        and model_runner.has_cross_attention
         
     | 
| 
       61 
60 
     | 
    
         
             
                    ), "Sliding window and cross attention are not supported together"
         
     | 
| 
       62 
61 
     | 
    
         | 
| 
       63 
     | 
    
         
            -
                    self.num_wrappers = 1
         
     | 
| 
       64 
     | 
    
         
            -
                    self.dispatch_reason = None
         
     | 
| 
       65 
62 
     | 
    
         
             
                    if model_runner.sliding_window_size is not None:
         
     | 
| 
       66 
63 
     | 
    
         
             
                        self.num_wrappers = 2
         
     | 
| 
       67 
64 
     | 
    
         
             
                        self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
         
     | 
| 
       68 
65 
     | 
    
         
             
                    elif model_runner.has_cross_attention:
         
     | 
| 
       69 
66 
     | 
    
         
             
                        self.num_wrappers = 2
         
     | 
| 
       70 
67 
     | 
    
         
             
                        self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
         
     | 
| 
      
 68 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 69 
     | 
    
         
            +
                        self.num_wrappers = 1
         
     | 
| 
      
 70 
     | 
    
         
            +
                        self.dispatch_reason = None
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                    # Allocate buffers
         
     | 
| 
      
 73 
     | 
    
         
            +
                    self.workspace_buffer = torch.empty(
         
     | 
| 
      
 74 
     | 
    
         
            +
                        global_config.flashinfer_workspace_size,
         
     | 
| 
      
 75 
     | 
    
         
            +
                        dtype=torch.uint8,
         
     | 
| 
      
 76 
     | 
    
         
            +
                        device=model_runner.device,
         
     | 
| 
      
 77 
     | 
    
         
            +
                    )
         
     | 
| 
      
 78 
     | 
    
         
            +
                    max_bs = model_runner.req_to_token_pool.size
         
     | 
| 
      
 79 
     | 
    
         
            +
                    self.kv_indptr = [
         
     | 
| 
      
 80 
     | 
    
         
            +
                        torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
         
     | 
| 
      
 81 
     | 
    
         
            +
                        for _ in range(self.num_wrappers)
         
     | 
| 
      
 82 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 83 
     | 
    
         
            +
                    self.kv_last_page_len = torch.ones(
         
     | 
| 
      
 84 
     | 
    
         
            +
                        (max_bs,), dtype=torch.int32, device=model_runner.device
         
     | 
| 
      
 85 
     | 
    
         
            +
                    )
         
     | 
| 
      
 86 
     | 
    
         
            +
                    self.qo_indptr = [
         
     | 
| 
      
 87 
     | 
    
         
            +
                        torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
         
     | 
| 
      
 88 
     | 
    
         
            +
                        for _ in range(self.num_wrappers)
         
     | 
| 
      
 89 
     | 
    
         
            +
                    ]
         
     | 
| 
       71 
90 
     | 
    
         | 
| 
      
 91 
     | 
    
         
            +
                    # Create wrappers
         
     | 
| 
       72 
92 
     | 
    
         
             
                    # NOTE: we do not use ragged attention when there are multiple wrappers
         
     | 
| 
       73 
93 
     | 
    
         
             
                    self.prefill_wrapper_ragged = (
         
     | 
| 
       74 
94 
     | 
    
         
             
                        BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
         
     | 
| 
         @@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       92 
112 
     | 
    
         
             
                            )
         
     | 
| 
       93 
113 
     | 
    
         
             
                        )
         
     | 
| 
       94 
114 
     | 
    
         | 
| 
      
 115 
     | 
    
         
            +
                    # Create indices updater
         
     | 
| 
      
 116 
     | 
    
         
            +
                    self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
         
     | 
| 
      
 118 
     | 
    
         
            +
                        model_runner, self
         
     | 
| 
      
 119 
     | 
    
         
            +
                    )
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                    # Other metadata
         
     | 
| 
       95 
122 
     | 
    
         
             
                    self.forward_metadata = None
         
     | 
| 
       96 
123 
     | 
    
         
             
                    self.cuda_graph_metadata = {}
         
     | 
| 
       97 
124 
     | 
    
         | 
| 
       98 
     | 
    
         
            -
                def _get_wrapper_idx(self, layer: nn.Module):
         
     | 
| 
       99 
     | 
    
         
            -
                    if self.num_wrappers == 1:
         
     | 
| 
       100 
     | 
    
         
            -
                        return 0
         
     | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
                    if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
       103 
     | 
    
         
            -
                        return layer.sliding_window_size == -1
         
     | 
| 
       104 
     | 
    
         
            -
                    if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
       105 
     | 
    
         
            -
                        return layer.is_cross_attention
         
     | 
| 
       106 
     | 
    
         
            -
             
     | 
| 
       107 
     | 
    
         
            -
                    raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
125 
     | 
    
         
             
                def init_forward_metadata(self, forward_batch: ForwardBatch):
         
     | 
| 
       110 
126 
     | 
    
         
             
                    if forward_batch.forward_mode.is_decode():
         
     | 
| 
       111 
     | 
    
         
            -
                         
     | 
| 
       112 
     | 
    
         
            -
             
     | 
| 
       113 
     | 
    
         
            -
             
     | 
| 
       114 
     | 
    
         
            -
                         
     | 
| 
      
 127 
     | 
    
         
            +
                        self.indices_updater_decode.update(
         
     | 
| 
      
 128 
     | 
    
         
            +
                            forward_batch.req_pool_indices,
         
     | 
| 
      
 129 
     | 
    
         
            +
                            forward_batch.seq_lens,
         
     | 
| 
      
 130 
     | 
    
         
            +
                        )
         
     | 
| 
      
 131 
     | 
    
         
            +
                        self.forward_metadata = (self.decode_wrappers,)
         
     | 
| 
       115 
132 
     | 
    
         
             
                    else:
         
     | 
| 
       116 
133 
     | 
    
         
             
                        prefix_lens = forward_batch.extend_prefix_lens
         
     | 
| 
       117 
134 
     | 
    
         | 
| 
         @@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       123 
140 
     | 
    
         
             
                        ):
         
     | 
| 
       124 
141 
     | 
    
         
             
                            use_ragged = True
         
     | 
| 
       125 
142 
     | 
    
         | 
| 
       126 
     | 
    
         
            -
                        total_num_tokens = torch.sum(forward_batch.seq_lens).item()
         
     | 
| 
       127 
143 
     | 
    
         
             
                        extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
         
     | 
| 
       128 
144 
     | 
    
         | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
             
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                         
     | 
| 
       135 
     | 
    
         
            -
                        use_ragged=use_ragged,
         
     | 
| 
       136 
     | 
    
         
            -
                    )
         
     | 
| 
      
 145 
     | 
    
         
            +
                        self.indices_updater_prefill.update(
         
     | 
| 
      
 146 
     | 
    
         
            +
                            forward_batch.req_pool_indices,
         
     | 
| 
      
 147 
     | 
    
         
            +
                            forward_batch.seq_lens,
         
     | 
| 
      
 148 
     | 
    
         
            +
                            prefix_lens,
         
     | 
| 
      
 149 
     | 
    
         
            +
                            use_ragged,
         
     | 
| 
      
 150 
     | 
    
         
            +
                        )
         
     | 
| 
       137 
151 
     | 
    
         | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
             
     | 
| 
       140 
     | 
    
         
            -
             
     | 
| 
       141 
     | 
    
         
            -
                         
     | 
| 
       142 
     | 
    
         
            -
                        self.decode_wrappers,
         
     | 
| 
       143 
     | 
    
         
            -
                    )
         
     | 
| 
      
 152 
     | 
    
         
            +
                        self.forward_metadata = (
         
     | 
| 
      
 153 
     | 
    
         
            +
                            use_ragged,
         
     | 
| 
      
 154 
     | 
    
         
            +
                            extend_no_prefix,
         
     | 
| 
      
 155 
     | 
    
         
            +
                        )
         
     | 
| 
       144 
156 
     | 
    
         | 
| 
       145 
157 
     | 
    
         
             
                def init_cuda_graph_state(self, max_bs: int):
         
     | 
| 
       146 
     | 
    
         
            -
                     
     | 
| 
       147 
     | 
    
         
            -
                        (max_bs  
     | 
| 
       148 
     | 
    
         
            -
                    )
         
     | 
| 
       149 
     | 
    
         
            -
                    self.cuda_graph_kv_indices = torch.zeros(
         
     | 
| 
       150 
     | 
    
         
            -
                        (max_bs * self.model_runner.model_config.context_len,),
         
     | 
| 
      
 158 
     | 
    
         
            +
                    cuda_graph_kv_indices = torch.zeros(
         
     | 
| 
      
 159 
     | 
    
         
            +
                        (max_bs * self.max_context_len,),
         
     | 
| 
       151 
160 
     | 
    
         
             
                        dtype=torch.int32,
         
     | 
| 
       152 
161 
     | 
    
         
             
                        device="cuda",
         
     | 
| 
       153 
162 
     | 
    
         
             
                    )
         
     | 
| 
       154 
     | 
    
         
            -
                    self. 
     | 
| 
       155 
     | 
    
         
            -
                        ( 
     | 
| 
       156 
     | 
    
         
            -
                    )
         
     | 
| 
       157 
     | 
    
         
            -
             
     | 
| 
       158 
     | 
    
         
            -
                    # NOTE: the buffers are always in the form of list
         
     | 
| 
       159 
     | 
    
         
            -
                    self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
         
     | 
| 
       160 
     | 
    
         
            -
                        self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
         
     | 
| 
       161 
     | 
    
         
            -
                    ]
         
     | 
| 
       162 
     | 
    
         
            -
                    self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
         
     | 
| 
       163 
     | 
    
         
            -
                        self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
         
     | 
| 
      
 163 
     | 
    
         
            +
                    self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
         
     | 
| 
      
 164 
     | 
    
         
            +
                        cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
         
     | 
| 
       164 
165 
     | 
    
         
             
                    ]
         
     | 
| 
       165 
166 
     | 
    
         | 
| 
       166 
167 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       167 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 168 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       168 
169 
     | 
    
         
             
                ):
         
     | 
| 
       169 
170 
     | 
    
         
             
                    decode_wrappers = []
         
     | 
| 
       170 
171 
     | 
    
         
             
                    for i in range(self.num_wrappers):
         
     | 
| 
         @@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       174 
175 
     | 
    
         
             
                                "NHD",
         
     | 
| 
       175 
176 
     | 
    
         
             
                                use_cuda_graph=True,
         
     | 
| 
       176 
177 
     | 
    
         
             
                                use_tensor_cores=self.decode_use_tensor_cores,
         
     | 
| 
       177 
     | 
    
         
            -
                                paged_kv_indptr_buffer=self. 
     | 
| 
      
 178 
     | 
    
         
            +
                                paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
         
     | 
| 
       178 
179 
     | 
    
         
             
                                paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
         
     | 
| 
       179 
     | 
    
         
            -
                                paged_kv_last_page_len_buffer=self. 
     | 
| 
      
 180 
     | 
    
         
            +
                                paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
         
     | 
| 
       180 
181 
     | 
    
         
             
                            )
         
     | 
| 
       181 
182 
     | 
    
         
             
                        )
         
     | 
| 
       182 
183 
     | 
    
         | 
| 
       183 
     | 
    
         
            -
                     
     | 
| 
       184 
     | 
    
         
            -
                        ForwardMode.DECODE,
         
     | 
| 
       185 
     | 
    
         
            -
                        self.model_runner,
         
     | 
| 
       186 
     | 
    
         
            -
                        req_pool_indices,
         
     | 
| 
       187 
     | 
    
         
            -
                        seq_lens,
         
     | 
| 
       188 
     | 
    
         
            -
                        None,
         
     | 
| 
       189 
     | 
    
         
            -
                        decode_wrappers,
         
     | 
| 
       190 
     | 
    
         
            -
                    )
         
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
      
 184 
     | 
    
         
            +
                    self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
         
     | 
| 
       192 
185 
     | 
    
         
             
                    self.cuda_graph_metadata[bs] = decode_wrappers
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
                    self.forward_metadata = (False, False, None, decode_wrappers)
         
     | 
| 
      
 186 
     | 
    
         
            +
                    self.forward_metadata = (decode_wrappers,)
         
     | 
| 
       195 
187 
     | 
    
         | 
| 
       196 
188 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       197 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 189 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       198 
190 
     | 
    
         
             
                ):
         
     | 
| 
       199 
     | 
    
         
            -
                     
     | 
| 
       200 
     | 
    
         
            -
                         
     | 
| 
       201 
     | 
    
         
            -
                        self.model_runner,
         
     | 
| 
       202 
     | 
    
         
            -
                        req_pool_indices[:bs],
         
     | 
| 
       203 
     | 
    
         
            -
                        seq_lens[:bs],
         
     | 
| 
       204 
     | 
    
         
            -
                        None,
         
     | 
| 
       205 
     | 
    
         
            -
                        self.cuda_graph_metadata[bs],
         
     | 
| 
      
 191 
     | 
    
         
            +
                    self.indices_updater_decode.update(
         
     | 
| 
      
 192 
     | 
    
         
            +
                        req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
         
     | 
| 
       206 
193 
     | 
    
         
             
                    )
         
     | 
| 
       207 
194 
     | 
    
         | 
| 
       208 
195 
     | 
    
         
             
                def get_cuda_graph_seq_len_fill_value(self):
         
     | 
| 
         @@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       213 
200 
     | 
    
         
             
                        self._get_wrapper_idx(layer)
         
     | 
| 
       214 
201 
     | 
    
         
             
                    ]
         
     | 
| 
       215 
202 
     | 
    
         | 
| 
       216 
     | 
    
         
            -
                    use_ragged, extend_no_prefix 
     | 
| 
      
 203 
     | 
    
         
            +
                    use_ragged, extend_no_prefix = self.forward_metadata
         
     | 
| 
       217 
204 
     | 
    
         | 
| 
       218 
205 
     | 
    
         
             
                    if not use_ragged:
         
     | 
| 
       219 
206 
     | 
    
         
             
                        if k is not None:
         
     | 
| 
         @@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       259 
246 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
       260 
247 
     | 
    
         | 
| 
       261 
248 
     | 
    
         
             
                def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
         
     | 
| 
       262 
     | 
    
         
            -
                    decode_wrapper = self.forward_metadata[ 
     | 
| 
      
 249 
     | 
    
         
            +
                    decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
         
     | 
| 
       263 
250 
     | 
    
         | 
| 
       264 
251 
     | 
    
         
             
                    if k is not None:
         
     | 
| 
       265 
252 
     | 
    
         
             
                        assert v is not None
         
     | 
| 
         @@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       275 
262 
     | 
    
         
             
                    )
         
     | 
| 
       276 
263 
     | 
    
         | 
| 
       277 
264 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
      
 265 
     | 
    
         
            +
             
     | 
| 
      
 266 
     | 
    
         
            +
                def _get_wrapper_idx(self, layer: nn.Module):
         
     | 
| 
      
 267 
     | 
    
         
            +
                    if self.num_wrappers == 1:
         
     | 
| 
      
 268 
     | 
    
         
            +
                        return 0
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                    if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
      
 271 
     | 
    
         
            +
                        return layer.sliding_window_size == -1
         
     | 
| 
      
 272 
     | 
    
         
            +
                    if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
      
 273 
     | 
    
         
            +
                        return layer.is_cross_attention
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
                    raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
         
     | 
| 
      
 276 
     | 
    
         
            +
             
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
            class FlashInferIndicesUpdaterDecode:
         
     | 
| 
      
 279 
     | 
    
         
            +
                def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
         
     | 
| 
      
 280 
     | 
    
         
            +
                    # Constants
         
     | 
| 
      
 281 
     | 
    
         
            +
                    self.num_qo_heads = (
         
     | 
| 
      
 282 
     | 
    
         
            +
                        model_runner.model_config.num_attention_heads // model_runner.tp_size
         
     | 
| 
      
 283 
     | 
    
         
            +
                    )
         
     | 
| 
      
 284 
     | 
    
         
            +
                    self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
         
     | 
| 
      
 285 
     | 
    
         
            +
                        model_runner.tp_size
         
     | 
| 
      
 286 
     | 
    
         
            +
                    )
         
     | 
| 
      
 287 
     | 
    
         
            +
                    self.head_dim = model_runner.model_config.head_dim
         
     | 
| 
      
 288 
     | 
    
         
            +
                    self.data_type = model_runner.kv_cache_dtype
         
     | 
| 
      
 289 
     | 
    
         
            +
                    self.q_data_type = model_runner.dtype
         
     | 
| 
      
 290 
     | 
    
         
            +
                    self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
         
     | 
| 
      
 291 
     | 
    
         
            +
                    self.sliding_window_size = model_runner.sliding_window_size
         
     | 
| 
      
 292 
     | 
    
         
            +
             
     | 
| 
      
 293 
     | 
    
         
            +
                    # Buffers and wrappers
         
     | 
| 
      
 294 
     | 
    
         
            +
                    self.kv_indptr = attn_backend.kv_indptr
         
     | 
| 
      
 295 
     | 
    
         
            +
                    self.kv_last_page_len = attn_backend.kv_last_page_len
         
     | 
| 
      
 296 
     | 
    
         
            +
                    self.req_to_token = model_runner.req_to_token_pool.req_to_token
         
     | 
| 
      
 297 
     | 
    
         
            +
                    self.decode_wrappers = attn_backend.decode_wrappers
         
     | 
| 
      
 298 
     | 
    
         
            +
             
     | 
| 
      
 299 
     | 
    
         
            +
                    # Dispatch
         
     | 
| 
      
 300 
     | 
    
         
            +
                    if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
      
 301 
     | 
    
         
            +
                        self.update = self.update_sliding_window
         
     | 
| 
      
 302 
     | 
    
         
            +
                    elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
      
 303 
     | 
    
         
            +
                        self.update = self.update_cross_attention
         
     | 
| 
      
 304 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 305 
     | 
    
         
            +
                        assert attn_backend.num_wrappers == 1
         
     | 
| 
      
 306 
     | 
    
         
            +
                        self.update = self.update_single_wrapper
         
     | 
| 
      
 307 
     | 
    
         
            +
             
     | 
| 
      
 308 
     | 
    
         
            +
                def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
         
     | 
| 
      
 309 
     | 
    
         
            +
                    decode_wrappers = decode_wrappers or self.decode_wrappers
         
     | 
| 
      
 310 
     | 
    
         
            +
                    self.call_begin_forward(
         
     | 
| 
      
 311 
     | 
    
         
            +
                        decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
         
     | 
| 
      
 312 
     | 
    
         
            +
                    )
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
                def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
         
     | 
| 
      
 315 
     | 
    
         
            +
                    decode_wrappers = decode_wrappers or self.decode_wrappers
         
     | 
| 
      
 316 
     | 
    
         
            +
             
     | 
| 
      
 317 
     | 
    
         
            +
                    for wrapper_id in range(2):
         
     | 
| 
      
 318 
     | 
    
         
            +
                        if wrapper_id == 0:
         
     | 
| 
      
 319 
     | 
    
         
            +
                            # Sliding window attention
         
     | 
| 
      
 320 
     | 
    
         
            +
                            paged_kernel_lens = torch.minimum(  # TODO: replace this with clamp
         
     | 
| 
      
 321 
     | 
    
         
            +
                                seq_lens,
         
     | 
| 
      
 322 
     | 
    
         
            +
                                torch.tensor(self.sliding_window_size + 1),
         
     | 
| 
      
 323 
     | 
    
         
            +
                            )
         
     | 
| 
      
 324 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 325 
     | 
    
         
            +
                            # Full attention
         
     | 
| 
      
 326 
     | 
    
         
            +
                            paged_kernel_lens = seq_lens
         
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
                        kv_start_idx = seq_lens - paged_kernel_lens
         
     | 
| 
      
 329 
     | 
    
         
            +
             
     | 
| 
      
 330 
     | 
    
         
            +
                        self.call_begin_forward(
         
     | 
| 
      
 331 
     | 
    
         
            +
                            decode_wrappers[wrapper_id],
         
     | 
| 
      
 332 
     | 
    
         
            +
                            req_pool_indices,
         
     | 
| 
      
 333 
     | 
    
         
            +
                            paged_kernel_lens,
         
     | 
| 
      
 334 
     | 
    
         
            +
                            self.kv_indptr[wrapper_id],
         
     | 
| 
      
 335 
     | 
    
         
            +
                            kv_start_idx,
         
     | 
| 
      
 336 
     | 
    
         
            +
                        )
         
     | 
| 
      
 337 
     | 
    
         
            +
             
     | 
| 
      
 338 
     | 
    
         
            +
                def update_cross_attention(self):
         
     | 
| 
      
 339 
     | 
    
         
            +
                    raise NotImplementedError()
         
     | 
| 
      
 340 
     | 
    
         
            +
             
     | 
| 
      
 341 
     | 
    
         
            +
                def call_begin_forward(
         
     | 
| 
      
 342 
     | 
    
         
            +
                    self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
         
     | 
| 
      
 343 
     | 
    
         
            +
                ):
         
     | 
| 
      
 344 
     | 
    
         
            +
                    bs = len(req_pool_indices)
         
     | 
| 
      
 345 
     | 
    
         
            +
                    kv_indptr = kv_indptr[: bs + 1]
         
     | 
| 
      
 346 
     | 
    
         
            +
                    # TODO: optimize the blocking call on kv_indptr[-1]
         
     | 
| 
      
 347 
     | 
    
         
            +
                    kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
      
 348 
     | 
    
         
            +
                    kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
         
     | 
| 
      
 349 
     | 
    
         
            +
             
     | 
| 
      
 350 
     | 
    
         
            +
                    create_flashinfer_kv_indices_triton[(bs,)](
         
     | 
| 
      
 351 
     | 
    
         
            +
                        self.req_to_token,
         
     | 
| 
      
 352 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 353 
     | 
    
         
            +
                        paged_kernel_lens,
         
     | 
| 
      
 354 
     | 
    
         
            +
                        kv_indptr,
         
     | 
| 
      
 355 
     | 
    
         
            +
                        kv_start_idx,
         
     | 
| 
      
 356 
     | 
    
         
            +
                        kv_indices,
         
     | 
| 
      
 357 
     | 
    
         
            +
                        self.max_context_len,
         
     | 
| 
      
 358 
     | 
    
         
            +
                    )
         
     | 
| 
      
 359 
     | 
    
         
            +
             
     | 
| 
      
 360 
     | 
    
         
            +
                    wrapper.end_forward()
         
     | 
| 
      
 361 
     | 
    
         
            +
                    wrapper.begin_forward(
         
     | 
| 
      
 362 
     | 
    
         
            +
                        kv_indptr,
         
     | 
| 
      
 363 
     | 
    
         
            +
                        kv_indices,
         
     | 
| 
      
 364 
     | 
    
         
            +
                        self.kv_last_page_len[:bs],
         
     | 
| 
      
 365 
     | 
    
         
            +
                        self.num_qo_heads,
         
     | 
| 
      
 366 
     | 
    
         
            +
                        self.num_kv_heads,
         
     | 
| 
      
 367 
     | 
    
         
            +
                        self.head_dim,
         
     | 
| 
      
 368 
     | 
    
         
            +
                        1,
         
     | 
| 
      
 369 
     | 
    
         
            +
                        data_type=self.data_type,
         
     | 
| 
      
 370 
     | 
    
         
            +
                        q_data_type=self.q_data_type,
         
     | 
| 
      
 371 
     | 
    
         
            +
                    )
         
     | 
| 
      
 372 
     | 
    
         
            +
             
     | 
| 
      
 373 
     | 
    
         
            +
             
     | 
| 
      
 374 
     | 
    
         
            +
            class FlashInferIndicesUpdaterPrefill:
         
     | 
| 
      
 375 
     | 
    
         
            +
                def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
         
     | 
| 
      
 376 
     | 
    
         
            +
                    # Constants
         
     | 
| 
      
 377 
     | 
    
         
            +
                    self.num_qo_heads = (
         
     | 
| 
      
 378 
     | 
    
         
            +
                        model_runner.model_config.num_attention_heads // model_runner.tp_size
         
     | 
| 
      
 379 
     | 
    
         
            +
                    )
         
     | 
| 
      
 380 
     | 
    
         
            +
                    self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
         
     | 
| 
      
 381 
     | 
    
         
            +
                        model_runner.tp_size
         
     | 
| 
      
 382 
     | 
    
         
            +
                    )
         
     | 
| 
      
 383 
     | 
    
         
            +
                    self.head_dim = model_runner.model_config.head_dim
         
     | 
| 
      
 384 
     | 
    
         
            +
                    self.data_type = model_runner.kv_cache_dtype
         
     | 
| 
      
 385 
     | 
    
         
            +
                    self.q_data_type = model_runner.dtype
         
     | 
| 
      
 386 
     | 
    
         
            +
                    self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
         
     | 
| 
      
 387 
     | 
    
         
            +
                    self.sliding_window_size = model_runner.sliding_window_size
         
     | 
| 
      
 388 
     | 
    
         
            +
             
     | 
| 
      
 389 
     | 
    
         
            +
                    # Buffers and wrappers
         
     | 
| 
      
 390 
     | 
    
         
            +
                    self.kv_indptr = attn_backend.kv_indptr
         
     | 
| 
      
 391 
     | 
    
         
            +
                    self.kv_last_page_len = attn_backend.kv_last_page_len
         
     | 
| 
      
 392 
     | 
    
         
            +
                    self.qo_indptr = attn_backend.qo_indptr
         
     | 
| 
      
 393 
     | 
    
         
            +
                    self.req_to_token = model_runner.req_to_token_pool.req_to_token
         
     | 
| 
      
 394 
     | 
    
         
            +
                    self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
         
     | 
| 
      
 395 
     | 
    
         
            +
                    self.wrappers_paged = attn_backend.prefill_wrappers_paged
         
     | 
| 
      
 396 
     | 
    
         
            +
             
     | 
| 
      
 397 
     | 
    
         
            +
                    # Dispatch
         
     | 
| 
      
 398 
     | 
    
         
            +
                    if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
      
 399 
     | 
    
         
            +
                        self.update = self.update_sliding_window
         
     | 
| 
      
 400 
     | 
    
         
            +
                    elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
      
 401 
     | 
    
         
            +
                        self.update = self.update_cross_attention
         
     | 
| 
      
 402 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 403 
     | 
    
         
            +
                        assert attn_backend.num_wrappers == 1
         
     | 
| 
      
 404 
     | 
    
         
            +
                        self.update = self.update_single_wrapper
         
     | 
| 
      
 405 
     | 
    
         
            +
             
     | 
| 
      
 406 
     | 
    
         
            +
                def update_single_wrapper(
         
     | 
| 
      
 407 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged
         
     | 
| 
      
 408 
     | 
    
         
            +
                ):
         
     | 
| 
      
 409 
     | 
    
         
            +
                    if use_ragged:
         
     | 
| 
      
 410 
     | 
    
         
            +
                        paged_kernel_lens = prefix_lens
         
     | 
| 
      
 411 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 412 
     | 
    
         
            +
                        paged_kernel_lens = seq_lens
         
     | 
| 
      
 413 
     | 
    
         
            +
             
     | 
| 
      
 414 
     | 
    
         
            +
                    self.call_begin_forward(
         
     | 
| 
      
 415 
     | 
    
         
            +
                        self.wrapper_ragged,
         
     | 
| 
      
 416 
     | 
    
         
            +
                        self.wrappers_paged[0],
         
     | 
| 
      
 417 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 418 
     | 
    
         
            +
                        paged_kernel_lens,
         
     | 
| 
      
 419 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 420 
     | 
    
         
            +
                        prefix_lens,
         
     | 
| 
      
 421 
     | 
    
         
            +
                        None,
         
     | 
| 
      
 422 
     | 
    
         
            +
                        self.kv_indptr[0],
         
     | 
| 
      
 423 
     | 
    
         
            +
                        self.qo_indptr[0],
         
     | 
| 
      
 424 
     | 
    
         
            +
                        use_ragged,
         
     | 
| 
      
 425 
     | 
    
         
            +
                    )
         
     | 
| 
      
 426 
     | 
    
         
            +
             
     | 
| 
      
 427 
     | 
    
         
            +
                def update_sliding_window(
         
     | 
| 
      
 428 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged
         
     | 
| 
      
 429 
     | 
    
         
            +
                ):
         
     | 
| 
      
 430 
     | 
    
         
            +
                    for wrapper_id in range(2):
         
     | 
| 
      
 431 
     | 
    
         
            +
                        if wrapper_id == 0:
         
     | 
| 
      
 432 
     | 
    
         
            +
                            # window attention use paged only
         
     | 
| 
      
 433 
     | 
    
         
            +
                            paged_kernel_lens = torch.minimum(
         
     | 
| 
      
 434 
     | 
    
         
            +
                                seq_lens,
         
     | 
| 
      
 435 
     | 
    
         
            +
                                torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
         
     | 
| 
      
 436 
     | 
    
         
            +
                            )
         
     | 
| 
      
 437 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 438 
     | 
    
         
            +
                            # full attention
         
     | 
| 
      
 439 
     | 
    
         
            +
                            paged_kernel_lens = seq_lens
         
     | 
| 
      
 440 
     | 
    
         
            +
                        kv_start_idx = seq_lens - paged_kernel_lens
         
     | 
| 
      
 441 
     | 
    
         
            +
             
     | 
| 
      
 442 
     | 
    
         
            +
                        self.call_begin_forward(
         
     | 
| 
      
 443 
     | 
    
         
            +
                            self.wrapper_ragged,
         
     | 
| 
      
 444 
     | 
    
         
            +
                            self.wrappers_paged[wrapper_id],
         
     | 
| 
      
 445 
     | 
    
         
            +
                            req_pool_indices,
         
     | 
| 
      
 446 
     | 
    
         
            +
                            paged_kernel_lens,
         
     | 
| 
      
 447 
     | 
    
         
            +
                            seq_lens,
         
     | 
| 
      
 448 
     | 
    
         
            +
                            prefix_lens,
         
     | 
| 
      
 449 
     | 
    
         
            +
                            kv_start_idx,
         
     | 
| 
      
 450 
     | 
    
         
            +
                            self.kv_indptr[wrapper_id],
         
     | 
| 
      
 451 
     | 
    
         
            +
                            self.qo_indptr[wrapper_id],
         
     | 
| 
      
 452 
     | 
    
         
            +
                            use_ragged,
         
     | 
| 
      
 453 
     | 
    
         
            +
                        )
         
     | 
| 
      
 454 
     | 
    
         
            +
             
     | 
| 
      
 455 
     | 
    
         
            +
                def update_cross_attention(self):
         
     | 
| 
      
 456 
     | 
    
         
            +
                    raise NotImplementedError()
         
     | 
| 
      
 457 
     | 
    
         
            +
             
     | 
| 
      
 458 
     | 
    
         
            +
                def call_begin_forward(
         
     | 
| 
      
 459 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 460 
     | 
    
         
            +
                    wrapper_ragged,
         
     | 
| 
      
 461 
     | 
    
         
            +
                    wrapper_paged,
         
     | 
| 
      
 462 
     | 
    
         
            +
                    req_pool_indices,
         
     | 
| 
      
 463 
     | 
    
         
            +
                    paged_kernel_lens,
         
     | 
| 
      
 464 
     | 
    
         
            +
                    seq_lens,
         
     | 
| 
      
 465 
     | 
    
         
            +
                    prefix_lens,
         
     | 
| 
      
 466 
     | 
    
         
            +
                    kv_start_idx,
         
     | 
| 
      
 467 
     | 
    
         
            +
                    kv_indptr,
         
     | 
| 
      
 468 
     | 
    
         
            +
                    qo_indptr,
         
     | 
| 
      
 469 
     | 
    
         
            +
                    use_ragged,
         
     | 
| 
      
 470 
     | 
    
         
            +
                ):
         
     | 
| 
      
 471 
     | 
    
         
            +
                    bs = len(req_pool_indices)
         
     | 
| 
      
 472 
     | 
    
         
            +
                    kv_indptr = kv_indptr[: bs + 1]
         
     | 
| 
      
 473 
     | 
    
         
            +
                    kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
      
 474 
     | 
    
         
            +
                    kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
         
     | 
| 
      
 475 
     | 
    
         
            +
                    create_flashinfer_kv_indices_triton[(bs,)](
         
     | 
| 
      
 476 
     | 
    
         
            +
                        self.req_to_token,
         
     | 
| 
      
 477 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 478 
     | 
    
         
            +
                        paged_kernel_lens,
         
     | 
| 
      
 479 
     | 
    
         
            +
                        kv_indptr,
         
     | 
| 
      
 480 
     | 
    
         
            +
                        kv_start_idx,
         
     | 
| 
      
 481 
     | 
    
         
            +
                        kv_indices,
         
     | 
| 
      
 482 
     | 
    
         
            +
                        self.max_context_len,
         
     | 
| 
      
 483 
     | 
    
         
            +
                    )
         
     | 
| 
      
 484 
     | 
    
         
            +
             
     | 
| 
      
 485 
     | 
    
         
            +
                    qo_indptr = qo_indptr[: bs + 1]
         
     | 
| 
      
 486 
     | 
    
         
            +
                    qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
         
     | 
| 
      
 487 
     | 
    
         
            +
             
     | 
| 
      
 488 
     | 
    
         
            +
                    # extend part
         
     | 
| 
      
 489 
     | 
    
         
            +
                    if use_ragged:
         
     | 
| 
      
 490 
     | 
    
         
            +
                        wrapper_ragged.end_forward()
         
     | 
| 
      
 491 
     | 
    
         
            +
                        wrapper_ragged.begin_forward(
         
     | 
| 
      
 492 
     | 
    
         
            +
                            qo_indptr,
         
     | 
| 
      
 493 
     | 
    
         
            +
                            qo_indptr,
         
     | 
| 
      
 494 
     | 
    
         
            +
                            self.num_qo_heads,
         
     | 
| 
      
 495 
     | 
    
         
            +
                            self.num_kv_heads,
         
     | 
| 
      
 496 
     | 
    
         
            +
                            self.head_dim,
         
     | 
| 
      
 497 
     | 
    
         
            +
                        )
         
     | 
| 
      
 498 
     | 
    
         
            +
             
     | 
| 
      
 499 
     | 
    
         
            +
                    # cached part
         
     | 
| 
      
 500 
     | 
    
         
            +
                    wrapper_paged.end_forward()
         
     | 
| 
      
 501 
     | 
    
         
            +
                    wrapper_paged.begin_forward(
         
     | 
| 
      
 502 
     | 
    
         
            +
                        qo_indptr,
         
     | 
| 
      
 503 
     | 
    
         
            +
                        kv_indptr,
         
     | 
| 
      
 504 
     | 
    
         
            +
                        kv_indices,
         
     | 
| 
      
 505 
     | 
    
         
            +
                        self.kv_last_page_len[:bs],
         
     | 
| 
      
 506 
     | 
    
         
            +
                        self.num_qo_heads,
         
     | 
| 
      
 507 
     | 
    
         
            +
                        self.num_kv_heads,
         
     | 
| 
      
 508 
     | 
    
         
            +
                        self.head_dim,
         
     | 
| 
      
 509 
     | 
    
         
            +
                        1,
         
     | 
| 
      
 510 
     | 
    
         
            +
                    )
         
     | 
| 
      
 511 
     | 
    
         
            +
             
     | 
| 
      
 512 
     | 
    
         
            +
             
     | 
| 
      
 513 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 514 
     | 
    
         
            +
            def create_flashinfer_kv_indices_triton(
         
     | 
| 
      
 515 
     | 
    
         
            +
                req_to_token_ptr,  # [max_batch, max_context_len]
         
     | 
| 
      
 516 
     | 
    
         
            +
                req_pool_indices_ptr,
         
     | 
| 
      
 517 
     | 
    
         
            +
                page_kernel_lens_ptr,
         
     | 
| 
      
 518 
     | 
    
         
            +
                kv_indptr,
         
     | 
| 
      
 519 
     | 
    
         
            +
                kv_start_idx,
         
     | 
| 
      
 520 
     | 
    
         
            +
                kv_indices_ptr,
         
     | 
| 
      
 521 
     | 
    
         
            +
                max_context_len: tl.constexpr,
         
     | 
| 
      
 522 
     | 
    
         
            +
            ):
         
     | 
| 
      
 523 
     | 
    
         
            +
                BLOCK_SIZE: tl.constexpr = 512
         
     | 
| 
      
 524 
     | 
    
         
            +
                pid = tl.program_id(axis=0)
         
     | 
| 
      
 525 
     | 
    
         
            +
                req_pool_index = tl.load(req_pool_indices_ptr + pid)
         
     | 
| 
      
 526 
     | 
    
         
            +
                kv_indices_offset = tl.load(kv_indptr + pid)
         
     | 
| 
      
 527 
     | 
    
         
            +
             
     | 
| 
      
 528 
     | 
    
         
            +
                kv_start = 0
         
     | 
| 
      
 529 
     | 
    
         
            +
                kv_end = 0
         
     | 
| 
      
 530 
     | 
    
         
            +
                if kv_start_idx:
         
     | 
| 
      
 531 
     | 
    
         
            +
                    kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
         
     | 
| 
      
 532 
     | 
    
         
            +
                    kv_end = kv_start
         
     | 
| 
      
 533 
     | 
    
         
            +
                kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
         
     | 
| 
      
 534 
     | 
    
         
            +
             
     | 
| 
      
 535 
     | 
    
         
            +
                req_to_token_ptr += req_pool_index * max_context_len
         
     | 
| 
      
 536 
     | 
    
         
            +
                kv_indices_ptr += kv_indices_offset
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
      
 538 
     | 
    
         
            +
                ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
         
     | 
| 
      
 539 
     | 
    
         
            +
                st_offset = tl.arange(0, BLOCK_SIZE)
         
     | 
| 
      
 540 
     | 
    
         
            +
                num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
         
     | 
| 
      
 541 
     | 
    
         
            +
                for _ in range(num_loop):
         
     | 
| 
      
 542 
     | 
    
         
            +
                    mask = ld_offset < kv_end
         
     | 
| 
      
 543 
     | 
    
         
            +
                    data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
         
     | 
| 
      
 544 
     | 
    
         
            +
                    tl.store(kv_indices_ptr + st_offset, data, mask=mask)
         
     | 
| 
      
 545 
     | 
    
         
            +
                    ld_offset += BLOCK_SIZE
         
     | 
| 
      
 546 
     | 
    
         
            +
                    st_offset += BLOCK_SIZE
         
     | 
| 
         @@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       40 
40 
     | 
    
         | 
| 
       41 
41 
     | 
    
         
             
                    self.cuda_graph_max_seq_len = model_runner.model_config.context_len
         
     | 
| 
       42 
42 
     | 
    
         | 
| 
      
 43 
     | 
    
         
            +
                    self.device = model_runner.device
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
       43 
45 
     | 
    
         
             
                def init_forward_metadata(self, forward_batch: ForwardBatch):
         
     | 
| 
       44 
46 
     | 
    
         
             
                    """Init auxiliary variables for triton attention backend."""
         
     | 
| 
       45 
47 
     | 
    
         | 
| 
         @@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       51 
53 
     | 
    
         
             
                        attn_logits = torch.empty(
         
     | 
| 
       52 
54 
     | 
    
         
             
                            (self.num_head, total_num_tokens),
         
     | 
| 
       53 
55 
     | 
    
         
             
                            dtype=self.reduce_dtype,
         
     | 
| 
       54 
     | 
    
         
            -
                            device= 
     | 
| 
      
 56 
     | 
    
         
            +
                            device=self.device,
         
     | 
| 
       55 
57 
     | 
    
         
             
                        )
         
     | 
| 
       56 
58 
     | 
    
         | 
| 
       57 
59 
     | 
    
         
             
                        max_seq_len = torch.max(forward_batch.seq_lens).item()
         
     | 
| 
         @@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       67 
69 
     | 
    
         
             
                    self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
         
     | 
| 
       68 
70 
     | 
    
         | 
| 
       69 
71 
     | 
    
         
             
                    self.cuda_graph_start_loc = torch.zeros(
         
     | 
| 
       70 
     | 
    
         
            -
                        (max_bs,), dtype=torch.int32, device= 
     | 
| 
      
 72 
     | 
    
         
            +
                        (max_bs,), dtype=torch.int32, device=self.device
         
     | 
| 
       71 
73 
     | 
    
         
             
                    )
         
     | 
| 
       72 
74 
     | 
    
         
             
                    self.cuda_graph_attn_logits = torch.empty(
         
     | 
| 
       73 
75 
     | 
    
         
             
                        (
         
     | 
| 
         @@ -79,7 +81,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       79 
81 
     | 
    
         
             
                    )
         
     | 
| 
       80 
82 
     | 
    
         | 
| 
       81 
83 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       82 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       83 
85 
     | 
    
         
             
                ):
         
     | 
| 
       84 
86 
     | 
    
         
             
                    self.forward_metadata = (
         
     | 
| 
       85 
87 
     | 
    
         
             
                        self.cuda_graph_start_loc,
         
     | 
| 
         @@ -89,7 +91,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       89 
91 
     | 
    
         
             
                    )
         
     | 
| 
       90 
92 
     | 
    
         | 
| 
       91 
93 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       92 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 94 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       93 
95 
     | 
    
         
             
                ):
         
     | 
| 
       94 
96 
     | 
    
         
             
                    self.cuda_graph_start_loc.zero_()
         
     | 
| 
       95 
97 
     | 
    
         
             
                    self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
         
     |