sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -11,7 +11,6 @@ from enum import Enum, auto 
     | 
|
| 
       11 
11 
     | 
    
         
             
            from typing import TYPE_CHECKING
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
            import torch
         
     | 
| 
       14 
     | 
    
         
            -
            import torch.nn as nn
         
     | 
| 
       15 
14 
     | 
    
         
             
            import triton
         
     | 
| 
       16 
15 
     | 
    
         
             
            import triton.language as tl
         
     | 
| 
       17 
16 
     | 
    
         | 
| 
         @@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch 
     | 
|
| 
       21 
20 
     | 
    
         
             
            from sglang.srt.utils import is_flashinfer_available
         
     | 
| 
       22 
21 
     | 
    
         | 
| 
       23 
22 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
      
 23 
     | 
    
         
            +
                from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       24 
24 
     | 
    
         
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       25 
25 
     | 
    
         | 
| 
       26 
26 
     | 
    
         
             
            if is_flashinfer_available():
         
     | 
| 
         @@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       56 
56 
     | 
    
         | 
| 
       57 
57 
     | 
    
         
             
                    assert not (
         
     | 
| 
       58 
58 
     | 
    
         
             
                        model_runner.sliding_window_size is not None
         
     | 
| 
       59 
     | 
    
         
            -
                        and model_runner. 
     | 
| 
      
 59 
     | 
    
         
            +
                        and model_runner.model_config.is_encoder_decoder
         
     | 
| 
       60 
60 
     | 
    
         
             
                    ), "Sliding window and cross attention are not supported together"
         
     | 
| 
       61 
61 
     | 
    
         | 
| 
       62 
62 
     | 
    
         
             
                    if model_runner.sliding_window_size is not None:
         
     | 
| 
       63 
63 
     | 
    
         
             
                        self.num_wrappers = 2
         
     | 
| 
       64 
64 
     | 
    
         
             
                        self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
         
     | 
| 
       65 
     | 
    
         
            -
                    elif model_runner. 
     | 
| 
      
 65 
     | 
    
         
            +
                    elif model_runner.model_config.is_encoder_decoder:
         
     | 
| 
       66 
66 
     | 
    
         
             
                        self.num_wrappers = 2
         
     | 
| 
       67 
67 
     | 
    
         
             
                        self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
         
     | 
| 
       68 
68 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -127,6 +127,9 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       127 
127 
     | 
    
         
             
                        self.indices_updater_decode.update(
         
     | 
| 
       128 
128 
     | 
    
         
             
                            forward_batch.req_pool_indices,
         
     | 
| 
       129 
129 
     | 
    
         
             
                            forward_batch.seq_lens,
         
     | 
| 
      
 130 
     | 
    
         
            +
                            forward_batch.seq_lens_sum,
         
     | 
| 
      
 131 
     | 
    
         
            +
                            decode_wrappers=None,
         
     | 
| 
      
 132 
     | 
    
         
            +
                            encoder_lens=forward_batch.encoder_lens,
         
     | 
| 
       130 
133 
     | 
    
         
             
                        )
         
     | 
| 
       131 
134 
     | 
    
         
             
                        self.forward_metadata = (self.decode_wrappers,)
         
     | 
| 
       132 
135 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -134,10 +137,7 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       134 
137 
     | 
    
         | 
| 
       135 
138 
     | 
    
         
             
                        # Some heuristics to check whether to use ragged forward
         
     | 
| 
       136 
139 
     | 
    
         
             
                        use_ragged = False
         
     | 
| 
       137 
     | 
    
         
            -
                        if  
     | 
| 
       138 
     | 
    
         
            -
                            torch.sum(forward_batch.seq_lens).item() >= 4096
         
     | 
| 
       139 
     | 
    
         
            -
                            and self.num_wrappers == 1
         
     | 
| 
       140 
     | 
    
         
            -
                        ):
         
     | 
| 
      
 140 
     | 
    
         
            +
                        if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
         
     | 
| 
       141 
141 
     | 
    
         
             
                            use_ragged = True
         
     | 
| 
       142 
142 
     | 
    
         | 
| 
       143 
143 
     | 
    
         
             
                        extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
         
     | 
| 
         @@ -146,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       146 
146 
     | 
    
         
             
                            forward_batch.req_pool_indices,
         
     | 
| 
       147 
147 
     | 
    
         
             
                            forward_batch.seq_lens,
         
     | 
| 
       148 
148 
     | 
    
         
             
                            prefix_lens,
         
     | 
| 
       149 
     | 
    
         
            -
                            use_ragged,
         
     | 
| 
      
 149 
     | 
    
         
            +
                            use_ragged=use_ragged,
         
     | 
| 
      
 150 
     | 
    
         
            +
                            encoder_lens=forward_batch.encoder_lens,
         
     | 
| 
       150 
151 
     | 
    
         
             
                        )
         
     | 
| 
       151 
152 
     | 
    
         | 
| 
       152 
     | 
    
         
            -
                        self.forward_metadata = (
         
     | 
| 
       153 
     | 
    
         
            -
                            use_ragged,
         
     | 
| 
       154 
     | 
    
         
            -
                            extend_no_prefix,
         
     | 
| 
       155 
     | 
    
         
            -
                        )
         
     | 
| 
      
 153 
     | 
    
         
            +
                        self.forward_metadata = (use_ragged, extend_no_prefix)
         
     | 
| 
       156 
154 
     | 
    
         | 
| 
       157 
155 
     | 
    
         
             
                def init_cuda_graph_state(self, max_bs: int):
         
     | 
| 
       158 
156 
     | 
    
         
             
                    cuda_graph_kv_indices = torch.zeros(
         
     | 
| 
         @@ -165,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       165 
163 
     | 
    
         
             
                    ]
         
     | 
| 
       166 
164 
     | 
    
         | 
| 
       167 
165 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       168 
     | 
    
         
            -
                    self, 
     | 
| 
      
 166 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 167 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 168 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 169 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 170 
     | 
    
         
            +
                    encoder_lens: torch.Tensor = None,
         
     | 
| 
       169 
171 
     | 
    
         
             
                ):
         
     | 
| 
       170 
172 
     | 
    
         
             
                    decode_wrappers = []
         
     | 
| 
       171 
173 
     | 
    
         
             
                    for i in range(self.num_wrappers):
         
     | 
| 
         @@ -181,37 +183,59 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       181 
183 
     | 
    
         
             
                            )
         
     | 
| 
       182 
184 
     | 
    
         
             
                        )
         
     | 
| 
       183 
185 
     | 
    
         | 
| 
       184 
     | 
    
         
            -
                     
     | 
| 
      
 186 
     | 
    
         
            +
                    seq_lens_sum = seq_lens.sum().item()
         
     | 
| 
      
 187 
     | 
    
         
            +
                    self.indices_updater_decode.update(
         
     | 
| 
      
 188 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 189 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 190 
     | 
    
         
            +
                        seq_lens_sum,
         
     | 
| 
      
 191 
     | 
    
         
            +
                        decode_wrappers=decode_wrappers,
         
     | 
| 
      
 192 
     | 
    
         
            +
                        encoder_lens=encoder_lens,
         
     | 
| 
      
 193 
     | 
    
         
            +
                    )
         
     | 
| 
       185 
194 
     | 
    
         
             
                    self.cuda_graph_metadata[bs] = decode_wrappers
         
     | 
| 
       186 
195 
     | 
    
         
             
                    self.forward_metadata = (decode_wrappers,)
         
     | 
| 
       187 
196 
     | 
    
         | 
| 
       188 
197 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       189 
     | 
    
         
            -
                    self, 
     | 
| 
      
 198 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 199 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 200 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 201 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 202 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 203 
     | 
    
         
            +
                    encoder_lens: torch.Tensor = None,
         
     | 
| 
       190 
204 
     | 
    
         
             
                ):
         
     | 
| 
       191 
205 
     | 
    
         
             
                    self.indices_updater_decode.update(
         
     | 
| 
       192 
     | 
    
         
            -
                        req_pool_indices[:bs], 
     | 
| 
      
 206 
     | 
    
         
            +
                        req_pool_indices[:bs],
         
     | 
| 
      
 207 
     | 
    
         
            +
                        seq_lens[:bs],
         
     | 
| 
      
 208 
     | 
    
         
            +
                        seq_lens_sum,
         
     | 
| 
      
 209 
     | 
    
         
            +
                        decode_wrappers=self.cuda_graph_metadata[bs],
         
     | 
| 
      
 210 
     | 
    
         
            +
                        encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
         
     | 
| 
       193 
211 
     | 
    
         
             
                    )
         
     | 
| 
       194 
212 
     | 
    
         | 
| 
       195 
213 
     | 
    
         
             
                def get_cuda_graph_seq_len_fill_value(self):
         
     | 
| 
       196 
214 
     | 
    
         
             
                    return 0
         
     | 
| 
       197 
215 
     | 
    
         | 
| 
       198 
     | 
    
         
            -
                def forward_extend( 
     | 
| 
      
 216 
     | 
    
         
            +
                def forward_extend(
         
     | 
| 
      
 217 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 218 
     | 
    
         
            +
                ):
         
     | 
| 
       199 
219 
     | 
    
         
             
                    prefill_wrapper_paged = self.prefill_wrappers_paged[
         
     | 
| 
       200 
220 
     | 
    
         
             
                        self._get_wrapper_idx(layer)
         
     | 
| 
       201 
221 
     | 
    
         
             
                    ]
         
     | 
| 
       202 
222 
     | 
    
         | 
| 
       203 
223 
     | 
    
         
             
                    use_ragged, extend_no_prefix = self.forward_metadata
         
     | 
| 
      
 224 
     | 
    
         
            +
                    cache_loc = (
         
     | 
| 
      
 225 
     | 
    
         
            +
                        forward_batch.out_cache_loc
         
     | 
| 
      
 226 
     | 
    
         
            +
                        if not layer.is_cross_attention
         
     | 
| 
      
 227 
     | 
    
         
            +
                        else forward_batch.encoder_out_cache_loc
         
     | 
| 
      
 228 
     | 
    
         
            +
                    )
         
     | 
| 
       204 
229 
     | 
    
         | 
| 
       205 
230 
     | 
    
         
             
                    if not use_ragged:
         
     | 
| 
       206 
231 
     | 
    
         
             
                        if k is not None:
         
     | 
| 
       207 
232 
     | 
    
         
             
                            assert v is not None
         
     | 
| 
       208 
     | 
    
         
            -
                            forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       209 
     | 
    
         
            -
             
     | 
| 
       210 
     | 
    
         
            -
                            )
         
     | 
| 
      
 233 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
         
     | 
| 
      
 234 
     | 
    
         
            +
             
     | 
| 
       211 
235 
     | 
    
         
             
                        o = prefill_wrapper_paged.forward(
         
     | 
| 
       212 
236 
     | 
    
         
             
                            q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
         
     | 
| 
       213 
237 
     | 
    
         
             
                            forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
         
     | 
| 
       214 
     | 
    
         
            -
                            causal= 
     | 
| 
      
 238 
     | 
    
         
            +
                            causal=not layer.is_cross_attention,
         
     | 
| 
       215 
239 
     | 
    
         
             
                            sm_scale=layer.scaling,
         
     | 
| 
       216 
240 
     | 
    
         
             
                            window_left=layer.sliding_window_size,
         
     | 
| 
       217 
241 
     | 
    
         
             
                            logits_soft_cap=layer.logit_cap,
         
     | 
| 
         @@ -239,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       239 
263 
     | 
    
         | 
| 
       240 
264 
     | 
    
         
             
                            o, _ = merge_state(o1, s1, o2, s2)
         
     | 
| 
       241 
265 
     | 
    
         | 
| 
       242 
     | 
    
         
            -
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       243 
     | 
    
         
            -
                            layer.layer_id, forward_batch.out_cache_loc, k, v
         
     | 
| 
       244 
     | 
    
         
            -
                        )
         
     | 
| 
      
 266 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
         
     | 
| 
       245 
267 
     | 
    
         | 
| 
       246 
268 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
       247 
269 
     | 
    
         | 
| 
       248 
     | 
    
         
            -
                def forward_decode( 
     | 
| 
      
 270 
     | 
    
         
            +
                def forward_decode(
         
     | 
| 
      
 271 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 272 
     | 
    
         
            +
                ):
         
     | 
| 
       249 
273 
     | 
    
         
             
                    decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
         
     | 
| 
      
 274 
     | 
    
         
            +
                    cache_loc = (
         
     | 
| 
      
 275 
     | 
    
         
            +
                        forward_batch.out_cache_loc
         
     | 
| 
      
 276 
     | 
    
         
            +
                        if not layer.is_cross_attention
         
     | 
| 
      
 277 
     | 
    
         
            +
                        else forward_batch.encoder_out_cache_loc
         
     | 
| 
      
 278 
     | 
    
         
            +
                    )
         
     | 
| 
       250 
279 
     | 
    
         | 
| 
       251 
280 
     | 
    
         
             
                    if k is not None:
         
     | 
| 
       252 
281 
     | 
    
         
             
                        assert v is not None
         
     | 
| 
       253 
     | 
    
         
            -
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       254 
     | 
    
         
            -
                            layer.layer_id, forward_batch.out_cache_loc, k, v
         
     | 
| 
       255 
     | 
    
         
            -
                        )
         
     | 
| 
      
 282 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
         
     | 
| 
       256 
283 
     | 
    
         | 
| 
       257 
284 
     | 
    
         
             
                    o = decode_wrapper.forward(
         
     | 
| 
       258 
285 
     | 
    
         
             
                        q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
         
     | 
| 
         @@ -263,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       263 
290 
     | 
    
         | 
| 
       264 
291 
     | 
    
         
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         
     | 
| 
       265 
292 
     | 
    
         | 
| 
       266 
     | 
    
         
            -
                def _get_wrapper_idx(self, layer:  
     | 
| 
      
 293 
     | 
    
         
            +
                def _get_wrapper_idx(self, layer: RadixAttention):
         
     | 
| 
       267 
294 
     | 
    
         
             
                    if self.num_wrappers == 1:
         
     | 
| 
       268 
295 
     | 
    
         
             
                        return 0
         
     | 
| 
       269 
296 
     | 
    
         | 
| 
         @@ -290,6 +317,8 @@ class FlashInferIndicesUpdaterDecode: 
     | 
|
| 
       290 
317 
     | 
    
         
             
                    self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
         
     | 
| 
       291 
318 
     | 
    
         
             
                    self.sliding_window_size = model_runner.sliding_window_size
         
     | 
| 
       292 
319 
     | 
    
         | 
| 
      
 320 
     | 
    
         
            +
                    self.attn_backend = attn_backend
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
       293 
322 
     | 
    
         
             
                    # Buffers and wrappers
         
     | 
| 
       294 
323 
     | 
    
         
             
                    self.kv_indptr = attn_backend.kv_indptr
         
     | 
| 
       295 
324 
     | 
    
         
             
                    self.kv_last_page_len = attn_backend.kv_last_page_len
         
     | 
| 
         @@ -297,55 +326,117 @@ class FlashInferIndicesUpdaterDecode: 
     | 
|
| 
       297 
326 
     | 
    
         
             
                    self.decode_wrappers = attn_backend.decode_wrappers
         
     | 
| 
       298 
327 
     | 
    
         | 
| 
       299 
328 
     | 
    
         
             
                    # Dispatch
         
     | 
| 
       300 
     | 
    
         
            -
                    if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
      
 329 
     | 
    
         
            +
                    if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
       301 
330 
     | 
    
         
             
                        self.update = self.update_sliding_window
         
     | 
| 
       302 
     | 
    
         
            -
                    elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
      
 331 
     | 
    
         
            +
                    elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
       303 
332 
     | 
    
         
             
                        self.update = self.update_cross_attention
         
     | 
| 
       304 
333 
     | 
    
         
             
                    else:
         
     | 
| 
       305 
     | 
    
         
            -
                        assert attn_backend.num_wrappers == 1
         
     | 
| 
      
 334 
     | 
    
         
            +
                        assert self.attn_backend.num_wrappers == 1
         
     | 
| 
       306 
335 
     | 
    
         
             
                        self.update = self.update_single_wrapper
         
     | 
| 
       307 
336 
     | 
    
         | 
| 
       308 
     | 
    
         
            -
                def  
     | 
| 
      
 337 
     | 
    
         
            +
                def update(
         
     | 
| 
      
 338 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
         
     | 
| 
      
 339 
     | 
    
         
            +
                ):
         
     | 
| 
      
 340 
     | 
    
         
            +
                    # Keep the signature for type checking. It will be assigned during runtime.
         
     | 
| 
      
 341 
     | 
    
         
            +
                    raise NotImplementedError()
         
     | 
| 
      
 342 
     | 
    
         
            +
             
     | 
| 
      
 343 
     | 
    
         
            +
                def update_single_wrapper(
         
     | 
| 
      
 344 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 345 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 346 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 347 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 348 
     | 
    
         
            +
                    decode_wrappers=None,
         
     | 
| 
      
 349 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
      
 350 
     | 
    
         
            +
                ):
         
     | 
| 
       309 
351 
     | 
    
         
             
                    decode_wrappers = decode_wrappers or self.decode_wrappers
         
     | 
| 
       310 
352 
     | 
    
         
             
                    self.call_begin_forward(
         
     | 
| 
       311 
     | 
    
         
            -
                        decode_wrappers[0], 
     | 
| 
      
 353 
     | 
    
         
            +
                        decode_wrappers[0],
         
     | 
| 
      
 354 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 355 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 356 
     | 
    
         
            +
                        seq_lens_sum,
         
     | 
| 
      
 357 
     | 
    
         
            +
                        self.kv_indptr[0],
         
     | 
| 
      
 358 
     | 
    
         
            +
                        None,
         
     | 
| 
       312 
359 
     | 
    
         
             
                    )
         
     | 
| 
       313 
360 
     | 
    
         | 
| 
       314 
     | 
    
         
            -
                def update_sliding_window( 
     | 
| 
      
 361 
     | 
    
         
            +
                def update_sliding_window(
         
     | 
| 
      
 362 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 363 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 364 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 365 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 366 
     | 
    
         
            +
                    decode_wrappers=None,
         
     | 
| 
      
 367 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
      
 368 
     | 
    
         
            +
                ):
         
     | 
| 
       315 
369 
     | 
    
         
             
                    decode_wrappers = decode_wrappers or self.decode_wrappers
         
     | 
| 
       316 
370 
     | 
    
         | 
| 
       317 
371 
     | 
    
         
             
                    for wrapper_id in range(2):
         
     | 
| 
       318 
372 
     | 
    
         
             
                        if wrapper_id == 0:
         
     | 
| 
       319 
373 
     | 
    
         
             
                            # Sliding window attention
         
     | 
| 
       320 
     | 
    
         
            -
                             
     | 
| 
      
 374 
     | 
    
         
            +
                            paged_kernel_lens_tmp = torch.minimum(  # TODO: replace this with clamp
         
     | 
| 
       321 
375 
     | 
    
         
             
                                seq_lens,
         
     | 
| 
       322 
376 
     | 
    
         
             
                                torch.tensor(self.sliding_window_size + 1),
         
     | 
| 
       323 
377 
     | 
    
         
             
                            )
         
     | 
| 
      
 378 
     | 
    
         
            +
                            paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
         
     | 
| 
      
 379 
     | 
    
         
            +
                            kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
         
     | 
| 
       324 
380 
     | 
    
         
             
                        else:
         
     | 
| 
       325 
381 
     | 
    
         
             
                            # Full attention
         
     | 
| 
       326 
     | 
    
         
            -
                             
     | 
| 
      
 382 
     | 
    
         
            +
                            paged_kernel_lens_tmp = seq_lens
         
     | 
| 
      
 383 
     | 
    
         
            +
                            paged_kernel_lens_sum_tmp = seq_lens_sum
         
     | 
| 
      
 384 
     | 
    
         
            +
                            kv_start_idx_tmp = None
         
     | 
| 
       327 
385 
     | 
    
         | 
| 
       328 
     | 
    
         
            -
                         
     | 
| 
      
 386 
     | 
    
         
            +
                        self.call_begin_forward(
         
     | 
| 
      
 387 
     | 
    
         
            +
                            decode_wrappers[wrapper_id],
         
     | 
| 
      
 388 
     | 
    
         
            +
                            req_pool_indices,
         
     | 
| 
      
 389 
     | 
    
         
            +
                            paged_kernel_lens_tmp,
         
     | 
| 
      
 390 
     | 
    
         
            +
                            paged_kernel_lens_sum_tmp,
         
     | 
| 
      
 391 
     | 
    
         
            +
                            self.kv_indptr[wrapper_id],
         
     | 
| 
      
 392 
     | 
    
         
            +
                            kv_start_idx_tmp,
         
     | 
| 
      
 393 
     | 
    
         
            +
                        )
         
     | 
| 
      
 394 
     | 
    
         
            +
             
     | 
| 
      
 395 
     | 
    
         
            +
                def update_cross_attention(
         
     | 
| 
      
 396 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 397 
     | 
    
         
            +
                    req_pool_indices,
         
     | 
| 
      
 398 
     | 
    
         
            +
                    seq_lens,
         
     | 
| 
      
 399 
     | 
    
         
            +
                    seq_lens_sum,
         
     | 
| 
      
 400 
     | 
    
         
            +
                    decode_wrappers=None,
         
     | 
| 
      
 401 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
      
 402 
     | 
    
         
            +
                ):
         
     | 
| 
      
 403 
     | 
    
         
            +
                    decode_wrappers = decode_wrappers or self.decode_wrappers
         
     | 
| 
      
 404 
     | 
    
         
            +
             
     | 
| 
      
 405 
     | 
    
         
            +
                    for wrapper_id in range(2):
         
     | 
| 
      
 406 
     | 
    
         
            +
                        if wrapper_id == 0:
         
     | 
| 
      
 407 
     | 
    
         
            +
                            # Normal attention
         
     | 
| 
      
 408 
     | 
    
         
            +
                            paged_kernel_lens = seq_lens
         
     | 
| 
      
 409 
     | 
    
         
            +
                            kv_start_idx = encoder_lens
         
     | 
| 
      
 410 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 411 
     | 
    
         
            +
                            # Cross attention
         
     | 
| 
      
 412 
     | 
    
         
            +
                            paged_kernel_lens = encoder_lens
         
     | 
| 
      
 413 
     | 
    
         
            +
                            kv_start_idx = torch.zeros_like(encoder_lens)
         
     | 
| 
      
 414 
     | 
    
         
            +
                            seq_lens_sum = encoder_lens.sum().item()
         
     | 
| 
       329 
415 
     | 
    
         | 
| 
       330 
416 
     | 
    
         
             
                        self.call_begin_forward(
         
     | 
| 
       331 
417 
     | 
    
         
             
                            decode_wrappers[wrapper_id],
         
     | 
| 
       332 
418 
     | 
    
         
             
                            req_pool_indices,
         
     | 
| 
       333 
419 
     | 
    
         
             
                            paged_kernel_lens,
         
     | 
| 
      
 420 
     | 
    
         
            +
                            seq_lens_sum,
         
     | 
| 
       334 
421 
     | 
    
         
             
                            self.kv_indptr[wrapper_id],
         
     | 
| 
       335 
422 
     | 
    
         
             
                            kv_start_idx,
         
     | 
| 
       336 
423 
     | 
    
         
             
                        )
         
     | 
| 
       337 
424 
     | 
    
         | 
| 
       338 
     | 
    
         
            -
                def update_cross_attention(self):
         
     | 
| 
       339 
     | 
    
         
            -
                    raise NotImplementedError()
         
     | 
| 
       340 
     | 
    
         
            -
             
     | 
| 
       341 
425 
     | 
    
         
             
                def call_begin_forward(
         
     | 
| 
       342 
     | 
    
         
            -
                    self, 
     | 
| 
      
 426 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 427 
     | 
    
         
            +
                    wrapper,
         
     | 
| 
      
 428 
     | 
    
         
            +
                    req_pool_indices,
         
     | 
| 
      
 429 
     | 
    
         
            +
                    paged_kernel_lens,
         
     | 
| 
      
 430 
     | 
    
         
            +
                    paged_kernel_lens_sum,
         
     | 
| 
      
 431 
     | 
    
         
            +
                    kv_indptr,
         
     | 
| 
      
 432 
     | 
    
         
            +
                    kv_start_idx,
         
     | 
| 
       343 
433 
     | 
    
         
             
                ):
         
     | 
| 
       344 
434 
     | 
    
         
             
                    bs = len(req_pool_indices)
         
     | 
| 
      
 435 
     | 
    
         
            +
                    kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
       345 
436 
     | 
    
         
             
                    kv_indptr = kv_indptr[: bs + 1]
         
     | 
| 
       346 
     | 
    
         
            -
                     
     | 
| 
       347 
     | 
    
         
            -
             
     | 
| 
       348 
     | 
    
         
            -
                     
     | 
| 
      
 437 
     | 
    
         
            +
                    kv_indices = torch.empty(
         
     | 
| 
      
 438 
     | 
    
         
            +
                        paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
         
     | 
| 
      
 439 
     | 
    
         
            +
                    )
         
     | 
| 
       349 
440 
     | 
    
         | 
| 
       350 
441 
     | 
    
         
             
                    create_flashinfer_kv_indices_triton[(bs,)](
         
     | 
| 
       351 
442 
     | 
    
         
             
                        self.req_to_token,
         
     | 
| 
         @@ -386,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       386 
477 
     | 
    
         
             
                    self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
         
     | 
| 
       387 
478 
     | 
    
         
             
                    self.sliding_window_size = model_runner.sliding_window_size
         
     | 
| 
       388 
479 
     | 
    
         | 
| 
      
 480 
     | 
    
         
            +
                    self.attn_backend = attn_backend
         
     | 
| 
      
 481 
     | 
    
         
            +
             
     | 
| 
       389 
482 
     | 
    
         
             
                    # Buffers and wrappers
         
     | 
| 
       390 
483 
     | 
    
         
             
                    self.kv_indptr = attn_backend.kv_indptr
         
     | 
| 
       391 
484 
     | 
    
         
             
                    self.kv_last_page_len = attn_backend.kv_last_page_len
         
     | 
| 
         @@ -395,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       395 
488 
     | 
    
         
             
                    self.wrappers_paged = attn_backend.prefill_wrappers_paged
         
     | 
| 
       396 
489 
     | 
    
         | 
| 
       397 
490 
     | 
    
         
             
                    # Dispatch
         
     | 
| 
       398 
     | 
    
         
            -
                    if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
      
 491 
     | 
    
         
            +
                    if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
         
     | 
| 
       399 
492 
     | 
    
         
             
                        self.update = self.update_sliding_window
         
     | 
| 
       400 
     | 
    
         
            -
                    elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
      
 493 
     | 
    
         
            +
                    elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
         
     | 
| 
       401 
494 
     | 
    
         
             
                        self.update = self.update_cross_attention
         
     | 
| 
       402 
495 
     | 
    
         
             
                    else:
         
     | 
| 
       403 
     | 
    
         
            -
                        assert attn_backend.num_wrappers == 1
         
     | 
| 
      
 496 
     | 
    
         
            +
                        assert self.attn_backend.num_wrappers == 1
         
     | 
| 
       404 
497 
     | 
    
         
             
                        self.update = self.update_single_wrapper
         
     | 
| 
       405 
498 
     | 
    
         | 
| 
      
 499 
     | 
    
         
            +
                def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
         
     | 
| 
      
 500 
     | 
    
         
            +
                    # Keep the signature for type checking. It will be assigned during runtime.
         
     | 
| 
      
 501 
     | 
    
         
            +
                    raise NotImplementedError()
         
     | 
| 
      
 502 
     | 
    
         
            +
             
     | 
| 
       406 
503 
     | 
    
         
             
                def update_single_wrapper(
         
     | 
| 
       407 
     | 
    
         
            -
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged
         
     | 
| 
      
 504 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
         
     | 
| 
       408 
505 
     | 
    
         
             
                ):
         
     | 
| 
       409 
506 
     | 
    
         
             
                    if use_ragged:
         
     | 
| 
       410 
507 
     | 
    
         
             
                        paged_kernel_lens = prefix_lens
         
     | 
| 
         @@ -425,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       425 
522 
     | 
    
         
             
                    )
         
     | 
| 
       426 
523 
     | 
    
         | 
| 
       427 
524 
     | 
    
         
             
                def update_sliding_window(
         
     | 
| 
       428 
     | 
    
         
            -
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged
         
     | 
| 
      
 525 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
         
     | 
| 
       429 
526 
     | 
    
         
             
                ):
         
     | 
| 
       430 
527 
     | 
    
         
             
                    for wrapper_id in range(2):
         
     | 
| 
       431 
528 
     | 
    
         
             
                        if wrapper_id == 0:
         
     | 
| 
         @@ -452,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       452 
549 
     | 
    
         
             
                            use_ragged,
         
     | 
| 
       453 
550 
     | 
    
         
             
                        )
         
     | 
| 
       454 
551 
     | 
    
         | 
| 
       455 
     | 
    
         
            -
                def update_cross_attention( 
     | 
| 
       456 
     | 
    
         
            -
                     
     | 
| 
      
 552 
     | 
    
         
            +
                def update_cross_attention(
         
     | 
| 
      
 553 
     | 
    
         
            +
                    self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
         
     | 
| 
      
 554 
     | 
    
         
            +
                ):
         
     | 
| 
      
 555 
     | 
    
         
            +
                    for wrapper_id in range(2):
         
     | 
| 
      
 556 
     | 
    
         
            +
                        if wrapper_id == 0:
         
     | 
| 
      
 557 
     | 
    
         
            +
                            # normal attention
         
     | 
| 
      
 558 
     | 
    
         
            +
                            paged_kernel_lens = seq_lens
         
     | 
| 
      
 559 
     | 
    
         
            +
                            kv_start_idx = encoder_lens
         
     | 
| 
      
 560 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 561 
     | 
    
         
            +
                            # cross attention
         
     | 
| 
      
 562 
     | 
    
         
            +
                            paged_kernel_lens = encoder_lens
         
     | 
| 
      
 563 
     | 
    
         
            +
                            kv_start_idx = torch.zeros_like(encoder_lens)
         
     | 
| 
      
 564 
     | 
    
         
            +
             
     | 
| 
      
 565 
     | 
    
         
            +
                        self.call_begin_forward(
         
     | 
| 
      
 566 
     | 
    
         
            +
                            self.wrapper_ragged,
         
     | 
| 
      
 567 
     | 
    
         
            +
                            self.wrappers_paged[wrapper_id],
         
     | 
| 
      
 568 
     | 
    
         
            +
                            req_pool_indices,
         
     | 
| 
      
 569 
     | 
    
         
            +
                            paged_kernel_lens,
         
     | 
| 
      
 570 
     | 
    
         
            +
                            seq_lens,
         
     | 
| 
      
 571 
     | 
    
         
            +
                            prefix_lens,
         
     | 
| 
      
 572 
     | 
    
         
            +
                            kv_start_idx,
         
     | 
| 
      
 573 
     | 
    
         
            +
                            self.kv_indptr[wrapper_id],
         
     | 
| 
      
 574 
     | 
    
         
            +
                            self.qo_indptr[wrapper_id],
         
     | 
| 
      
 575 
     | 
    
         
            +
                            use_ragged,
         
     | 
| 
      
 576 
     | 
    
         
            +
                        )
         
     | 
| 
       457 
577 
     | 
    
         | 
| 
       458 
578 
     | 
    
         
             
                def call_begin_forward(
         
     | 
| 
       459 
579 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -469,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       469 
589 
     | 
    
         
             
                    use_ragged,
         
     | 
| 
       470 
590 
     | 
    
         
             
                ):
         
     | 
| 
       471 
591 
     | 
    
         
             
                    bs = len(req_pool_indices)
         
     | 
| 
      
 592 
     | 
    
         
            +
                    kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
       472 
593 
     | 
    
         
             
                    kv_indptr = kv_indptr[: bs + 1]
         
     | 
| 
       473 
     | 
    
         
            -
                    kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
         
     | 
| 
       474 
594 
     | 
    
         
             
                    kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
         
     | 
| 
       475 
595 
     | 
    
         
             
                    create_flashinfer_kv_indices_triton[(bs,)](
         
     | 
| 
       476 
596 
     | 
    
         
             
                        self.req_to_token,
         
     | 
| 
         @@ -482,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       482 
602 
     | 
    
         
             
                        self.max_context_len,
         
     | 
| 
       483 
603 
     | 
    
         
             
                    )
         
     | 
| 
       484 
604 
     | 
    
         | 
| 
      
 605 
     | 
    
         
            +
                    qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
         
     | 
| 
       485 
606 
     | 
    
         
             
                    qo_indptr = qo_indptr[: bs + 1]
         
     | 
| 
       486 
     | 
    
         
            -
                    qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
         
     | 
| 
       487 
607 
     | 
    
         | 
| 
       488 
608 
     | 
    
         
             
                    # extend part
         
     | 
| 
       489 
609 
     | 
    
         
             
                    if use_ragged:
         
     | 
| 
         @@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict 
     | 
|
| 
       10 
10 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
      
 13 
     | 
    
         
            +
                from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       13 
14 
     | 
    
         
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       14 
15 
     | 
    
         | 
| 
       15 
16 
     | 
    
         | 
| 
         @@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       81 
82 
     | 
    
         
             
                    )
         
     | 
| 
       82 
83 
     | 
    
         | 
| 
       83 
84 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       84 
     | 
    
         
            -
                    self, 
     | 
| 
      
 85 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 86 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 87 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 88 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 89 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
       85 
90 
     | 
    
         
             
                ):
         
     | 
| 
      
 91 
     | 
    
         
            +
                    # NOTE: encoder_lens expected to be zeros or None
         
     | 
| 
       86 
92 
     | 
    
         
             
                    self.forward_metadata = (
         
     | 
| 
       87 
93 
     | 
    
         
             
                        self.cuda_graph_start_loc,
         
     | 
| 
       88 
94 
     | 
    
         
             
                        self.cuda_graph_attn_logits,
         
     | 
| 
         @@ -91,15 +97,23 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       91 
97 
     | 
    
         
             
                    )
         
     | 
| 
       92 
98 
     | 
    
         | 
| 
       93 
99 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       94 
     | 
    
         
            -
                    self, 
     | 
| 
      
 100 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 101 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 105 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
       95 
106 
     | 
    
         
             
                ):
         
     | 
| 
      
 107 
     | 
    
         
            +
                    # NOTE: encoder_lens expected to be zeros or None
         
     | 
| 
       96 
108 
     | 
    
         
             
                    self.cuda_graph_start_loc.zero_()
         
     | 
| 
       97 
109 
     | 
    
         
             
                    self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
         
     | 
| 
       98 
110 
     | 
    
         | 
| 
       99 
111 
     | 
    
         
             
                def get_cuda_graph_seq_len_fill_value(self):
         
     | 
| 
       100 
112 
     | 
    
         
             
                    return 1
         
     | 
| 
       101 
113 
     | 
    
         | 
| 
       102 
     | 
    
         
            -
                def forward_extend( 
     | 
| 
      
 114 
     | 
    
         
            +
                def forward_extend(
         
     | 
| 
      
 115 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 116 
     | 
    
         
            +
                ):
         
     | 
| 
       103 
117 
     | 
    
         
             
                    # TODO: reuse the buffer across layers
         
     | 
| 
       104 
118 
     | 
    
         
             
                    if layer.qk_head_dim != layer.v_head_dim:
         
     | 
| 
       105 
119 
     | 
    
         
             
                        o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
         
     | 
| 
         @@ -107,7 +121,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       107 
121 
     | 
    
         
             
                        o = torch.empty_like(q)
         
     | 
| 
       108 
122 
     | 
    
         | 
| 
       109 
123 
     | 
    
         
             
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       110 
     | 
    
         
            -
                        layer 
     | 
| 
      
 124 
     | 
    
         
            +
                        layer, forward_batch.out_cache_loc, k, v
         
     | 
| 
       111 
125 
     | 
    
         
             
                    )
         
     | 
| 
       112 
126 
     | 
    
         | 
| 
       113 
127 
     | 
    
         
             
                    start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
         
     | 
| 
         @@ -129,7 +143,9 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       129 
143 
     | 
    
         
             
                    )
         
     | 
| 
       130 
144 
     | 
    
         
             
                    return o
         
     | 
| 
       131 
145 
     | 
    
         | 
| 
       132 
     | 
    
         
            -
                def forward_decode( 
     | 
| 
      
 146 
     | 
    
         
            +
                def forward_decode(
         
     | 
| 
      
 147 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 148 
     | 
    
         
            +
                ):
         
     | 
| 
       133 
149 
     | 
    
         
             
                    # During torch.compile, there is a bug in rotary_emb that causes the
         
     | 
| 
       134 
150 
     | 
    
         
             
                    # output value to have a 3D tensor shape. This reshapes the output correctly.
         
     | 
| 
       135 
151 
     | 
    
         
             
                    q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
         
     | 
| 
         @@ -143,7 +159,7 @@ class TritonAttnBackend(AttentionBackend): 
     | 
|
| 
       143 
159 
     | 
    
         
             
                    start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
         
     | 
| 
       144 
160 
     | 
    
         | 
| 
       145 
161 
     | 
    
         
             
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       146 
     | 
    
         
            -
                        layer 
     | 
| 
      
 162 
     | 
    
         
            +
                        layer, forward_batch.out_cache_loc, k, v
         
     | 
| 
       147 
163 
     | 
    
         
             
                    )
         
     | 
| 
       148 
164 
     | 
    
         | 
| 
       149 
165 
     | 
    
         
             
                    self.decode_attention_fwd(
         
     | 
| 
         @@ -50,6 +50,7 @@ def _fwd_kernel( 
     | 
|
| 
       50 
50 
     | 
    
         
             
                BLOCK_M: tl.constexpr,
         
     | 
| 
       51 
51 
     | 
    
         
             
                BLOCK_DMODEL: tl.constexpr,
         
     | 
| 
       52 
52 
     | 
    
         
             
                BLOCK_N: tl.constexpr,
         
     | 
| 
      
 53 
     | 
    
         
            +
                IS_CAUSAL: tl.constexpr,
         
     | 
| 
       53 
54 
     | 
    
         
             
                Lk: tl.constexpr,
         
     | 
| 
       54 
55 
     | 
    
         
             
            ):
         
     | 
| 
       55 
56 
     | 
    
         
             
                cur_batch = tl.program_id(0)
         
     | 
| 
         @@ -78,7 +79,9 @@ def _fwd_kernel( 
     | 
|
| 
       78 
79 
     | 
    
         
             
                mask_d = offs_d < Lk
         
     | 
| 
       79 
80 
     | 
    
         | 
| 
       80 
81 
     | 
    
         
             
                q = tl.load(
         
     | 
| 
       81 
     | 
    
         
            -
                    Q + off_q, 
     | 
| 
      
 82 
     | 
    
         
            +
                    Q + off_q,
         
     | 
| 
      
 83 
     | 
    
         
            +
                    mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
         
     | 
| 
      
 84 
     | 
    
         
            +
                    other=0.0,
         
     | 
| 
       82 
85 
     | 
    
         
             
                )
         
     | 
| 
       83 
86 
     | 
    
         | 
| 
       84 
87 
     | 
    
         
             
                k_ptrs = K + off_k
         
     | 
| 
         @@ -91,7 +94,12 @@ def _fwd_kernel( 
     | 
|
| 
       91 
94 
     | 
    
         | 
| 
       92 
95 
     | 
    
         
             
                block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
         
     | 
| 
       93 
96 
     | 
    
         | 
| 
       94 
     | 
    
         
            -
                 
     | 
| 
      
 97 
     | 
    
         
            +
                end_n = (
         
     | 
| 
      
 98 
     | 
    
         
            +
                    cur_batch_seq_len
         
     | 
| 
      
 99 
     | 
    
         
            +
                    if not IS_CAUSAL
         
     | 
| 
      
 100 
     | 
    
         
            +
                    else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
         
     | 
| 
      
 101 
     | 
    
         
            +
                )
         
     | 
| 
      
 102 
     | 
    
         
            +
                for start_n in range(0, block_mask * end_n, BLOCK_N):
         
     | 
| 
       95 
103 
     | 
    
         
             
                    start_n = tl.multiple_of(start_n, BLOCK_N)
         
     | 
| 
       96 
104 
     | 
    
         
             
                    # -- compute qk ----
         
     | 
| 
       97 
105 
     | 
    
         
             
                    k = tl.load(
         
     | 
| 
         @@ -104,7 +112,18 @@ def _fwd_kernel( 
     | 
|
| 
       104 
112 
     | 
    
         
             
                    qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
         
     | 
| 
       105 
113 
     | 
    
         
             
                    qk += tl.dot(q, k)
         
     | 
| 
       106 
114 
     | 
    
         
             
                    qk *= sm_scale
         
     | 
| 
       107 
     | 
    
         
            -
             
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                    if IS_CAUSAL:
         
     | 
| 
      
 117 
     | 
    
         
            +
                        qk += tl.where(
         
     | 
| 
      
 118 
     | 
    
         
            +
                            (start_n + offs_n[None, :] < cur_batch_seq_len)
         
     | 
| 
      
 119 
     | 
    
         
            +
                            & (offs_m[:, None] >= (start_n + offs_n[None, :])),
         
     | 
| 
      
 120 
     | 
    
         
            +
                            0,
         
     | 
| 
      
 121 
     | 
    
         
            +
                            float("-inf"),
         
     | 
| 
      
 122 
     | 
    
         
            +
                        )
         
     | 
| 
      
 123 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 124 
     | 
    
         
            +
                        qk += tl.where(
         
     | 
| 
      
 125 
     | 
    
         
            +
                            (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
         
     | 
| 
      
 126 
     | 
    
         
            +
                        )
         
     | 
| 
       108 
127 
     | 
    
         | 
| 
       109 
128 
     | 
    
         
             
                    # -- compute m_ij, p, l_ij
         
     | 
| 
       110 
129 
     | 
    
         
             
                    m_ij = tl.max(qk, 1)
         
     | 
| 
         @@ -146,7 +165,9 @@ def _fwd_kernel( 
     | 
|
| 
       146 
165 
     | 
    
         
             
                )
         
     | 
| 
       147 
166 
     | 
    
         | 
| 
       148 
167 
     | 
    
         | 
| 
       149 
     | 
    
         
            -
            def context_attention_fwd( 
     | 
| 
      
 168 
     | 
    
         
            +
            def context_attention_fwd(
         
     | 
| 
      
 169 
     | 
    
         
            +
                q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
         
     | 
| 
      
 170 
     | 
    
         
            +
            ):
         
     | 
| 
       150 
171 
     | 
    
         
             
                if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
         
     | 
| 
       151 
172 
     | 
    
         
             
                    BLOCK = 128
         
     | 
| 
       152 
173 
     | 
    
         
             
                else:
         
     | 
| 
         @@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): 
     | 
|
| 
       181 
202 
     | 
    
         
             
                    BLOCK_M=BLOCK,
         
     | 
| 
       182 
203 
     | 
    
         
             
                    BLOCK_DMODEL=triton.next_power_of_2(Lk),
         
     | 
| 
       183 
204 
     | 
    
         
             
                    BLOCK_N=BLOCK,
         
     | 
| 
      
 205 
     | 
    
         
            +
                    IS_CAUSAL=is_causal,
         
     | 
| 
       184 
206 
     | 
    
         
             
                    num_warps=num_warps,
         
     | 
| 
       185 
207 
     | 
    
         
             
                    num_stages=1,
         
     | 
| 
       186 
208 
     | 
    
         
             
                    Lk=Lk,
         
     |