sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/bench_latency.py
    CHANGED
    
    | 
         @@ -227,8 +227,9 @@ def extend(reqs, model_runner): 
     | 
|
| 
       227 
227 
     | 
    
         
             
                    req_to_token_pool=model_runner.req_to_token_pool,
         
     | 
| 
       228 
228 
     | 
    
         
             
                    token_to_kv_pool=model_runner.token_to_kv_pool,
         
     | 
| 
       229 
229 
     | 
    
         
             
                    tree_cache=None,
         
     | 
| 
      
 230 
     | 
    
         
            +
                    model_config=model_runner.model_config,
         
     | 
| 
       230 
231 
     | 
    
         
             
                )
         
     | 
| 
       231 
     | 
    
         
            -
                batch.prepare_for_extend( 
     | 
| 
      
 232 
     | 
    
         
            +
                batch.prepare_for_extend()
         
     | 
| 
       232 
233 
     | 
    
         
             
                model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       233 
234 
     | 
    
         
             
                forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
         
     | 
| 
       234 
235 
     | 
    
         
             
                logits_output = model_runner.forward(forward_batch)
         
     | 
    
        sglang/lang/chat_template.py
    CHANGED
    
    | 
         @@ -133,6 +133,22 @@ register_chat_template( 
     | 
|
| 
       133 
133 
     | 
    
         
             
                )
         
     | 
| 
       134 
134 
     | 
    
         
             
            )
         
     | 
| 
       135 
135 
     | 
    
         | 
| 
      
 136 
     | 
    
         
            +
            # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
         
     | 
| 
      
 137 
     | 
    
         
            +
            register_chat_template(
         
     | 
| 
      
 138 
     | 
    
         
            +
                ChatTemplate(
         
     | 
| 
      
 139 
     | 
    
         
            +
                    name="qwen2-vl",
         
     | 
| 
      
 140 
     | 
    
         
            +
                    default_system_prompt="You are a helpful assistant.",
         
     | 
| 
      
 141 
     | 
    
         
            +
                    role_prefix_and_suffix={
         
     | 
| 
      
 142 
     | 
    
         
            +
                        "system": ("<|im_start|>system\n", "<|im_end|>\n"),
         
     | 
| 
      
 143 
     | 
    
         
            +
                        "user": ("<|im_start|>user\n", "<|im_end|>\n"),
         
     | 
| 
      
 144 
     | 
    
         
            +
                        "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
         
     | 
| 
      
 145 
     | 
    
         
            +
                    },
         
     | 
| 
      
 146 
     | 
    
         
            +
                    style=ChatTemplateStyle.PLAIN,
         
     | 
| 
      
 147 
     | 
    
         
            +
                    stop_str=("<|im_end|>"),
         
     | 
| 
      
 148 
     | 
    
         
            +
                    image_token="<|vision_start|><|image_pad|><|vision_end|>",
         
     | 
| 
      
 149 
     | 
    
         
            +
                )
         
     | 
| 
      
 150 
     | 
    
         
            +
            )
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
       136 
152 
     | 
    
         | 
| 
       137 
153 
     | 
    
         
             
            register_chat_template(
         
     | 
| 
       138 
154 
     | 
    
         
             
                ChatTemplate(
         
     | 
| 
         @@ -213,6 +229,7 @@ register_chat_template( 
     | 
|
| 
       213 
229 
     | 
    
         
             
                        ),
         
     | 
| 
       214 
230 
     | 
    
         
             
                    },
         
     | 
| 
       215 
231 
     | 
    
         
             
                    stop_str=("<|eot_id|>",),
         
     | 
| 
      
 232 
     | 
    
         
            +
                    image_token="<|image|>",
         
     | 
| 
       216 
233 
     | 
    
         
             
                )
         
     | 
| 
       217 
234 
     | 
    
         
             
            )
         
     | 
| 
       218 
235 
     | 
    
         | 
    
        sglang/launch_server_llavavid.py
    CHANGED
    
    | 
         @@ -14,7 +14,7 @@ if __name__ == "__main__": 
     | 
|
| 
       14 
14 
     | 
    
         
             
                model_override_args["num_frames"] = 16
         
     | 
| 
       15 
15 
     | 
    
         
             
                model_override_args["model_type"] = "llavavid"
         
     | 
| 
       16 
16 
     | 
    
         
             
                if model_override_args["num_frames"] == 32:
         
     | 
| 
       17 
     | 
    
         
            -
                    model_override_args["rope_scaling"] = {"factor": 2.0, " 
     | 
| 
      
 17 
     | 
    
         
            +
                    model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
         
     | 
| 
       18 
18 
     | 
    
         
             
                    model_override_args["max_sequence_length"] = 4096 * 2
         
     | 
| 
       19 
19 
     | 
    
         
             
                    model_override_args["tokenizer_model_max_length"] = 4096 * 2
         
     | 
| 
       20 
20 
     | 
    
         
             
                    model_override_args["model_max_length"] = 4096 * 2
         
     | 
    
        sglang/srt/configs/__init__.py
    CHANGED
    
    
| 
         @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and 
     | 
|
| 
       13 
13 
     | 
    
         
             
            limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            """
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
      
 16 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 17 
     | 
    
         
            +
            import os
         
     | 
| 
       16 
18 
     | 
    
         
             
            from enum import IntEnum, auto
         
     | 
| 
       17 
19 
     | 
    
         
             
            from typing import Optional
         
     | 
| 
       18 
20 
     | 
    
         | 
| 
         @@ -20,6 +22,8 @@ from transformers import PretrainedConfig 
     | 
|
| 
       20 
22 
     | 
    
         | 
| 
       21 
23 
     | 
    
         
             
            from sglang.srt.hf_transformers_utils import get_config, get_context_length
         
     | 
| 
       22 
24 
     | 
    
         | 
| 
      
 25 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
       23 
27 
     | 
    
         | 
| 
       24 
28 
     | 
    
         
             
            class AttentionArch(IntEnum):
         
     | 
| 
       25 
29 
     | 
    
         
             
                MLA = auto()
         
     | 
| 
         @@ -46,10 +50,29 @@ class ModelConfig: 
     | 
|
| 
       46 
50 
     | 
    
         
             
                        model_override_args=model_override_args,
         
     | 
| 
       47 
51 
     | 
    
         
             
                    )
         
     | 
| 
       48 
52 
     | 
    
         
             
                    self.hf_text_config = get_hf_text_config(self.hf_config)
         
     | 
| 
      
 53 
     | 
    
         
            +
                    derived_context_len = get_context_length(self.hf_text_config)
         
     | 
| 
      
 54 
     | 
    
         
            +
                    allow_long_context = os.environ.get(
         
     | 
| 
      
 55 
     | 
    
         
            +
                        "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
         
     | 
| 
      
 56 
     | 
    
         
            +
                    )
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
       49 
58 
     | 
    
         
             
                    if context_length is not None:
         
     | 
| 
       50 
     | 
    
         
            -
                         
     | 
| 
      
 59 
     | 
    
         
            +
                        if context_length > derived_context_len:
         
     | 
| 
      
 60 
     | 
    
         
            +
                            if allow_long_context:
         
     | 
| 
      
 61 
     | 
    
         
            +
                                logger.warning(
         
     | 
| 
      
 62 
     | 
    
         
            +
                                    f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
         
     | 
| 
      
 63 
     | 
    
         
            +
                                    f"This may lead to incorrect model outputs or CUDA errors."
         
     | 
| 
      
 64 
     | 
    
         
            +
                                )
         
     | 
| 
      
 65 
     | 
    
         
            +
                                self.context_len = context_length
         
     | 
| 
      
 66 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 67 
     | 
    
         
            +
                                raise ValueError(
         
     | 
| 
      
 68 
     | 
    
         
            +
                                    f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
         
     | 
| 
      
 69 
     | 
    
         
            +
                                    f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
         
     | 
| 
      
 70 
     | 
    
         
            +
                                    f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
         
     | 
| 
      
 71 
     | 
    
         
            +
                                )
         
     | 
| 
      
 72 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 73 
     | 
    
         
            +
                            self.context_len = context_length
         
     | 
| 
       51 
74 
     | 
    
         
             
                    else:
         
     | 
| 
       52 
     | 
    
         
            -
                        self.context_len =  
     | 
| 
      
 75 
     | 
    
         
            +
                        self.context_len = derived_context_len
         
     | 
| 
       53 
76 
     | 
    
         | 
| 
       54 
77 
     | 
    
         
             
                    # Unify the config keys for hf_text_config
         
     | 
| 
       55 
78 
     | 
    
         
             
                    self.head_dim = getattr(
         
     | 
| 
         @@ -89,6 +112,8 @@ class ModelConfig: 
     | 
|
| 
       89 
112 
     | 
    
         
             
                    self.num_hidden_layers = self.hf_text_config.num_hidden_layers
         
     | 
| 
       90 
113 
     | 
    
         
             
                    self.vocab_size = self.hf_text_config.vocab_size
         
     | 
| 
       91 
114 
     | 
    
         | 
| 
      
 115 
     | 
    
         
            +
                    self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
       92 
117 
     | 
    
         
             
                # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
         
     | 
| 
       93 
118 
     | 
    
         
             
                def get_total_num_kv_heads(self) -> int:
         
     | 
| 
       94 
119 
     | 
    
         
             
                    """Returns the total number of KV heads."""
         
     | 
| 
         @@ -0,0 +1,133 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # coding=utf-8
         
     | 
| 
      
 2 
     | 
    
         
            +
            # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
         
     | 
| 
      
 3 
     | 
    
         
            +
            # All rights reserved.
         
     | 
| 
      
 4 
     | 
    
         
            +
            #
         
     | 
| 
      
 5 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 6 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 7 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 10 
     | 
    
         
            +
            #
         
     | 
| 
      
 11 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 12 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 13 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 15 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 16 
     | 
    
         
            +
            """Qwen2VL model configuration"""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            import os
         
     | 
| 
      
 19 
     | 
    
         
            +
            from typing import Union
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            from transformers import PretrainedConfig
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            class Qwen2VLVisionConfig(PretrainedConfig):
         
     | 
| 
      
 25 
     | 
    
         
            +
                model_type = "qwen2_vl"
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 28 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 29 
     | 
    
         
            +
                    depth=32,
         
     | 
| 
      
 30 
     | 
    
         
            +
                    embed_dim=1280,
         
     | 
| 
      
 31 
     | 
    
         
            +
                    hidden_size=3584,
         
     | 
| 
      
 32 
     | 
    
         
            +
                    hidden_act="quick_gelu",
         
     | 
| 
      
 33 
     | 
    
         
            +
                    mlp_ratio=4,
         
     | 
| 
      
 34 
     | 
    
         
            +
                    num_heads=16,
         
     | 
| 
      
 35 
     | 
    
         
            +
                    in_channels=3,
         
     | 
| 
      
 36 
     | 
    
         
            +
                    patch_size=14,
         
     | 
| 
      
 37 
     | 
    
         
            +
                    spatial_merge_size=2,
         
     | 
| 
      
 38 
     | 
    
         
            +
                    temporal_patch_size=2,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 40 
     | 
    
         
            +
                ):
         
     | 
| 
      
 41 
     | 
    
         
            +
                    super().__init__(**kwargs)
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                    self.depth = depth
         
     | 
| 
      
 44 
     | 
    
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 
      
 45 
     | 
    
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 
      
 47 
     | 
    
         
            +
                    self.mlp_ratio = mlp_ratio
         
     | 
| 
      
 48 
     | 
    
         
            +
                    self.num_heads = num_heads
         
     | 
| 
      
 49 
     | 
    
         
            +
                    self.in_channels = in_channels
         
     | 
| 
      
 50 
     | 
    
         
            +
                    self.patch_size = patch_size
         
     | 
| 
      
 51 
     | 
    
         
            +
                    self.spatial_merge_size = spatial_merge_size
         
     | 
| 
      
 52 
     | 
    
         
            +
                    self.temporal_patch_size = temporal_patch_size
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 55 
     | 
    
         
            +
                def from_pretrained(
         
     | 
| 
      
 56 
     | 
    
         
            +
                    cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
         
     | 
| 
      
 57 
     | 
    
         
            +
                ) -> "PretrainedConfig":
         
     | 
| 
      
 58 
     | 
    
         
            +
                    cls._set_token_in_kwargs(kwargs)
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                    config_dict, kwargs = cls.get_config_dict(
         
     | 
| 
      
 61 
     | 
    
         
            +
                        pretrained_model_name_or_path, **kwargs
         
     | 
| 
      
 62 
     | 
    
         
            +
                    )
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                    if config_dict.get("model_type") == "qwen2_vl":
         
     | 
| 
      
 65 
     | 
    
         
            +
                        config_dict = config_dict["vision_config"]
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                    return cls.from_dict(config_dict, **kwargs)
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
            class Qwen2VLConfig(PretrainedConfig):
         
     | 
| 
      
 71 
     | 
    
         
            +
                model_type = "qwen2_vl"
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 74 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 75 
     | 
    
         
            +
                    vocab_size=152064,
         
     | 
| 
      
 76 
     | 
    
         
            +
                    hidden_size=8192,
         
     | 
| 
      
 77 
     | 
    
         
            +
                    intermediate_size=29568,
         
     | 
| 
      
 78 
     | 
    
         
            +
                    num_hidden_layers=80,
         
     | 
| 
      
 79 
     | 
    
         
            +
                    num_attention_heads=64,
         
     | 
| 
      
 80 
     | 
    
         
            +
                    num_key_value_heads=8,
         
     | 
| 
      
 81 
     | 
    
         
            +
                    hidden_act="silu",
         
     | 
| 
      
 82 
     | 
    
         
            +
                    max_position_embeddings=32768,
         
     | 
| 
      
 83 
     | 
    
         
            +
                    initializer_range=0.02,
         
     | 
| 
      
 84 
     | 
    
         
            +
                    rms_norm_eps=1e-05,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    use_cache=True,
         
     | 
| 
      
 86 
     | 
    
         
            +
                    tie_word_embeddings=False,
         
     | 
| 
      
 87 
     | 
    
         
            +
                    rope_theta=1000000.0,
         
     | 
| 
      
 88 
     | 
    
         
            +
                    use_sliding_window=False,
         
     | 
| 
      
 89 
     | 
    
         
            +
                    sliding_window=4096,
         
     | 
| 
      
 90 
     | 
    
         
            +
                    max_window_layers=80,
         
     | 
| 
      
 91 
     | 
    
         
            +
                    attention_dropout=0.0,
         
     | 
| 
      
 92 
     | 
    
         
            +
                    vision_config=None,
         
     | 
| 
      
 93 
     | 
    
         
            +
                    rope_scaling=None,
         
     | 
| 
      
 94 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 95 
     | 
    
         
            +
                ):
         
     | 
| 
      
 96 
     | 
    
         
            +
                    if isinstance(vision_config, dict):
         
     | 
| 
      
 97 
     | 
    
         
            +
                        self.vision_config = Qwen2VLVisionConfig(**vision_config)
         
     | 
| 
      
 98 
     | 
    
         
            +
                    elif vision_config is None:
         
     | 
| 
      
 99 
     | 
    
         
            +
                        self.vision_config = Qwen2VLVisionConfig()
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 
      
 102 
     | 
    
         
            +
                    self.max_position_embeddings = max_position_embeddings
         
     | 
| 
      
 103 
     | 
    
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 
      
 104 
     | 
    
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 
      
 105 
     | 
    
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 
      
 106 
     | 
    
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 
      
 107 
     | 
    
         
            +
                    self.use_sliding_window = use_sliding_window
         
     | 
| 
      
 108 
     | 
    
         
            +
                    self.sliding_window = sliding_window
         
     | 
| 
      
 109 
     | 
    
         
            +
                    self.max_window_layers = max_window_layers
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                    # for backward compatibility
         
     | 
| 
      
 112 
     | 
    
         
            +
                    if num_key_value_heads is None:
         
     | 
| 
      
 113 
     | 
    
         
            +
                        num_key_value_heads = num_attention_heads
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                    self.num_key_value_heads = num_key_value_heads
         
     | 
| 
      
 116 
     | 
    
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self.initializer_range = initializer_range
         
     | 
| 
      
 118 
     | 
    
         
            +
                    self.rms_norm_eps = rms_norm_eps
         
     | 
| 
      
 119 
     | 
    
         
            +
                    self.use_cache = use_cache
         
     | 
| 
      
 120 
     | 
    
         
            +
                    self.rope_theta = rope_theta
         
     | 
| 
      
 121 
     | 
    
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 
      
 122 
     | 
    
         
            +
                    self.rope_scaling = rope_scaling
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                    # NOTE: the following section from original transformers config
         
     | 
| 
      
 125 
     | 
    
         
            +
                    # for Qwen2-VL is commented out to address rope config loading issue
         
     | 
| 
      
 126 
     | 
    
         
            +
                    #
         
     | 
| 
      
 127 
     | 
    
         
            +
                    # if self.rope_scaling is not None and "type" in self.rope_scaling:
         
     | 
| 
      
 128 
     | 
    
         
            +
                    #     if self.rope_scaling["type"] == "mrope":
         
     | 
| 
      
 129 
     | 
    
         
            +
                    #         self.rope_scaling["type"] = "default"
         
     | 
| 
      
 130 
     | 
    
         
            +
                    #     self.rope_scaling["rope_type"] = self.rope_scaling["type"]
         
     | 
| 
      
 131 
     | 
    
         
            +
                    # rope_config_validation(self)
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
                    super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
         
     | 
| 
         @@ -73,9 +73,16 @@ class FSMCache(BaseToolCache): 
     | 
|
| 
       73 
73 
     | 
    
         
             
                def init_value(self, key):
         
     | 
| 
       74 
74 
     | 
    
         
             
                    key_type, key_string = key
         
     | 
| 
       75 
75 
     | 
    
         
             
                    if key_type == "json":
         
     | 
| 
       76 
     | 
    
         
            -
                         
     | 
| 
       77 
     | 
    
         
            -
                             
     | 
| 
       78 
     | 
    
         
            -
             
     | 
| 
      
 76 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 77 
     | 
    
         
            +
                            regex = build_regex_from_schema(
         
     | 
| 
      
 78 
     | 
    
         
            +
                                key_string,
         
     | 
| 
      
 79 
     | 
    
         
            +
                                whitespace_pattern=self.constrained_json_whitespace_pattern,
         
     | 
| 
      
 80 
     | 
    
         
            +
                            )
         
     | 
| 
      
 81 
     | 
    
         
            +
                        except NotImplementedError as e:
         
     | 
| 
      
 82 
     | 
    
         
            +
                            logger.warning(
         
     | 
| 
      
 83 
     | 
    
         
            +
                                f"skip invalid json schema: json_schema={key_string}, {e=}"
         
     | 
| 
      
 84 
     | 
    
         
            +
                            )
         
     | 
| 
      
 85 
     | 
    
         
            +
                            return None, key_string
         
     | 
| 
       79 
86 
     | 
    
         
             
                    elif key_type == "regex":
         
     | 
| 
       80 
87 
     | 
    
         
             
                        regex = key_string
         
     | 
| 
       81 
88 
     | 
    
         
             
                    else:
         
     | 
    
        sglang/srt/conversation.py
    CHANGED
    
    | 
         @@ -509,6 +509,19 @@ register_conv_template( 
     | 
|
| 
       509 
509 
     | 
    
         
             
                )
         
     | 
| 
       510 
510 
     | 
    
         
             
            )
         
     | 
| 
       511 
511 
     | 
    
         | 
| 
      
 512 
     | 
    
         
            +
            register_conv_template(
         
     | 
| 
      
 513 
     | 
    
         
            +
                Conversation(
         
     | 
| 
      
 514 
     | 
    
         
            +
                    name="llama_3_vision",
         
     | 
| 
      
 515 
     | 
    
         
            +
                    system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
         
     | 
| 
      
 516 
     | 
    
         
            +
                    system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
         
     | 
| 
      
 517 
     | 
    
         
            +
                    roles=("user", "assistant"),
         
     | 
| 
      
 518 
     | 
    
         
            +
                    sep_style=SeparatorStyle.LLAMA3,
         
     | 
| 
      
 519 
     | 
    
         
            +
                    sep="",
         
     | 
| 
      
 520 
     | 
    
         
            +
                    stop_str=["<|end_of_text|>", "<|eot_id|>"],
         
     | 
| 
      
 521 
     | 
    
         
            +
                    image_token="<|image|>",
         
     | 
| 
      
 522 
     | 
    
         
            +
                )
         
     | 
| 
      
 523 
     | 
    
         
            +
            )
         
     | 
| 
      
 524 
     | 
    
         
            +
             
     | 
| 
       512 
525 
     | 
    
         
             
            register_conv_template(
         
     | 
| 
       513 
526 
     | 
    
         
             
                Conversation(
         
     | 
| 
       514 
527 
     | 
    
         
             
                    name="llava_llama_3",
         
     | 
| 
         @@ -530,3 +543,17 @@ register_conv_template( 
     | 
|
| 
       530 
543 
     | 
    
         
             
                    stop_str=["<|im_end|>", "<|action_end|>"],
         
     | 
| 
       531 
544 
     | 
    
         
             
                )
         
     | 
| 
       532 
545 
     | 
    
         
             
            )
         
     | 
| 
      
 546 
     | 
    
         
            +
             
     | 
| 
      
 547 
     | 
    
         
            +
            # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
         
     | 
| 
      
 548 
     | 
    
         
            +
            register_conv_template(
         
     | 
| 
      
 549 
     | 
    
         
            +
                Conversation(
         
     | 
| 
      
 550 
     | 
    
         
            +
                    name="qwen2-vl",
         
     | 
| 
      
 551 
     | 
    
         
            +
                    system_message="You are a helpful assistant.",
         
     | 
| 
      
 552 
     | 
    
         
            +
                    system_template="<|im_start|>system\n{system_message}",
         
     | 
| 
      
 553 
     | 
    
         
            +
                    roles=("<|im_start|>user", "<|im_start|>assistant"),
         
     | 
| 
      
 554 
     | 
    
         
            +
                    sep="<|im_end|>\n",
         
     | 
| 
      
 555 
     | 
    
         
            +
                    sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
         
     | 
| 
      
 556 
     | 
    
         
            +
                    stop_str=["<|im_end|>"],
         
     | 
| 
      
 557 
     | 
    
         
            +
                    image_token="<|vision_start|><|image_pad|><|vision_end|>",
         
     | 
| 
      
 558 
     | 
    
         
            +
                )
         
     | 
| 
      
 559 
     | 
    
         
            +
            )
         
     | 
| 
         @@ -33,12 +33,13 @@ from transformers import ( 
     | 
|
| 
       33 
33 
     | 
    
         
             
            try:
         
     | 
| 
       34 
34 
     | 
    
         
             
                from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
         
     | 
| 
       35 
35 
     | 
    
         | 
| 
       36 
     | 
    
         
            -
                from sglang.srt.configs import ExaoneConfig
         
     | 
| 
      
 36 
     | 
    
         
            +
                from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
         
     | 
| 
       37 
37 
     | 
    
         | 
| 
       38 
38 
     | 
    
         
             
                _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
         
     | 
| 
       39 
39 
     | 
    
         
             
                    ChatGLMConfig.model_type: ChatGLMConfig,
         
     | 
| 
       40 
40 
     | 
    
         
             
                    DbrxConfig.model_type: DbrxConfig,
         
     | 
| 
       41 
41 
     | 
    
         
             
                    ExaoneConfig.model_type: ExaoneConfig,
         
     | 
| 
      
 42 
     | 
    
         
            +
                    Qwen2VLConfig.model_type: Qwen2VLConfig,
         
     | 
| 
       42 
43 
     | 
    
         
             
                }
         
     | 
| 
       43 
44 
     | 
    
         
             
            except ImportError:
         
     | 
| 
       44 
45 
     | 
    
         
             
                # We want this file to run without vllm dependency
         
     | 
| 
         @@ -162,6 +163,8 @@ def get_tokenizer( 
     | 
|
| 
       162 
163 
     | 
    
         
             
                        "Using a slow tokenizer. This might cause a significant "
         
     | 
| 
       163 
164 
     | 
    
         
             
                        "slowdown. Consider using a fast tokenizer instead."
         
     | 
| 
       164 
165 
     | 
    
         
             
                    )
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
                attach_additional_stop_token_ids(tokenizer)
         
     | 
| 
       165 
168 
     | 
    
         
             
                return tokenizer
         
     | 
| 
       166 
169 
     | 
    
         | 
| 
       167 
170 
     | 
    
         | 
| 
         @@ -180,4 +183,16 @@ def get_processor( 
     | 
|
| 
       180 
183 
     | 
    
         
             
                    tokenizer_revision=tokenizer_revision,
         
     | 
| 
       181 
184 
     | 
    
         
             
                    **kwargs,
         
     | 
| 
       182 
185 
     | 
    
         
             
                )
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                attach_additional_stop_token_ids(processor.tokenizer)
         
     | 
| 
       183 
188 
     | 
    
         
             
                return processor
         
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
            def attach_additional_stop_token_ids(tokenizer):
         
     | 
| 
      
 192 
     | 
    
         
            +
                # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
         
     | 
| 
      
 193 
     | 
    
         
            +
                if "<|eom_id|>" in tokenizer.get_added_vocab():
         
     | 
| 
      
 194 
     | 
    
         
            +
                    tokenizer.additional_stop_token_ids = set(
         
     | 
| 
      
 195 
     | 
    
         
            +
                        [tokenizer.get_added_vocab()["<|eom_id|>"]]
         
     | 
| 
      
 196 
     | 
    
         
            +
                    )
         
     | 
| 
      
 197 
     | 
    
         
            +
                else:
         
     | 
| 
      
 198 
     | 
    
         
            +
                    tokenizer.additional_stop_token_ids = None
         
     | 
| 
         @@ -1,8 +1,10 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from abc import ABC, abstractmethod
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
       2 
3 
     | 
    
         | 
| 
       3 
4 
     | 
    
         
             
            import torch
         
     | 
| 
       4 
5 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       5 
6 
     | 
    
         | 
| 
      
 7 
     | 
    
         
            +
            from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       6 
8 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       7 
9 
     | 
    
         | 
| 
       8 
10 
     | 
    
         | 
| 
         @@ -19,13 +21,22 @@ class AttentionBackend(ABC): 
     | 
|
| 
       19 
21 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       20 
22 
     | 
    
         | 
| 
       21 
23 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       22 
     | 
    
         
            -
                    self, 
     | 
| 
      
 24 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 25 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 26 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 27 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 28 
     | 
    
         
            +
                    encoder_lens: Optional[torch.Tensor] = None,
         
     | 
| 
       23 
29 
     | 
    
         
             
                ):
         
     | 
| 
       24 
30 
     | 
    
         
             
                    """Init the metadata for a forward pass for capturing a cuda graph."""
         
     | 
| 
       25 
31 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       26 
32 
     | 
    
         | 
| 
       27 
33 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       28 
     | 
    
         
            -
                    self, 
     | 
| 
      
 34 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 35 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 36 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 37 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 38 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    encoder_lens: Optional[torch.Tensor] = None,
         
     | 
| 
       29 
40 
     | 
    
         
             
                ):
         
     | 
| 
       30 
41 
     | 
    
         
             
                    """Init the metadata for a forward pass for replying a cuda graph."""
         
     | 
| 
       31 
42 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
         @@ -39,7 +50,7 @@ class AttentionBackend(ABC): 
     | 
|
| 
       39 
50 
     | 
    
         
             
                    q: torch.Tensor,
         
     | 
| 
       40 
51 
     | 
    
         
             
                    k: torch.Tensor,
         
     | 
| 
       41 
52 
     | 
    
         
             
                    v: torch.Tensor,
         
     | 
| 
       42 
     | 
    
         
            -
                    layer:  
     | 
| 
      
 53 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       43 
54 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
       44 
55 
     | 
    
         
             
                ):
         
     | 
| 
       45 
56 
     | 
    
         
             
                    """Run forward on an attention layer."""
         
     | 
| 
         @@ -53,7 +64,7 @@ class AttentionBackend(ABC): 
     | 
|
| 
       53 
64 
     | 
    
         
             
                    q: torch.Tensor,
         
     | 
| 
       54 
65 
     | 
    
         
             
                    k: torch.Tensor,
         
     | 
| 
       55 
66 
     | 
    
         
             
                    v: torch.Tensor,
         
     | 
| 
       56 
     | 
    
         
            -
                    layer:  
     | 
| 
      
 67 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       57 
68 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
       58 
69 
     | 
    
         
             
                ):
         
     | 
| 
       59 
70 
     | 
    
         
             
                    """Run a forward for decode."""
         
     | 
| 
         @@ -64,7 +75,7 @@ class AttentionBackend(ABC): 
     | 
|
| 
       64 
75 
     | 
    
         
             
                    q: torch.Tensor,
         
     | 
| 
       65 
76 
     | 
    
         
             
                    k: torch.Tensor,
         
     | 
| 
       66 
77 
     | 
    
         
             
                    v: torch.Tensor,
         
     | 
| 
       67 
     | 
    
         
            -
                    layer:  
     | 
| 
      
 78 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       68 
79 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
       69 
80 
     | 
    
         
             
                ):
         
     | 
| 
       70 
81 
     | 
    
         
             
                    """Run a forward for extend."""
         
     | 
| 
         @@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict 
     | 
|
| 
       10 
10 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
      
 13 
     | 
    
         
            +
                from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       13 
14 
     | 
    
         
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       14 
15 
     | 
    
         | 
| 
       15 
16 
     | 
    
         | 
| 
         @@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       134 
135 
     | 
    
         
             
                    )
         
     | 
| 
       135 
136 
     | 
    
         | 
| 
       136 
137 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       137 
     | 
    
         
            -
                    self, 
     | 
| 
      
 138 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 139 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 140 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 141 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 142 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
       138 
143 
     | 
    
         
             
                ):
         
     | 
| 
      
 144 
     | 
    
         
            +
                    # NOTE: encoder_lens expected to be zeros or None
         
     | 
| 
       139 
145 
     | 
    
         
             
                    self.forward_metadata = (
         
     | 
| 
       140 
146 
     | 
    
         
             
                        self.cuda_graph_start_loc,
         
     | 
| 
       141 
147 
     | 
    
         
             
                        self.cuda_graph_attn_logits,
         
     | 
| 
         @@ -144,15 +150,23 @@ class DoubleSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       144 
150 
     | 
    
         
             
                    )
         
     | 
| 
       145 
151 
     | 
    
         | 
| 
       146 
152 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       147 
     | 
    
         
            -
                    self, 
     | 
| 
      
 153 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 154 
     | 
    
         
            +
                    bs: int,
         
     | 
| 
      
 155 
     | 
    
         
            +
                    req_pool_indices: torch.Tensor,
         
     | 
| 
      
 156 
     | 
    
         
            +
                    seq_lens: torch.Tensor,
         
     | 
| 
      
 157 
     | 
    
         
            +
                    seq_lens_sum: int,
         
     | 
| 
      
 158 
     | 
    
         
            +
                    encoder_lens=None,
         
     | 
| 
       148 
159 
     | 
    
         
             
                ):
         
     | 
| 
      
 160 
     | 
    
         
            +
                    # NOTE: encoder_lens expected to be zeros or None
         
     | 
| 
       149 
161 
     | 
    
         
             
                    self.cuda_graph_start_loc.zero_()
         
     | 
| 
       150 
162 
     | 
    
         
             
                    self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
         
     | 
| 
       151 
163 
     | 
    
         | 
| 
       152 
164 
     | 
    
         
             
                def get_cuda_graph_seq_len_fill_value(self):
         
     | 
| 
       153 
165 
     | 
    
         
             
                    return 1
         
     | 
| 
       154 
166 
     | 
    
         | 
| 
       155 
     | 
    
         
            -
                def forward_extend( 
     | 
| 
      
 167 
     | 
    
         
            +
                def forward_extend(
         
     | 
| 
      
 168 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 169 
     | 
    
         
            +
                ):
         
     | 
| 
       156 
170 
     | 
    
         
             
                    # TODO: reuse the buffer across layers
         
     | 
| 
       157 
171 
     | 
    
         
             
                    if layer.qk_head_dim != layer.v_head_dim:
         
     | 
| 
       158 
172 
     | 
    
         
             
                        o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
         
     | 
| 
         @@ -168,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       168 
182 
     | 
    
         
             
                    )
         
     | 
| 
       169 
183 
     | 
    
         | 
| 
       170 
184 
     | 
    
         
             
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       171 
     | 
    
         
            -
                        layer 
     | 
| 
      
 185 
     | 
    
         
            +
                        layer, forward_batch.out_cache_loc, k, v, k_label
         
     | 
| 
       172 
186 
     | 
    
         
             
                    )
         
     | 
| 
       173 
187 
     | 
    
         | 
| 
       174 
188 
     | 
    
         
             
                    (
         
     | 
| 
         @@ -197,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       197 
211 
     | 
    
         
             
                    )
         
     | 
| 
       198 
212 
     | 
    
         
             
                    return o
         
     | 
| 
       199 
213 
     | 
    
         | 
| 
       200 
     | 
    
         
            -
                def forward_decode( 
     | 
| 
      
 214 
     | 
    
         
            +
                def forward_decode(
         
     | 
| 
      
 215 
     | 
    
         
            +
                    self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
         
     | 
| 
      
 216 
     | 
    
         
            +
                ):
         
     | 
| 
       201 
217 
     | 
    
         
             
                    # During torch.compile, there is a bug in rotary_emb that causes the
         
     | 
| 
       202 
218 
     | 
    
         
             
                    # output value to have a 3D tensor shape. This reshapes the output correctly.
         
     | 
| 
       203 
219 
     | 
    
         
             
                    q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
         
     | 
| 
         @@ -227,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend): 
     | 
|
| 
       227 
243 
     | 
    
         
             
                    )
         
     | 
| 
       228 
244 
     | 
    
         | 
| 
       229 
245 
     | 
    
         
             
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       230 
     | 
    
         
            -
                        layer 
     | 
| 
      
 246 
     | 
    
         
            +
                        layer, forward_batch.out_cache_loc, k, v, k_label
         
     | 
| 
       231 
247 
     | 
    
         
             
                    )
         
     | 
| 
       232 
248 
     | 
    
         | 
| 
       233 
249 
     | 
    
         
             
                    # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
         
     |